wire: give wire.Bind access to the arguments to the injector function (google/go-cloud#715)
This commit is contained in:
committed by
Ross Light
parent
67170e739d
commit
6ea381b3fe
@@ -238,6 +238,9 @@ func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[
|
||||
case pv.IsNil():
|
||||
// This is an input.
|
||||
inputVisited.Set(curr, -1)
|
||||
case pv.IsArg():
|
||||
// This is an injector argument.
|
||||
inputVisited.Set(curr, -1)
|
||||
case pv.IsProvider():
|
||||
// Try to see if any args haven't been visited.
|
||||
p := pv.Provider()
|
||||
|
||||
@@ -80,37 +80,14 @@ type call struct {
|
||||
|
||||
// solve finds the sequence of calls required to produce an output type
|
||||
// with an optional set of provided inputs.
|
||||
func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, []error) {
|
||||
func solve(fset *token.FileSet, out types.Type, given *types.Tuple, set *ProviderSet) ([]call, []error) {
|
||||
ec := new(errorCollector)
|
||||
for i, g := range given {
|
||||
for _, h := range given[:i] {
|
||||
if types.Identical(g, h) {
|
||||
ec.add(fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start building the mapping of type to local variable of the given type.
|
||||
// The first len(given) local variables are the given types.
|
||||
index := new(typeutil.Map)
|
||||
for i, g := range given {
|
||||
if pv := set.For(g); !pv.IsNil() {
|
||||
switch {
|
||||
case pv.IsProvider():
|
||||
ec.add(fmt.Errorf("input of %s conflicts with provider %s at %s",
|
||||
types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos)))
|
||||
case pv.IsValue():
|
||||
ec.add(fmt.Errorf("input of %s conflicts with value at %s",
|
||||
types.TypeString(g, nil), fset.Position(pv.Value().Pos)))
|
||||
default:
|
||||
panic("unknown return value from ProviderSet.For")
|
||||
}
|
||||
} else {
|
||||
index.Set(g, i)
|
||||
}
|
||||
}
|
||||
if len(ec.errors) > 0 {
|
||||
return nil, ec.errors
|
||||
for i := 0; i < given.Len(); i++ {
|
||||
index.Set(given.At(i).Type(), i)
|
||||
}
|
||||
|
||||
// Topological sort of the directed graph defined by the providers
|
||||
@@ -149,6 +126,19 @@ dfs:
|
||||
ec.add(errors.New(sb.String()))
|
||||
index.Set(curr.t, errAbort)
|
||||
continue
|
||||
case pv.IsArg():
|
||||
src := set.srcMap.At(curr.t).(*providerSetSrc)
|
||||
used = append(used, src)
|
||||
if concrete := pv.ConcreteType(); !types.Identical(concrete, curr.t) {
|
||||
// Interface binding.
|
||||
i := index.At(concrete)
|
||||
if i == nil {
|
||||
stk = append(stk, curr, frame{t: concrete, from: curr.t, up: &curr})
|
||||
continue
|
||||
}
|
||||
index.Set(curr.t, i)
|
||||
}
|
||||
continue
|
||||
case pv.IsProvider():
|
||||
p := pv.Provider()
|
||||
src := set.srcMap.At(curr.t).(*providerSetSrc)
|
||||
@@ -192,7 +182,7 @@ dfs:
|
||||
}
|
||||
args[i] = v.(int)
|
||||
}
|
||||
index.Set(curr.t, len(given)+len(calls))
|
||||
index.Set(curr.t, given.Len()+len(calls))
|
||||
kind := funcProviderCall
|
||||
if p.IsStruct {
|
||||
kind = structProvider
|
||||
@@ -222,7 +212,7 @@ dfs:
|
||||
}
|
||||
src := set.srcMap.At(curr.t).(*providerSetSrc)
|
||||
used = append(used, src)
|
||||
index.Set(curr.t, len(given)+len(calls))
|
||||
index.Set(curr.t, given.Len()+len(calls))
|
||||
calls = append(calls, call{
|
||||
kind: valueExpr,
|
||||
out: curr.t,
|
||||
@@ -308,8 +298,23 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
|
||||
srcMap := new(typeutil.Map) // to *providerSetSrc
|
||||
srcMap.SetHasher(hasher)
|
||||
|
||||
// Process imports first, verifying that there are no conflicts between sets.
|
||||
ec := new(errorCollector)
|
||||
// Process injector arguments.
|
||||
if set.InjectorArgs != nil {
|
||||
givens := set.InjectorArgs.Tuple
|
||||
for i := 0; i < givens.Len(); i++ {
|
||||
typ := givens.At(i).Type()
|
||||
arg := &InjectorArg{Args: set.InjectorArgs, Index: i}
|
||||
src := &providerSetSrc{InjectorArg: arg}
|
||||
if prevSrc := srcMap.At(typ); prevSrc != nil {
|
||||
ec.add(bindingConflictError(fset, typ, set, src, prevSrc.(*providerSetSrc)))
|
||||
continue
|
||||
}
|
||||
providerMap.Set(typ, &ProvidedType{t: typ, a: arg})
|
||||
srcMap.Set(typ, src)
|
||||
}
|
||||
}
|
||||
// Process imports, verifying that there are no conflicts between sets.
|
||||
for _, imp := range set.Imports {
|
||||
src := &providerSetSrc{Import: imp}
|
||||
imp.providerMap.Iterate(func(k types.Type, v interface{}) {
|
||||
@@ -407,6 +412,10 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
|
||||
// Leaf: values do not have dependencies.
|
||||
continue
|
||||
}
|
||||
if pt.IsArg() {
|
||||
// Injector arguments do not have dependencies.
|
||||
continue
|
||||
}
|
||||
if !pt.IsProvider() {
|
||||
panic("invalid provider map value")
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ type providerSetSrc struct {
|
||||
Binding *IfaceBinding
|
||||
Value *Value
|
||||
Import *ProviderSet
|
||||
InjectorArg *InjectorArg
|
||||
}
|
||||
|
||||
// description returns a string describing the source of p, including line numbers.
|
||||
@@ -59,6 +60,9 @@ func (p *providerSetSrc) description(fset *token.FileSet, typ types.Type) string
|
||||
return fmt.Sprintf("wire.Value (%s)", fset.Position(p.Value.Pos))
|
||||
case p.Import != nil:
|
||||
return fmt.Sprintf("provider set %s(%s)", quoted(p.Import.VarName), fset.Position(p.Import.Pos))
|
||||
case p.InjectorArg != nil:
|
||||
args := p.InjectorArg.Args
|
||||
return fmt.Sprintf("argument %s to injector function %s (%s)", args.Tuple.At(p.InjectorArg.Index).Name(), args.Name, fset.Position(args.Pos))
|
||||
}
|
||||
panic("providerSetSrc with no fields set")
|
||||
}
|
||||
@@ -93,6 +97,8 @@ type ProviderSet struct {
|
||||
Bindings []*IfaceBinding
|
||||
Values []*Value
|
||||
Imports []*ProviderSet
|
||||
// InjectorArgs is only filled in for wire.Build.
|
||||
InjectorArgs *InjectorArgs
|
||||
|
||||
// providerMap maps from provided type to a *ProvidedType.
|
||||
// It includes all of the imported types.
|
||||
@@ -190,6 +196,24 @@ type Value struct {
|
||||
info *types.Info
|
||||
}
|
||||
|
||||
// InjectorArg describes a specific argument passed to an injector function.
|
||||
type InjectorArg struct {
|
||||
// Args is the full set of arguments.
|
||||
Args *InjectorArgs
|
||||
// Index is the index into Args.Tuple for this argument.
|
||||
Index int
|
||||
}
|
||||
|
||||
// InjectorArgs describes the arguments passed to an injector function.
|
||||
type InjectorArgs struct {
|
||||
// Name is the name of the injector function.
|
||||
Name string
|
||||
// Tuple represents the arguments.
|
||||
Tuple *types.Tuple
|
||||
// Pos is the source position of the injector function.
|
||||
Pos token.Pos
|
||||
}
|
||||
|
||||
// Load finds all the provider sets in the packages that match the given
|
||||
// patterns, as well as the provider sets' transitive dependencies. It
|
||||
// may return both errors and Info. The patterns are defined by the
|
||||
@@ -252,11 +276,6 @@ func Load(ctx context.Context, wd string, env []string, patterns []string) (*Inf
|
||||
if buildCall == nil {
|
||||
continue
|
||||
}
|
||||
set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, "")
|
||||
if len(errs) > 0 {
|
||||
ec.add(notePositionAll(fset.Position(fn.Pos()), errs)...)
|
||||
continue
|
||||
}
|
||||
sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||
ins, out, err := injectorFuncSignature(sig)
|
||||
if err != nil {
|
||||
@@ -267,6 +286,16 @@ func Load(ctx context.Context, wd string, env []string, patterns []string) (*Inf
|
||||
}
|
||||
continue
|
||||
}
|
||||
injectorArgs := &InjectorArgs{
|
||||
Name: fn.Name.Name,
|
||||
Tuple: ins,
|
||||
Pos: fn.Pos(),
|
||||
}
|
||||
set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, injectorArgs, "")
|
||||
if len(errs) > 0 {
|
||||
ec.add(notePositionAll(fset.Position(fn.Pos()), errs)...)
|
||||
continue
|
||||
}
|
||||
_, errs = solve(fset, out.out, ins, set)
|
||||
if len(errs) > 0 {
|
||||
ec.add(mapErrors(errs, func(e error) error {
|
||||
@@ -482,7 +511,7 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
|
||||
}
|
||||
switch fnObj.Name() {
|
||||
case "NewSet":
|
||||
pset, errs := oc.processNewSet(info, pkgPath, call, varName)
|
||||
pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName)
|
||||
return pset, notePositionAll(exprPos, errs)
|
||||
case "Bind":
|
||||
b, err := processBind(oc.fset, info, call)
|
||||
@@ -516,11 +545,12 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
|
||||
return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))}
|
||||
}
|
||||
|
||||
func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast.CallExpr, varName string) (*ProviderSet, []error) {
|
||||
func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (*ProviderSet, []error) {
|
||||
// Assumes that call.Fun is wire.NewSet or wire.Build.
|
||||
|
||||
pset := &ProviderSet{
|
||||
Pos: call.Pos(),
|
||||
InjectorArgs: args,
|
||||
PkgPath: pkgPath,
|
||||
VarName: varName,
|
||||
}
|
||||
@@ -626,17 +656,12 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func injectorFuncSignature(sig *types.Signature) ([]types.Type, outputSignature, error) {
|
||||
func injectorFuncSignature(sig *types.Signature) (*types.Tuple, outputSignature, error) {
|
||||
out, err := funcOutput(sig)
|
||||
if err != nil {
|
||||
return nil, outputSignature{}, err
|
||||
}
|
||||
params := sig.Params()
|
||||
given := make([]types.Type, params.Len())
|
||||
for i := 0; i < params.Len(); i++ {
|
||||
given[i] = params.At(i).Type()
|
||||
}
|
||||
return given, out, nil
|
||||
return sig.Params(), out, nil
|
||||
}
|
||||
|
||||
type outputSignature struct {
|
||||
@@ -893,49 +918,66 @@ func isProviderSetType(t types.Type) bool {
|
||||
return obj.Pkg() != nil && isWireImport(obj.Pkg().Path()) && obj.Name() == "ProviderSet"
|
||||
}
|
||||
|
||||
// ProvidedType is a pointer to a Provider or a Value. The zero value is
|
||||
// a nil pointer. It also holds the concrete type that the Provider or Value
|
||||
// provided.
|
||||
// ProvidedType represents a type provided from a source. The source
|
||||
// can be a *Provider (a provider function), a *Value (wire.Value), or an
|
||||
// *InjectorArgs (arguments to the injector function). The zero value has
|
||||
// none of the above, and returns true for IsNil.
|
||||
type ProvidedType struct {
|
||||
// t is the provided concrete type.
|
||||
t types.Type
|
||||
p *Provider
|
||||
v *Value
|
||||
a *InjectorArg
|
||||
}
|
||||
|
||||
// IsNil reports whether pv is the zero value.
|
||||
func (pv ProvidedType) IsNil() bool {
|
||||
return pv.p == nil && pv.v == nil
|
||||
// IsNil reports whether pt is the zero value.
|
||||
func (pt ProvidedType) IsNil() bool {
|
||||
return pt.p == nil && pt.v == nil && pt.a == nil
|
||||
}
|
||||
|
||||
// ConcreteType returns the concrete type that was provided.
|
||||
func (pv ProvidedType) ConcreteType() types.Type {
|
||||
return pv.t
|
||||
func (pt ProvidedType) ConcreteType() types.Type {
|
||||
return pt.t
|
||||
}
|
||||
|
||||
// IsProvider reports whether pv points to a Provider.
|
||||
func (pv ProvidedType) IsProvider() bool {
|
||||
return pv.p != nil
|
||||
// IsProvider reports whether pt points to a Provider.
|
||||
func (pt ProvidedType) IsProvider() bool {
|
||||
return pt.p != nil
|
||||
}
|
||||
|
||||
// IsValue reports whether pv points to a Value.
|
||||
func (pv ProvidedType) IsValue() bool {
|
||||
return pv.v != nil
|
||||
// IsValue reports whether pt points to a Value.
|
||||
func (pt ProvidedType) IsValue() bool {
|
||||
return pt.v != nil
|
||||
}
|
||||
|
||||
// Provider returns pv as a Provider pointer. It panics if pv points to a
|
||||
// Value.
|
||||
func (pv ProvidedType) Provider() *Provider {
|
||||
if pv.v != nil {
|
||||
panic("Value pointer converted to a Provider")
|
||||
}
|
||||
return pv.p
|
||||
// IsArg reports whether pt points to an injector argument.
|
||||
func (pt ProvidedType) IsArg() bool {
|
||||
return pt.a != nil
|
||||
}
|
||||
|
||||
// Value returns pv as a Value pointer. It panics if pv points to a
|
||||
// Provider.
|
||||
func (pv ProvidedType) Value() *Value {
|
||||
if pv.p != nil {
|
||||
panic("Provider pointer converted to a Value")
|
||||
// Provider returns pt as a Provider pointer. It panics if pt does not point
|
||||
// to a Provider.
|
||||
func (pt ProvidedType) Provider() *Provider {
|
||||
if pt.p == nil {
|
||||
panic("ProvidedType does not hold a Provider")
|
||||
}
|
||||
return pv.v
|
||||
return pt.p
|
||||
}
|
||||
|
||||
// Value returns pt as a Value pointer. It panics if pt does not point
|
||||
// to a Value.
|
||||
func (pt ProvidedType) Value() *Value {
|
||||
if pt.v == nil {
|
||||
panic("ProvidedType does not hold a Value")
|
||||
}
|
||||
return pt.v
|
||||
}
|
||||
|
||||
// Arg returns pt as an *InjectorArg representing an injector argument. It
|
||||
// panics if pt does not point to an arg.
|
||||
func (pt ProvidedType) Arg() *InjectorArg {
|
||||
if pt.a == nil {
|
||||
panic("ProvidedType does not hold an Arg")
|
||||
}
|
||||
return pt.a
|
||||
}
|
||||
|
||||
1
internal/wire/testdata/BindInjectorArg/want/program_out.txt
vendored
Normal file
1
internal/wire/testdata/BindInjectorArg/want/program_out.txt
vendored
Normal file
@@ -0,0 +1 @@
|
||||
hello
|
||||
@@ -1 +0,0 @@
|
||||
example.com/foo/wire.go:x:y: no binding for *example.com/foo.Foo
|
||||
13
internal/wire/testdata/BindInjectorArg/want/wire_gen.go
vendored
Normal file
13
internal/wire/testdata/BindInjectorArg/want/wire_gen.go
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
// Code generated by Wire. DO NOT EDIT.
|
||||
|
||||
//go:generate wire
|
||||
//+build !wireinject
|
||||
|
||||
package main
|
||||
|
||||
// Injectors from wire.go:
|
||||
|
||||
func inject(foo *Foo) *Bar {
|
||||
bar := NewBar(foo)
|
||||
return bar
|
||||
}
|
||||
@@ -1 +1,6 @@
|
||||
example.com/foo/wire.go:x:y: inject injectBar: input of example.com/foo.Foo conflicts with provider provideFoo at example.com/foo/foo.go:x:y
|
||||
example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo
|
||||
current:
|
||||
<- provider "provideFoo" (example.com/foo/foo.go:x:y)
|
||||
<- provider set "Set" (example.com/foo/foo.go:x:y)
|
||||
previous:
|
||||
<- argument foo to injector function injectBar (example.com/foo/wire.go:x:y)
|
||||
@@ -1 +1,5 @@
|
||||
example.com/foo/wire.go:x:y: inject inject: multiple inputs of the same type string
|
||||
example.com/foo/wire.go:x:y: wire.Build has multiple bindings for string
|
||||
current:
|
||||
<- argument b to injector function inject (example.com/foo/wire.go:x:y)
|
||||
previous:
|
||||
<- argument a to injector function inject (example.com/foo/wire.go:x:y)
|
||||
@@ -17,5 +17,3 @@ package bar
|
||||
var foo struct {
|
||||
X int
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -132,12 +132,26 @@ func generateInjectors(g *gen, pkg *packages.Package) (injectorFiles []*ast.File
|
||||
g.p("// Injectors from %s:\n\n", name)
|
||||
injectorFiles = append(injectorFiles, f)
|
||||
}
|
||||
set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, "")
|
||||
sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||
ins, _, err := injectorFuncSignature(sig)
|
||||
if err != nil {
|
||||
if w, ok := err.(*wireErr); ok {
|
||||
ec.add(notePosition(w.position, fmt.Errorf("inject %s: %v", fn.Name.Name, w.error)))
|
||||
} else {
|
||||
ec.add(notePosition(g.pkg.Fset.Position(fn.Pos()), fmt.Errorf("inject %s: %v", fn.Name.Name, err)))
|
||||
}
|
||||
continue
|
||||
}
|
||||
injectorArgs := &InjectorArgs{
|
||||
Name: fn.Name.Name,
|
||||
Tuple: ins,
|
||||
Pos: fn.Pos(),
|
||||
}
|
||||
set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, injectorArgs, "")
|
||||
if len(errs) > 0 {
|
||||
ec.add(notePositionAll(g.pkg.Fset.Position(fn.Pos()), errs)...)
|
||||
continue
|
||||
}
|
||||
sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||
if errs := g.inject(fn.Pos(), fn.Name.Name, sig, set); len(errs) > 0 {
|
||||
ec.add(errs...)
|
||||
continue
|
||||
@@ -249,11 +263,7 @@ func (g *gen) inject(pos token.Pos, name string, sig *types.Signature, set *Prov
|
||||
fmt.Errorf("inject %s: %v", name, err))}
|
||||
}
|
||||
params := sig.Params()
|
||||
given := make([]types.Type, params.Len())
|
||||
for i := 0; i < params.Len(); i++ {
|
||||
given[i] = params.At(i).Type()
|
||||
}
|
||||
calls, errs := solve(g.pkg.Fset, injectSig.out, given, set)
|
||||
calls, errs := solve(g.pkg.Fset, injectSig.out, params, set)
|
||||
if len(errs) > 0 {
|
||||
return mapErrors(errs, func(e error) error {
|
||||
if w, ok := e.(*wireErr); ok {
|
||||
|
||||
Reference in New Issue
Block a user