wire: give wire.Bind access to the arguments to the injector function (google/go-cloud#715)

This commit is contained in:
Robert van Gent
2018-11-16 10:26:10 -08:00
committed by Ross Light
parent 67170e739d
commit 6ea381b3fe
11 changed files with 173 additions and 89 deletions

View File

@@ -80,37 +80,14 @@ type call struct {
// solve finds the sequence of calls required to produce an output type
// with an optional set of provided inputs.
func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, []error) {
func solve(fset *token.FileSet, out types.Type, given *types.Tuple, set *ProviderSet) ([]call, []error) {
ec := new(errorCollector)
for i, g := range given {
for _, h := range given[:i] {
if types.Identical(g, h) {
ec.add(fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil)))
}
}
}
// Start building the mapping of type to local variable of the given type.
// The first len(given) local variables are the given types.
index := new(typeutil.Map)
for i, g := range given {
if pv := set.For(g); !pv.IsNil() {
switch {
case pv.IsProvider():
ec.add(fmt.Errorf("input of %s conflicts with provider %s at %s",
types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos)))
case pv.IsValue():
ec.add(fmt.Errorf("input of %s conflicts with value at %s",
types.TypeString(g, nil), fset.Position(pv.Value().Pos)))
default:
panic("unknown return value from ProviderSet.For")
}
} else {
index.Set(g, i)
}
}
if len(ec.errors) > 0 {
return nil, ec.errors
for i := 0; i < given.Len(); i++ {
index.Set(given.At(i).Type(), i)
}
// Topological sort of the directed graph defined by the providers
@@ -149,6 +126,19 @@ dfs:
ec.add(errors.New(sb.String()))
index.Set(curr.t, errAbort)
continue
case pv.IsArg():
src := set.srcMap.At(curr.t).(*providerSetSrc)
used = append(used, src)
if concrete := pv.ConcreteType(); !types.Identical(concrete, curr.t) {
// Interface binding.
i := index.At(concrete)
if i == nil {
stk = append(stk, curr, frame{t: concrete, from: curr.t, up: &curr})
continue
}
index.Set(curr.t, i)
}
continue
case pv.IsProvider():
p := pv.Provider()
src := set.srcMap.At(curr.t).(*providerSetSrc)
@@ -192,7 +182,7 @@ dfs:
}
args[i] = v.(int)
}
index.Set(curr.t, len(given)+len(calls))
index.Set(curr.t, given.Len()+len(calls))
kind := funcProviderCall
if p.IsStruct {
kind = structProvider
@@ -222,7 +212,7 @@ dfs:
}
src := set.srcMap.At(curr.t).(*providerSetSrc)
used = append(used, src)
index.Set(curr.t, len(given)+len(calls))
index.Set(curr.t, given.Len()+len(calls))
calls = append(calls, call{
kind: valueExpr,
out: curr.t,
@@ -308,8 +298,23 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
srcMap := new(typeutil.Map) // to *providerSetSrc
srcMap.SetHasher(hasher)
// Process imports first, verifying that there are no conflicts between sets.
ec := new(errorCollector)
// Process injector arguments.
if set.InjectorArgs != nil {
givens := set.InjectorArgs.Tuple
for i := 0; i < givens.Len(); i++ {
typ := givens.At(i).Type()
arg := &InjectorArg{Args: set.InjectorArgs, Index: i}
src := &providerSetSrc{InjectorArg: arg}
if prevSrc := srcMap.At(typ); prevSrc != nil {
ec.add(bindingConflictError(fset, typ, set, src, prevSrc.(*providerSetSrc)))
continue
}
providerMap.Set(typ, &ProvidedType{t: typ, a: arg})
srcMap.Set(typ, src)
}
}
// Process imports, verifying that there are no conflicts between sets.
for _, imp := range set.Imports {
src := &providerSetSrc{Import: imp}
imp.providerMap.Iterate(func(k types.Type, v interface{}) {
@@ -407,6 +412,10 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
// Leaf: values do not have dependencies.
continue
}
if pt.IsArg() {
// Injector arguments do not have dependencies.
continue
}
if !pt.IsProvider() {
panic("invalid provider map value")
}