wire: make solver iterative instead of recursive (google/go-cloud#137)

Primary reason is to make it easier to allow the process to continue and
collect errors. This has the side-effect of allowing larger depth graphs
since the solver no longer pushes Go stack frames.

Updates google/go-cloud#5
This commit is contained in:
Ross Light
2018-06-26 09:08:33 -07:00
parent f7658c8a13
commit 2eb9d5ea1f

View File

@@ -108,38 +108,57 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
} }
// Topological sort of the directed graph defined by the providers // Topological sort of the directed graph defined by the providers
// using a depth-first search. Provider set graphs are guaranteed to // using a depth-first search using a stack. Provider set graphs are
// be acyclic. // guaranteed to be acyclic.
var calls []call var calls []call
var visit func(trail []ProviderInput) error type frame struct {
visit = func(trail []ProviderInput) error { t types.Type
typ := trail[len(trail)-1].Type from types.Type
if index.At(typ) != nil { }
return nil stk := []frame{{t: out}}
for len(stk) > 0 {
curr := stk[len(stk)-1]
stk = stk[:len(stk)-1]
if index.At(curr.t) != nil {
continue
} }
switch pv := set.For(typ); { switch pv := set.For(curr.t); {
case pv.IsNil(): case pv.IsNil():
if len(trail) == 1 { if curr.from == nil {
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil)) return nil, fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(curr.t, nil))
} }
// TODO(light): Give name of provider. // TODO(light): Give name of provider.
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil)) return nil, fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(curr.t, nil), types.TypeString(curr.from, nil))
case pv.IsProvider(): case pv.IsProvider():
p := pv.Provider() p := pv.Provider()
if !types.Identical(p.Out, typ) { if !types.Identical(p.Out, curr.t) {
// Interface binding. Don't create a call ourselves. // Interface binding. Don't create a call ourselves.
if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil { i := index.At(p.Out)
return err if i == nil {
stk = append(stk, curr, frame{t: p.Out, from: curr.t})
continue
} }
index.Set(typ, index.At(p.Out)) index.Set(curr.t, i)
return nil continue
} }
for _, a := range p.Args { // Ensure that all argument types have been visited. If not, push them
// TODO(light): This will discard grown trail arrays. // on the stack in reverse order so that calls are added in argument
if err := visit(append(trail, a)); err != nil { // order.
return err visitedArgs := true
for i := len(p.Args) - 1; i >= 0; i-- {
a := p.Args[i]
if index.At(a.Type) == nil {
if visitedArgs {
// Make sure to re-visit this type after visiting all arguments.
stk = append(stk, curr)
visitedArgs = false
} }
stk = append(stk, frame{t: a.Type, from: curr.t})
}
}
if !visitedArgs {
continue
} }
args := make([]int, len(p.Args)) args := make([]int, len(p.Args))
ins := make([]types.Type, len(p.Args)) ins := make([]types.Type, len(p.Args))
@@ -147,7 +166,7 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
ins[i] = p.Args[i].Type ins[i] = p.Args[i].Type
args[i] = index.At(p.Args[i].Type).(int) args[i] = index.At(p.Args[i].Type).(int)
} }
index.Set(typ, len(given)+len(calls)) index.Set(curr.t, len(given)+len(calls))
kind := funcProviderCall kind := funcProviderCall
if p.IsStruct { if p.IsStruct {
kind = structProvider kind = structProvider
@@ -159,34 +178,32 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
args: args, args: args,
fieldNames: p.Fields, fieldNames: p.Fields,
ins: ins, ins: ins,
out: typ, out: curr.t,
hasCleanup: p.HasCleanup, hasCleanup: p.HasCleanup,
hasErr: p.HasErr, hasErr: p.HasErr,
}) })
case pv.IsValue(): case pv.IsValue():
v := pv.Value() v := pv.Value()
if !types.Identical(v.Out, typ) { if !types.Identical(v.Out, curr.t) {
// Interface binding. Don't create a call ourselves. // Interface binding. Don't create a call ourselves.
if err := visit(append(trail, ProviderInput{Type: v.Out})); err != nil { i := index.At(v.Out)
return err if i == nil {
stk = append(stk, curr, frame{t: v.Out, from: curr.t})
continue
} }
index.Set(typ, index.At(v.Out)) index.Set(curr.t, i)
return nil continue
} }
index.Set(typ, len(given)+len(calls)) index.Set(curr.t, len(given)+len(calls))
calls = append(calls, call{ calls = append(calls, call{
kind: valueExpr, kind: valueExpr,
out: typ, out: curr.t,
valueExpr: v.expr, valueExpr: v.expr,
valueTypeInfo: v.info, valueTypeInfo: v.info,
}) })
default: default:
panic("unknown return value from ProviderSet.For") panic("unknown return value from ProviderSet.For")
} }
return nil
}
if err := visit([]ProviderInput{{Type: out}}); err != nil {
return nil, err
} }
return calls, nil return calls, nil
} }