wire: make solver iterative instead of recursive (google/go-cloud#137)
Primary reason is to make it easier to allow the process to continue and collect errors. This has the side-effect of allowing larger depth graphs since the solver no longer pushes Go stack frames. Updates google/go-cloud#5
This commit is contained in:
@@ -108,46 +108,65 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
||||
}
|
||||
|
||||
// Topological sort of the directed graph defined by the providers
|
||||
// using a depth-first search. Provider set graphs are guaranteed to
|
||||
// be acyclic.
|
||||
// using a depth-first search using a stack. Provider set graphs are
|
||||
// guaranteed to be acyclic.
|
||||
var calls []call
|
||||
var visit func(trail []ProviderInput) error
|
||||
visit = func(trail []ProviderInput) error {
|
||||
typ := trail[len(trail)-1].Type
|
||||
if index.At(typ) != nil {
|
||||
return nil
|
||||
type frame struct {
|
||||
t types.Type
|
||||
from types.Type
|
||||
}
|
||||
stk := []frame{{t: out}}
|
||||
for len(stk) > 0 {
|
||||
curr := stk[len(stk)-1]
|
||||
stk = stk[:len(stk)-1]
|
||||
if index.At(curr.t) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch pv := set.For(typ); {
|
||||
switch pv := set.For(curr.t); {
|
||||
case pv.IsNil():
|
||||
if len(trail) == 1 {
|
||||
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
||||
if curr.from == nil {
|
||||
return nil, fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(curr.t, nil))
|
||||
}
|
||||
// TODO(light): Give name of provider.
|
||||
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil))
|
||||
return nil, fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(curr.t, nil), types.TypeString(curr.from, nil))
|
||||
case pv.IsProvider():
|
||||
p := pv.Provider()
|
||||
if !types.Identical(p.Out, typ) {
|
||||
if !types.Identical(p.Out, curr.t) {
|
||||
// Interface binding. Don't create a call ourselves.
|
||||
if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil {
|
||||
return err
|
||||
i := index.At(p.Out)
|
||||
if i == nil {
|
||||
stk = append(stk, curr, frame{t: p.Out, from: curr.t})
|
||||
continue
|
||||
}
|
||||
index.Set(typ, index.At(p.Out))
|
||||
return nil
|
||||
index.Set(curr.t, i)
|
||||
continue
|
||||
}
|
||||
for _, a := range p.Args {
|
||||
// TODO(light): This will discard grown trail arrays.
|
||||
if err := visit(append(trail, a)); err != nil {
|
||||
return err
|
||||
// Ensure that all argument types have been visited. If not, push them
|
||||
// on the stack in reverse order so that calls are added in argument
|
||||
// order.
|
||||
visitedArgs := true
|
||||
for i := len(p.Args) - 1; i >= 0; i-- {
|
||||
a := p.Args[i]
|
||||
if index.At(a.Type) == nil {
|
||||
if visitedArgs {
|
||||
// Make sure to re-visit this type after visiting all arguments.
|
||||
stk = append(stk, curr)
|
||||
visitedArgs = false
|
||||
}
|
||||
stk = append(stk, frame{t: a.Type, from: curr.t})
|
||||
}
|
||||
}
|
||||
if !visitedArgs {
|
||||
continue
|
||||
}
|
||||
args := make([]int, len(p.Args))
|
||||
ins := make([]types.Type, len(p.Args))
|
||||
for i := range p.Args {
|
||||
ins[i] = p.Args[i].Type
|
||||
args[i] = index.At(p.Args[i].Type).(int)
|
||||
}
|
||||
index.Set(typ, len(given)+len(calls))
|
||||
index.Set(curr.t, len(given)+len(calls))
|
||||
kind := funcProviderCall
|
||||
if p.IsStruct {
|
||||
kind = structProvider
|
||||
@@ -159,34 +178,32 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide
|
||||
args: args,
|
||||
fieldNames: p.Fields,
|
||||
ins: ins,
|
||||
out: typ,
|
||||
out: curr.t,
|
||||
hasCleanup: p.HasCleanup,
|
||||
hasErr: p.HasErr,
|
||||
})
|
||||
case pv.IsValue():
|
||||
v := pv.Value()
|
||||
if !types.Identical(v.Out, typ) {
|
||||
if !types.Identical(v.Out, curr.t) {
|
||||
// Interface binding. Don't create a call ourselves.
|
||||
if err := visit(append(trail, ProviderInput{Type: v.Out})); err != nil {
|
||||
return err
|
||||
i := index.At(v.Out)
|
||||
if i == nil {
|
||||
stk = append(stk, curr, frame{t: v.Out, from: curr.t})
|
||||
continue
|
||||
}
|
||||
index.Set(typ, index.At(v.Out))
|
||||
return nil
|
||||
index.Set(curr.t, i)
|
||||
continue
|
||||
}
|
||||
index.Set(typ, len(given)+len(calls))
|
||||
index.Set(curr.t, len(given)+len(calls))
|
||||
calls = append(calls, call{
|
||||
kind: valueExpr,
|
||||
out: typ,
|
||||
out: curr.t,
|
||||
valueExpr: v.expr,
|
||||
valueTypeInfo: v.info,
|
||||
})
|
||||
default:
|
||||
panic("unknown return value from ProviderSet.For")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if err := visit([]ProviderInput{{Type: out}}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return calls, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user