diff --git a/x/io/_demo/asyncdemo/async.go b/x/io/_demo/asyncdemo/async.go index 2fdbfa85..ff75e943 100644 --- a/x/io/_demo/asyncdemo/async.go +++ b/x/io/_demo/asyncdemo/async.go @@ -169,10 +169,11 @@ func GetUserCompiled(name string) *io.PromiseImpl[User] { state2.Exec = P.Exec state2.Parent = P state2.Call() - log.Printf("TextCompiled state2: %v\n", state2) + log.Printf("TextCompiled state2: %+v\n", state2) return case 2: P.Next = -1 + log.Printf("TextCompiled state2: %+v\n", state2) body, err := state2.Value, state2.Err if err != nil { resolve(User{}, err) @@ -380,10 +381,10 @@ func DemoCompiled() *io.PromiseImpl[io.Void] { case 1: P.Next = 2 user, err := state1.Value, state1.Err - log.Printf("user: %v, err: %v\n", user, err) + log.Printf("user: %+v, err: %v\n", user, err) state2 = io.Race[User](GetUserCompiled("2"), GetUserCompiled("3"), GetUserCompiled("4")) - log.Printf("state2: %v\n", state2) + log.Printf("state2: %+v\n", state2) state2.Exec = P.Exec state2.Parent = P state2.Call() @@ -393,10 +394,10 @@ func DemoCompiled() *io.PromiseImpl[io.Void] { P.Next = 3 user, err := state2.Value, state2.Err - log.Printf("race user: %v, err: %v\n", user, err) + log.Printf("race user: %+v, err: %v\n", user, err) state3 = io.All[User]([]io.AsyncCall[User]{GetUserCompiled("5"), GetUserCompiled("6"), GetUserCompiled("7")}) - log.Printf("state3: %v\n", state3) + log.Printf("state3: %+v\n", state3) state3.Exec = P.Exec state3.Parent = P state3.Call() @@ -408,7 +409,7 @@ func DemoCompiled() *io.PromiseImpl[io.Void] { log.Println(users, err) state4 = io.Await3Compiled[User, float64, io.Void](GetUserCompiled("8"), GetScoreCompiled(), DoUpdateCompiled("update sth.")) - log.Printf("state4: %v\n", state4) + log.Printf("state4: %+v\n", state4) state4.Exec = P.Exec state4.Parent = P state4.Call() diff --git a/x/io/extra.go b/x/io/extra.go index cd09f66a..163b20e6 100644 --- a/x/io/extra.go +++ b/x/io/extra.go @@ -52,59 +52,98 @@ type Result[T any] struct { // llgo:link Race llgo.race func Race[OutT any](acs ...AsyncCall[OutT]) *PromiseImpl[OutT] { + if len(acs) == 0 { + panic("face: no promise") + } + ps := make([]*PromiseImpl[OutT], len(acs)) + for idx, ac := range acs { + ps[idx] = ac.(*PromiseImpl[OutT]) + } + remaining := len(acs) + returned := false P := &PromiseImpl[OutT]{} P.Debug = "Race" P.Func = func(resolve func(OutT, error)) { - P.Next = -1 - rc := make(chan Result[OutT], len(acs)) - for _, ac := range acs { - ac := ac - go func(ac AsyncCall[OutT]) { - v, err := Run[OutT](ac) - rc <- Result[OutT]{v, err} - }(ac) - } - - v := <-rc - if debugAsync { - log.Printf("io.Race done: %+v won the race\n", v) - } - resolve(v.V, v.Err) - go func() { - count := 1 - for count < len(acs) { - <-rc - count++ + switch P.Next { + case 0: + P.Next = 1 + for _, p := range ps { + p.Exec = P.Exec + p.Parent = P + p.Call() } - close(rc) - }() + return + case 1: + remaining-- + if remaining < 0 { + log.Fatalf("race: remaining < 0: %+v\n", remaining) + } + if returned { + return + } + + for _, p := range ps { + if p.Done() { + if debugAsync { + log.Printf("io.Race done: %+v won the race\n", p) + } + returned = true + resolve(p.Value, p.Err) + return + } + } + log.Fatalf("no promise done: %+v\n", ps) + return + default: + panic("unreachable") + } } return P } func All[OutT any](acs []AsyncCall[OutT]) *PromiseImpl[[]Result[OutT]] { + ps := make([]*PromiseImpl[OutT], len(acs)) + for idx, ac := range acs { + ps[idx] = ac.(*PromiseImpl[OutT]) + } + done := 0 P := &PromiseImpl[[]Result[OutT]]{} P.Debug = "All" P.Func = func(resolve func([]Result[OutT], error)) { - P.Next = -1 - wg := sync.WaitGroup{} - ret := make([]Result[OutT], len(acs)) - for idx, ac := range acs { - idx := idx - ac := ac - wg.Add(1) - go func(ac AsyncCall[OutT]) { - v, err := Run[OutT](ac) - ret[idx] = Result[OutT]{v, err} - wg.Done() - }(ac) - } + switch P.Next { + case 0: + P.Next = 1 + for _, p := range ps { + p.Exec = P.Exec + p.Parent = P + p.Call() + } + return + case 1: + done++ + if done < len(acs) { + return + } + P.Next = -1 - wg.Wait() - if debugAsync { - log.Printf("io.All done: %+v\n", ret) + for _, p := range ps { + if !p.Done() { + log.Fatalf("io.All: not done: %+v\n", p) + } + } + + ret := make([]Result[OutT], len(acs)) + for idx, p := range ps { + ret[idx] = Result[OutT]{p.Value, p.Err} + } + if debugAsync { + log.Printf("io.All done: %+v\n", ret) + } + resolve(ret, nil) + return + default: + panic("unreachable") } - resolve(ret, nil) } return P } @@ -142,35 +181,116 @@ type Await3Result[T1 any, T2 any, T3 any] struct { Err error } -// TODO(lijie): rewrite to unblock and avoid goroutine func Await3Compiled[OutT1, OutT2, OutT3 any]( ac1 AsyncCall[OutT1], ac2 AsyncCall[OutT2], ac3 AsyncCall[OutT3], timeout ...time.Duration) *PromiseImpl[Await3Result[OutT1, OutT2, OutT3]] { + p1 := ac1.(*PromiseImpl[OutT1]) + p2 := ac2.(*PromiseImpl[OutT2]) + p3 := ac3.(*PromiseImpl[OutT3]) + remaining := 3 P := &PromiseImpl[Await3Result[OutT1, OutT2, OutT3]]{} P.Debug = "Await3" P.Func = func(resolve func(Await3Result[OutT1, OutT2, OutT3], error)) { - P.Next = -1 + switch P.Next { + case 0: + P.Next = 1 + p1.Exec = P.Exec + p1.Parent = P + p1.Call() - ret := Await3Result[OutT1, OutT2, OutT3]{} - wg := sync.WaitGroup{} - wg.Add(3) + p2.Exec = P.Exec + p2.Parent = P + p2.Call() - go func() { - defer wg.Done() - ret.V1, ret.Err = Run[OutT1](ac1) - }() - go func() { - defer wg.Done() - ret.V2, ret.Err = Run[OutT2](ac2) - }() - go func() { - defer wg.Done() - ret.V3, ret.Err = Run[OutT3](ac3) - }() - wg.Wait() - if debugAsync { - log.Printf("Await3 done: %+v\n", ret) + p3.Exec = P.Exec + p3.Parent = P + p3.Call() + return + case 1: + remaining-- + if remaining > 0 { + return + } + P.Next = -1 + // TODO(lijie): return every error? + if !p1.Done() || !p2.Done() || !p3.Done() { + log.Fatalf("io.Await3: not done: %+v, %+v, %+v\n", p1, p2, p3) + } + + var err error + if p1.Err != nil { + err = p1.Err + } else if p2.Err != nil { + err = p2.Err + } else if p3.Err != nil { + err = p3.Err + } + + resolve(Await3Result[OutT1, OutT2, OutT3]{ + V1: p1.Value, V2: p2.Value, V3: p3.Value, + Err: err, + }, err) + return + default: + panic("unreachable") } + } + return P +} + +// / PAll is a parallel version of All. +func PAll[OutT any](acs ...AsyncCall[OutT]) (resolve Promise[[]Result[OutT]]) { + panic("todo: PAll") +} + +func PAllCompiled[OutT any](acs ...AsyncCall[OutT]) *PromiseImpl[[]Result[OutT]] { + P := &PromiseImpl[[]Result[OutT]]{} + P.Debug = "Parallel" + P.Func = func(resolve func([]Result[OutT], error)) { + ret := make([]Result[OutT], len(acs)) + wg := sync.WaitGroup{} + for idx, ac := range acs { + idx := idx + ac := ac + wg.Add(1) + go func(ac AsyncCall[OutT]) { + v, err := Run[OutT](ac) + ret[idx] = Result[OutT]{v, err} + wg.Done() + }(ac) + } + wg.Wait() + resolve(ret, nil) + } + return P +} + +// / PAwait3 is a parallel version of Await3. +func PAwait3[OutT1, OutT2, OutT3 any](ac1 AsyncCall[OutT1], ac2 AsyncCall[OutT2], ac3 AsyncCall[OutT3]) (resolve Promise[Await3Result[OutT1, OutT2, OutT3]]) { + panic("todo: PAwait2") +} + +func PAwait3Compiled[OutT1, OutT2, OutT3 any]( + ac1 AsyncCall[OutT1], ac2 AsyncCall[OutT2], ac3 AsyncCall[OutT3]) *PromiseImpl[Await3Result[OutT1, OutT2, OutT3]] { + P := &PromiseImpl[Await3Result[OutT1, OutT2, OutT3]]{} + P.Debug = "Parallel3" + P.Func = func(resolve func(Await3Result[OutT1, OutT2, OutT3], error)) { + ret := Await3Result[OutT1, OutT2, OutT3]{} + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + ret.V1, ret.Err = Run[OutT1](ac1) + wg.Done() + }() + go func() { + ret.V2, ret.Err = Run[OutT2](ac2) + wg.Done() + }() + go func() { + ret.V3, ret.Err = Run[OutT3](ac3) + wg.Done() + }() + wg.Wait() resolve(ret, nil) } return P