diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 733a4dc..3fc29f8 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -238,6 +238,9 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[ case pv.IsNil(): // This is an input. inputVisited.Set(curr, -1) + case pv.IsArg(): + // This is an injector argument. + inputVisited.Set(curr, -1) case pv.IsProvider(): // Try to see if any args haven't been visited. p := pv.Provider() diff --git a/internal/wire/analyze.go b/internal/wire/analyze.go index 1657bf2..32cf119 100644 --- a/internal/wire/analyze.go +++ b/internal/wire/analyze.go @@ -80,37 +80,14 @@ type call struct { // solve finds the sequence of calls required to produce an output type // with an optional set of provided inputs. -func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, []error) { +func solve(fset *token.FileSet, out types.Type, given *types.Tuple, set *ProviderSet) ([]call, []error) { ec := new(errorCollector) - for i, g := range given { - for _, h := range given[:i] { - if types.Identical(g, h) { - ec.add(fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))) - } - } - } // Start building the mapping of type to local variable of the given type. // The first len(given) local variables are the given types. index := new(typeutil.Map) - for i, g := range given { - if pv := set.For(g); !pv.IsNil() { - switch { - case pv.IsProvider(): - ec.add(fmt.Errorf("input of %s conflicts with provider %s at %s", - types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos))) - case pv.IsValue(): - ec.add(fmt.Errorf("input of %s conflicts with value at %s", - types.TypeString(g, nil), fset.Position(pv.Value().Pos))) - default: - panic("unknown return value from ProviderSet.For") - } - } else { - index.Set(g, i) - } - } - if len(ec.errors) > 0 { - return nil, ec.errors + for i := 0; i < given.Len(); i++ { + index.Set(given.At(i).Type(), i) } // Topological sort of the directed graph defined by the providers @@ -149,6 +126,19 @@ dfs: ec.add(errors.New(sb.String())) index.Set(curr.t, errAbort) continue + case pv.IsArg(): + src := set.srcMap.At(curr.t).(*providerSetSrc) + used = append(used, src) + if concrete := pv.ConcreteType(); !types.Identical(concrete, curr.t) { + // Interface binding. + i := index.At(concrete) + if i == nil { + stk = append(stk, curr, frame{t: concrete, from: curr.t, up: &curr}) + continue + } + index.Set(curr.t, i) + } + continue case pv.IsProvider(): p := pv.Provider() src := set.srcMap.At(curr.t).(*providerSetSrc) @@ -192,7 +182,7 @@ dfs: } args[i] = v.(int) } - index.Set(curr.t, len(given)+len(calls)) + index.Set(curr.t, given.Len()+len(calls)) kind := funcProviderCall if p.IsStruct { kind = structProvider @@ -222,7 +212,7 @@ dfs: } src := set.srcMap.At(curr.t).(*providerSetSrc) used = append(used, src) - index.Set(curr.t, len(given)+len(calls)) + index.Set(curr.t, given.Len()+len(calls)) calls = append(calls, call{ kind: valueExpr, out: curr.t, @@ -308,8 +298,23 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider srcMap := new(typeutil.Map) // to *providerSetSrc srcMap.SetHasher(hasher) - // Process imports first, verifying that there are no conflicts between sets. ec := new(errorCollector) + // Process injector arguments. + if set.InjectorArgs != nil { + givens := set.InjectorArgs.Tuple + for i := 0; i < givens.Len(); i++ { + typ := givens.At(i).Type() + arg := &InjectorArg{Args: set.InjectorArgs, Index: i} + src := &providerSetSrc{InjectorArg: arg} + if prevSrc := srcMap.At(typ); prevSrc != nil { + ec.add(bindingConflictError(fset, typ, set, src, prevSrc.(*providerSetSrc))) + continue + } + providerMap.Set(typ, &ProvidedType{t: typ, a: arg}) + srcMap.Set(typ, src) + } + } + // Process imports, verifying that there are no conflicts between sets. for _, imp := range set.Imports { src := &providerSetSrc{Import: imp} imp.providerMap.Iterate(func(k types.Type, v interface{}) { @@ -407,6 +412,10 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error { // Leaf: values do not have dependencies. continue } + if pt.IsArg() { + // Injector arguments do not have dependencies. + continue + } if !pt.IsProvider() { panic("invalid provider map value") } diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 05864f0..558551e 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -32,10 +32,11 @@ import ( // A providerSetSrc captures the source for a type provided by a ProviderSet. // Exactly one of the fields will be set. type providerSetSrc struct { - Provider *Provider - Binding *IfaceBinding - Value *Value - Import *ProviderSet + Provider *Provider + Binding *IfaceBinding + Value *Value + Import *ProviderSet + InjectorArg *InjectorArg } // description returns a string describing the source of p, including line numbers. @@ -59,6 +60,9 @@ func (p *providerSetSrc) description(fset *token.FileSet, typ types.Type) string return fmt.Sprintf("wire.Value (%s)", fset.Position(p.Value.Pos)) case p.Import != nil: return fmt.Sprintf("provider set %s(%s)", quoted(p.Import.VarName), fset.Position(p.Import.Pos)) + case p.InjectorArg != nil: + args := p.InjectorArg.Args + return fmt.Sprintf("argument %s to injector function %s (%s)", args.Tuple.At(p.InjectorArg.Index).Name(), args.Name, fset.Position(args.Pos)) } panic("providerSetSrc with no fields set") } @@ -93,6 +97,8 @@ type ProviderSet struct { Bindings []*IfaceBinding Values []*Value Imports []*ProviderSet + // InjectorArgs is only filled in for wire.Build. + InjectorArgs *InjectorArgs // providerMap maps from provided type to a *ProvidedType. // It includes all of the imported types. @@ -190,6 +196,24 @@ type Value struct { info *types.Info } +// InjectorArg describes a specific argument passed to an injector function. +type InjectorArg struct { + // Args is the full set of arguments. + Args *InjectorArgs + // Index is the index into Args.Tuple for this argument. + Index int +} + +// InjectorArgs describes the arguments passed to an injector function. +type InjectorArgs struct { + // Name is the name of the injector function. + Name string + // Tuple represents the arguments. + Tuple *types.Tuple + // Pos is the source position of the injector function. + Pos token.Pos +} + // Load finds all the provider sets in the packages that match the given // patterns, as well as the provider sets' transitive dependencies. It // may return both errors and Info. The patterns are defined by the @@ -252,11 +276,6 @@ func Load(ctx context.Context, wd string, env []string, patterns []string) (*Inf if buildCall == nil { continue } - set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, "") - if len(errs) > 0 { - ec.add(notePositionAll(fset.Position(fn.Pos()), errs)...) - continue - } sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature) ins, out, err := injectorFuncSignature(sig) if err != nil { @@ -267,6 +286,16 @@ func Load(ctx context.Context, wd string, env []string, patterns []string) (*Inf } continue } + injectorArgs := &InjectorArgs{ + Name: fn.Name.Name, + Tuple: ins, + Pos: fn.Pos(), + } + set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, injectorArgs, "") + if len(errs) > 0 { + ec.add(notePositionAll(fset.Position(fn.Pos()), errs)...) + continue + } _, errs = solve(fset, out.out, ins, set) if len(errs) > 0 { ec.add(mapErrors(errs, func(e error) error { @@ -482,7 +511,7 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex } switch fnObj.Name() { case "NewSet": - pset, errs := oc.processNewSet(info, pkgPath, call, varName) + pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName) return pset, notePositionAll(exprPos, errs) case "Bind": b, err := processBind(oc.fset, info, call) @@ -516,13 +545,14 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))} } -func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast.CallExpr, varName string) (*ProviderSet, []error) { +func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (*ProviderSet, []error) { // Assumes that call.Fun is wire.NewSet or wire.Build. pset := &ProviderSet{ - Pos: call.Pos(), - PkgPath: pkgPath, - VarName: varName, + Pos: call.Pos(), + InjectorArgs: args, + PkgPath: pkgPath, + VarName: varName, } ec := new(errorCollector) for _, arg := range call.Args { @@ -626,17 +656,12 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro return provider, nil } -func injectorFuncSignature(sig *types.Signature) ([]types.Type, outputSignature, error) { +func injectorFuncSignature(sig *types.Signature) (*types.Tuple, outputSignature, error) { out, err := funcOutput(sig) if err != nil { return nil, outputSignature{}, err } - params := sig.Params() - given := make([]types.Type, params.Len()) - for i := 0; i < params.Len(); i++ { - given[i] = params.At(i).Type() - } - return given, out, nil + return sig.Params(), out, nil } type outputSignature struct { @@ -893,49 +918,66 @@ func isProviderSetType(t types.Type) bool { return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet" } -// ProvidedType is a pointer to a Provider or a Value. The zero value is -// a nil pointer. It also holds the concrete type that the Provider or Value -// provided. +// ProvidedType represents a type provided from a source. The source +// can be a *Provider (a provider function), a *Value (wire.Value), or an +// *InjectorArgs (arguments to the injector function). The zero value has +// none of the above, and returns true for IsNil. type ProvidedType struct { + // t is the provided concrete type. t types.Type p *Provider v *Value + a *InjectorArg } -// IsNil reports whether pv is the zero value. -func (pv ProvidedType) IsNil() bool { - return pv.p == nil && pv.v == nil +// IsNil reports whether pt is the zero value. +func (pt ProvidedType) IsNil() bool { + return pt.p == nil && pt.v == nil && pt.a == nil } // ConcreteType returns the concrete type that was provided. -func (pv ProvidedType) ConcreteType() types.Type { - return pv.t +func (pt ProvidedType) ConcreteType() types.Type { + return pt.t } -// IsProvider reports whether pv points to a Provider. -func (pv ProvidedType) IsProvider() bool { - return pv.p != nil +// IsProvider reports whether pt points to a Provider. +func (pt ProvidedType) IsProvider() bool { + return pt.p != nil } -// IsValue reports whether pv points to a Value. -func (pv ProvidedType) IsValue() bool { - return pv.v != nil +// IsValue reports whether pt points to a Value. +func (pt ProvidedType) IsValue() bool { + return pt.v != nil } -// Provider returns pv as a Provider pointer. It panics if pv points to a -// Value. -func (pv ProvidedType) Provider() *Provider { - if pv.v != nil { - panic("Value pointer converted to a Provider") +// IsArg reports whether pt points to an injector argument. +func (pt ProvidedType) IsArg() bool { + return pt.a != nil +} + +// Provider returns pt as a Provider pointer. It panics if pt does not point +// to a Provider. +func (pt ProvidedType) Provider() *Provider { + if pt.p == nil { + panic("ProvidedType does not hold a Provider") } - return pv.p + return pt.p } -// Value returns pv as a Value pointer. It panics if pv points to a -// Provider. -func (pv ProvidedType) Value() *Value { - if pv.p != nil { - panic("Provider pointer converted to a Value") +// Value returns pt as a Value pointer. It panics if pt does not point +// to a Value. +func (pt ProvidedType) Value() *Value { + if pt.v == nil { + panic("ProvidedType does not hold a Value") } - return pv.v + return pt.v +} + +// Arg returns pt as an *InjectorArg representing an injector argument. It +// panics if pt does not point to an arg. +func (pt ProvidedType) Arg() *InjectorArg { + if pt.a == nil { + panic("ProvidedType does not hold an Arg") + } + return pt.a } diff --git a/internal/wire/testdata/BindInjectorArg/want/program_out.txt b/internal/wire/testdata/BindInjectorArg/want/program_out.txt new file mode 100644 index 0000000..ce01362 --- /dev/null +++ b/internal/wire/testdata/BindInjectorArg/want/program_out.txt @@ -0,0 +1 @@ +hello diff --git a/internal/wire/testdata/BindInjectorArg/want/wire_errs.txt b/internal/wire/testdata/BindInjectorArg/want/wire_errs.txt deleted file mode 100644 index 31c37a8..0000000 --- a/internal/wire/testdata/BindInjectorArg/want/wire_errs.txt +++ /dev/null @@ -1 +0,0 @@ -example.com/foo/wire.go:x:y: no binding for *example.com/foo.Foo \ No newline at end of file diff --git a/internal/wire/testdata/BindInjectorArg/want/wire_gen.go b/internal/wire/testdata/BindInjectorArg/want/wire_gen.go new file mode 100644 index 0000000..abbd9c6 --- /dev/null +++ b/internal/wire/testdata/BindInjectorArg/want/wire_gen.go @@ -0,0 +1,13 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate wire +//+build !wireinject + +package main + +// Injectors from wire.go: + +func inject(foo *Foo) *Bar { + bar := NewBar(foo) + return bar +} diff --git a/internal/wire/testdata/InjectInputConflict/want/wire_errs.txt b/internal/wire/testdata/InjectInputConflict/want/wire_errs.txt index f3b2c2c..7ab4d29 100644 --- a/internal/wire/testdata/InjectInputConflict/want/wire_errs.txt +++ b/internal/wire/testdata/InjectInputConflict/want/wire_errs.txt @@ -1 +1,6 @@ -example.com/foo/wire.go:x:y: inject injectBar: input of example.com/foo.Foo conflicts with provider provideFoo at example.com/foo/foo.go:x:y \ No newline at end of file +example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo +current: +<- provider "provideFoo" (example.com/foo/foo.go:x:y) +<- provider set "Set" (example.com/foo/foo.go:x:y) +previous: +<- argument foo to injector function injectBar (example.com/foo/wire.go:x:y) \ No newline at end of file diff --git a/internal/wire/testdata/InvalidInjector/foo/wire.go b/internal/wire/testdata/InvalidInjector/foo/wire.go index 6ddd3bd..203f81e 100644 --- a/internal/wire/testdata/InvalidInjector/foo/wire.go +++ b/internal/wire/testdata/InvalidInjector/foo/wire.go @@ -22,7 +22,7 @@ import ( func injectFoo() Foo { // This non-call statement makes this an invalid injector. - _ = 42 + _ = 42 panic(wire.Build(provideFoo)) } diff --git a/internal/wire/testdata/MultipleArgsSameType/want/wire_errs.txt b/internal/wire/testdata/MultipleArgsSameType/want/wire_errs.txt index ad3bf9c..7a56e43 100644 --- a/internal/wire/testdata/MultipleArgsSameType/want/wire_errs.txt +++ b/internal/wire/testdata/MultipleArgsSameType/want/wire_errs.txt @@ -1 +1,5 @@ -example.com/foo/wire.go:x:y: inject inject: multiple inputs of the same type string \ No newline at end of file +example.com/foo/wire.go:x:y: wire.Build has multiple bindings for string +current: +<- argument b to injector function inject (example.com/foo/wire.go:x:y) +previous: +<- argument a to injector function inject (example.com/foo/wire.go:x:y) \ No newline at end of file diff --git a/internal/wire/testdata/UnexportedStruct/bar/bar.go b/internal/wire/testdata/UnexportedStruct/bar/bar.go index dbc5f97..7215742 100644 --- a/internal/wire/testdata/UnexportedStruct/bar/bar.go +++ b/internal/wire/testdata/UnexportedStruct/bar/bar.go @@ -17,5 +17,3 @@ package bar var foo struct { X int } - - diff --git a/internal/wire/wire.go b/internal/wire/wire.go index f143cb9..64db9e8 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -132,12 +132,26 @@ func generateInjectors(g *gen, pkg *packages.Package) (injectorFiles []*ast.File g.p("// Injectors from %s:\n\n", name) injectorFiles = append(injectorFiles, f) } - set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, "") + sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature) + ins, _, err := injectorFuncSignature(sig) + if err != nil { + if w, ok := err.(*wireErr); ok { + ec.add(notePosition(w.position, fmt.Errorf("inject %s: %v", fn.Name.Name, w.error))) + } else { + ec.add(notePosition(g.pkg.Fset.Position(fn.Pos()), fmt.Errorf("inject %s: %v", fn.Name.Name, err))) + } + continue + } + injectorArgs := &InjectorArgs{ + Name: fn.Name.Name, + Tuple: ins, + Pos: fn.Pos(), + } + set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, injectorArgs, "") if len(errs) > 0 { ec.add(notePositionAll(g.pkg.Fset.Position(fn.Pos()), errs)...) continue } - sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature) if errs := g.inject(fn.Pos(), fn.Name.Name, sig, set); len(errs) > 0 { ec.add(errs...) continue @@ -249,11 +263,7 @@ func (g *gen) inject(pos token.Pos, name string, sig *types.Signature, set *Prov fmt.Errorf("inject %s: %v", name, err))} } params := sig.Params() - given := make([]types.Type, params.Len()) - for i := 0; i < params.Len(); i++ { - given[i] = params.At(i).Type() - } - calls, errs := solve(g.pkg.Fset, injectSig.out, given, set) + calls, errs := solve(g.pkg.Fset, injectSig.out, params, set) if len(errs) > 0 { return mapErrors(errs, func(e error) error { if w, ok := e.(*wireErr); ok {