wire: give wire.Bind access to the arguments to the injector function (google/go-cloud#715)
This commit is contained in:
committed by
Ross Light
parent
67170e739d
commit
6ea381b3fe
@@ -80,37 +80,14 @@ type call struct {
|
||||
|
||||
// solve finds the sequence of calls required to produce an output type
|
||||
// with an optional set of provided inputs.
|
||||
func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, []error) {
|
||||
func solve(fset *token.FileSet, out types.Type, given *types.Tuple, set *ProviderSet) ([]call, []error) {
|
||||
ec := new(errorCollector)
|
||||
for i, g := range given {
|
||||
for _, h := range given[:i] {
|
||||
if types.Identical(g, h) {
|
||||
ec.add(fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start building the mapping of type to local variable of the given type.
|
||||
// The first len(given) local variables are the given types.
|
||||
index := new(typeutil.Map)
|
||||
for i, g := range given {
|
||||
if pv := set.For(g); !pv.IsNil() {
|
||||
switch {
|
||||
case pv.IsProvider():
|
||||
ec.add(fmt.Errorf("input of %s conflicts with provider %s at %s",
|
||||
types.TypeString(g, nil), pv.Provider().Name, fset.Position(pv.Provider().Pos)))
|
||||
case pv.IsValue():
|
||||
ec.add(fmt.Errorf("input of %s conflicts with value at %s",
|
||||
types.TypeString(g, nil), fset.Position(pv.Value().Pos)))
|
||||
default:
|
||||
panic("unknown return value from ProviderSet.For")
|
||||
}
|
||||
} else {
|
||||
index.Set(g, i)
|
||||
}
|
||||
}
|
||||
if len(ec.errors) > 0 {
|
||||
return nil, ec.errors
|
||||
for i := 0; i < given.Len(); i++ {
|
||||
index.Set(given.At(i).Type(), i)
|
||||
}
|
||||
|
||||
// Topological sort of the directed graph defined by the providers
|
||||
@@ -149,6 +126,19 @@ dfs:
|
||||
ec.add(errors.New(sb.String()))
|
||||
index.Set(curr.t, errAbort)
|
||||
continue
|
||||
case pv.IsArg():
|
||||
src := set.srcMap.At(curr.t).(*providerSetSrc)
|
||||
used = append(used, src)
|
||||
if concrete := pv.ConcreteType(); !types.Identical(concrete, curr.t) {
|
||||
// Interface binding.
|
||||
i := index.At(concrete)
|
||||
if i == nil {
|
||||
stk = append(stk, curr, frame{t: concrete, from: curr.t, up: &curr})
|
||||
continue
|
||||
}
|
||||
index.Set(curr.t, i)
|
||||
}
|
||||
continue
|
||||
case pv.IsProvider():
|
||||
p := pv.Provider()
|
||||
src := set.srcMap.At(curr.t).(*providerSetSrc)
|
||||
@@ -192,7 +182,7 @@ dfs:
|
||||
}
|
||||
args[i] = v.(int)
|
||||
}
|
||||
index.Set(curr.t, len(given)+len(calls))
|
||||
index.Set(curr.t, given.Len()+len(calls))
|
||||
kind := funcProviderCall
|
||||
if p.IsStruct {
|
||||
kind = structProvider
|
||||
@@ -222,7 +212,7 @@ dfs:
|
||||
}
|
||||
src := set.srcMap.At(curr.t).(*providerSetSrc)
|
||||
used = append(used, src)
|
||||
index.Set(curr.t, len(given)+len(calls))
|
||||
index.Set(curr.t, given.Len()+len(calls))
|
||||
calls = append(calls, call{
|
||||
kind: valueExpr,
|
||||
out: curr.t,
|
||||
@@ -308,8 +298,23 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
|
||||
srcMap := new(typeutil.Map) // to *providerSetSrc
|
||||
srcMap.SetHasher(hasher)
|
||||
|
||||
// Process imports first, verifying that there are no conflicts between sets.
|
||||
ec := new(errorCollector)
|
||||
// Process injector arguments.
|
||||
if set.InjectorArgs != nil {
|
||||
givens := set.InjectorArgs.Tuple
|
||||
for i := 0; i < givens.Len(); i++ {
|
||||
typ := givens.At(i).Type()
|
||||
arg := &InjectorArg{Args: set.InjectorArgs, Index: i}
|
||||
src := &providerSetSrc{InjectorArg: arg}
|
||||
if prevSrc := srcMap.At(typ); prevSrc != nil {
|
||||
ec.add(bindingConflictError(fset, typ, set, src, prevSrc.(*providerSetSrc)))
|
||||
continue
|
||||
}
|
||||
providerMap.Set(typ, &ProvidedType{t: typ, a: arg})
|
||||
srcMap.Set(typ, src)
|
||||
}
|
||||
}
|
||||
// Process imports, verifying that there are no conflicts between sets.
|
||||
for _, imp := range set.Imports {
|
||||
src := &providerSetSrc{Import: imp}
|
||||
imp.providerMap.Iterate(func(k types.Type, v interface{}) {
|
||||
@@ -407,6 +412,10 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
|
||||
// Leaf: values do not have dependencies.
|
||||
continue
|
||||
}
|
||||
if pt.IsArg() {
|
||||
// Injector arguments do not have dependencies.
|
||||
continue
|
||||
}
|
||||
if !pt.IsProvider() {
|
||||
panic("invalid provider map value")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user