diff --git a/internal/wire/analyze.go b/internal/wire/analyze.go index a8e5690..4e40417 100644 --- a/internal/wire/analyze.go +++ b/internal/wire/analyze.go @@ -108,46 +108,65 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide } // Topological sort of the directed graph defined by the providers - // using a depth-first search. Provider set graphs are guaranteed to - // be acyclic. + // using a depth-first search using a stack. Provider set graphs are + // guaranteed to be acyclic. var calls []call - var visit func(trail []ProviderInput) error - visit = func(trail []ProviderInput) error { - typ := trail[len(trail)-1].Type - if index.At(typ) != nil { - return nil + type frame struct { + t types.Type + from types.Type + } + stk := []frame{{t: out}} + for len(stk) > 0 { + curr := stk[len(stk)-1] + stk = stk[:len(stk)-1] + if index.At(curr.t) != nil { + continue } - switch pv := set.For(typ); { + switch pv := set.For(curr.t); { case pv.IsNil(): - if len(trail) == 1 { - return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil)) + if curr.from == nil { + return nil, fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(curr.t, nil)) } // TODO(light): Give name of provider. - return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil)) + return nil, fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(curr.t, nil), types.TypeString(curr.from, nil)) case pv.IsProvider(): p := pv.Provider() - if !types.Identical(p.Out, typ) { + if !types.Identical(p.Out, curr.t) { // Interface binding. Don't create a call ourselves. - if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil { - return err + i := index.At(p.Out) + if i == nil { + stk = append(stk, curr, frame{t: p.Out, from: curr.t}) + continue } - index.Set(typ, index.At(p.Out)) - return nil + index.Set(curr.t, i) + continue } - for _, a := range p.Args { - // TODO(light): This will discard grown trail arrays. - if err := visit(append(trail, a)); err != nil { - return err + // Ensure that all argument types have been visited. If not, push them + // on the stack in reverse order so that calls are added in argument + // order. + visitedArgs := true + for i := len(p.Args) - 1; i >= 0; i-- { + a := p.Args[i] + if index.At(a.Type) == nil { + if visitedArgs { + // Make sure to re-visit this type after visiting all arguments. + stk = append(stk, curr) + visitedArgs = false + } + stk = append(stk, frame{t: a.Type, from: curr.t}) } } + if !visitedArgs { + continue + } args := make([]int, len(p.Args)) ins := make([]types.Type, len(p.Args)) for i := range p.Args { ins[i] = p.Args[i].Type args[i] = index.At(p.Args[i].Type).(int) } - index.Set(typ, len(given)+len(calls)) + index.Set(curr.t, len(given)+len(calls)) kind := funcProviderCall if p.IsStruct { kind = structProvider @@ -159,34 +178,32 @@ func solve(fset *token.FileSet, out types.Type, given []types.Type, set *Provide args: args, fieldNames: p.Fields, ins: ins, - out: typ, + out: curr.t, hasCleanup: p.HasCleanup, hasErr: p.HasErr, }) case pv.IsValue(): v := pv.Value() - if !types.Identical(v.Out, typ) { + if !types.Identical(v.Out, curr.t) { // Interface binding. Don't create a call ourselves. - if err := visit(append(trail, ProviderInput{Type: v.Out})); err != nil { - return err + i := index.At(v.Out) + if i == nil { + stk = append(stk, curr, frame{t: v.Out, from: curr.t}) + continue } - index.Set(typ, index.At(v.Out)) - return nil + index.Set(curr.t, i) + continue } - index.Set(typ, len(given)+len(calls)) + index.Set(curr.t, len(given)+len(calls)) calls = append(calls, call{ kind: valueExpr, - out: typ, + out: curr.t, valueExpr: v.expr, valueTypeInfo: v.info, }) default: panic("unknown return value from ProviderSet.For") } - return nil - } - if err := visit([]ProviderInput{{Type: out}}); err != nil { - return nil, err } return calls, nil }