add FieldsOf to inject fields of a struct directly (#138)
This commit is contained in:
@@ -32,6 +32,7 @@ const (
|
||||
funcProviderCall callKind = iota
|
||||
structProvider
|
||||
valueExpr
|
||||
selectorExpr
|
||||
)
|
||||
|
||||
// A call represents a step of an injector function. It may be either a
|
||||
@@ -44,17 +45,21 @@ type call struct {
|
||||
// out is the type this step produces.
|
||||
out types.Type
|
||||
|
||||
// pkg and name identify the provider to call for kind ==
|
||||
// funcProviderCall or the type to construct for kind ==
|
||||
// structProvider.
|
||||
// pkg and name identify one of the following:
|
||||
// 1) the provider to call for kind == funcProviderCall;
|
||||
// 2) the type to construct for kind == structProvider;
|
||||
// 3) the name to select for kind == selectorExpr.
|
||||
pkg *types.Package
|
||||
name string
|
||||
|
||||
// args is a list of arguments to call the provider with. Each element is:
|
||||
// a) one of the givens (args[i] < len(given)), or
|
||||
// args is a list of arguments to call the provider with. Each element is:
|
||||
// a) one of the givens (args[i] < len(given)),
|
||||
// b) the result of a previous provider call (args[i] >= len(given))
|
||||
//
|
||||
// This will be nil for kind == valueExpr.
|
||||
//
|
||||
// If kind == selectorExpr, then the length of this slice will be 1 and the
|
||||
// "argument" will be the value to access fields from.
|
||||
args []int
|
||||
|
||||
// varargs is true if the provider function is variadic.
|
||||
@@ -207,6 +212,29 @@ dfs:
|
||||
valueExpr: v.expr,
|
||||
valueTypeInfo: v.info,
|
||||
})
|
||||
case pv.IsField():
|
||||
f := pv.Field()
|
||||
if index.At(f.Parent) == nil {
|
||||
// Fields have one dependency which is the parent struct. Make
|
||||
// sure to visit it first if it is not already visited.
|
||||
stk = append(stk, curr, frame{t: f.Parent, from: curr.t, up: &curr})
|
||||
continue
|
||||
}
|
||||
index.Set(curr.t, given.Len()+len(calls))
|
||||
v := index.At(f.Parent)
|
||||
if v == errAbort {
|
||||
index.Set(curr.t, errAbort)
|
||||
continue dfs
|
||||
}
|
||||
// Use the args[0] to store the position of the parent struct.
|
||||
args := []int{v.(int)}
|
||||
calls = append(calls, call{
|
||||
kind: selectorExpr,
|
||||
pkg: f.Pkg,
|
||||
name: f.Name,
|
||||
out: curr.t,
|
||||
args: args,
|
||||
})
|
||||
default:
|
||||
panic("unknown return value from ProviderSet.For")
|
||||
}
|
||||
@@ -275,11 +303,24 @@ func verifyArgsUsed(set *ProviderSet, used []*providerSetSrc) []error {
|
||||
errs = append(errs, fmt.Errorf("unused interface binding to type %s", types.TypeString(b.Iface, nil)))
|
||||
}
|
||||
}
|
||||
for _, f := range set.Fields {
|
||||
found := false
|
||||
for _, u := range used {
|
||||
if u.Field == f {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
errs = append(errs, fmt.Errorf("unused field %q.%s", f.Parent, f.Name))
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// buildProviderMap creates the providerMap and srcMap fields for a given provider set.
|
||||
// The given provider set's providerMap and srcMap fields are ignored.
|
||||
// buildProviderMap creates the providerMap and srcMap fields for a given
|
||||
// provider set. The given provider set's providerMap and srcMap fields are
|
||||
// ignored.
|
||||
func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *ProviderSet) (*typeutil.Map, *typeutil.Map, []error) {
|
||||
providerMap := new(typeutil.Map)
|
||||
providerMap.SetHasher(hasher)
|
||||
@@ -339,6 +380,15 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
|
||||
providerMap.Set(v.Out, &ProvidedType{t: v.Out, v: v})
|
||||
srcMap.Set(v.Out, src)
|
||||
}
|
||||
for _, f := range set.Fields {
|
||||
src := &providerSetSrc{Field: f}
|
||||
if prevSrc := srcMap.At(f.Out); prevSrc != nil {
|
||||
ec.add(bindingConflictError(fset, f.Out, set, src, prevSrc.(*providerSetSrc)))
|
||||
continue
|
||||
}
|
||||
providerMap.Set(f.Out, &ProvidedType{t: f.Out, f: f})
|
||||
srcMap.Set(f.Out, src)
|
||||
}
|
||||
if len(ec.errors) > 0 {
|
||||
return nil, nil, ec.errors
|
||||
}
|
||||
@@ -398,38 +448,49 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
|
||||
continue
|
||||
}
|
||||
pt := x.(*ProvidedType)
|
||||
if pt.IsValue() {
|
||||
switch {
|
||||
case pt.IsValue():
|
||||
// Leaf: values do not have dependencies.
|
||||
continue
|
||||
}
|
||||
if pt.IsArg() {
|
||||
case pt.IsArg():
|
||||
// Injector arguments do not have dependencies.
|
||||
continue
|
||||
}
|
||||
if !pt.IsProvider() {
|
||||
panic("invalid provider map value")
|
||||
}
|
||||
for _, arg := range pt.Provider().Args {
|
||||
a := arg.Type
|
||||
hasCycle := false
|
||||
for i, b := range curr {
|
||||
if types.Identical(a, b) {
|
||||
sb := new(strings.Builder)
|
||||
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
|
||||
for j := i; j < len(curr); j++ {
|
||||
p := providerMap.At(curr[j]).(*ProvidedType).Provider()
|
||||
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
|
||||
case pt.IsProvider() || pt.IsField():
|
||||
var args []types.Type
|
||||
if pt.IsProvider() {
|
||||
for _, arg := range pt.Provider().Args {
|
||||
args = append(args, arg.Type)
|
||||
}
|
||||
} else {
|
||||
args = append(args, pt.Field().Parent)
|
||||
}
|
||||
for _, a := range args {
|
||||
hasCycle := false
|
||||
for i, b := range curr {
|
||||
if types.Identical(a, b) {
|
||||
sb := new(strings.Builder)
|
||||
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
|
||||
for j := i; j < len(curr); j++ {
|
||||
t := providerMap.At(curr[j]).(*ProvidedType)
|
||||
if t.IsProvider() {
|
||||
p := t.Provider()
|
||||
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
|
||||
} else {
|
||||
p := t.Field()
|
||||
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
|
||||
ec.add(errors.New(sb.String()))
|
||||
hasCycle = true
|
||||
break
|
||||
}
|
||||
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
|
||||
ec.add(errors.New(sb.String()))
|
||||
hasCycle = true
|
||||
break
|
||||
}
|
||||
if !hasCycle {
|
||||
next := append(append([]types.Type(nil), curr...), a)
|
||||
stk = append(stk, next)
|
||||
}
|
||||
}
|
||||
if !hasCycle {
|
||||
next := append(append([]types.Type(nil), curr...), a)
|
||||
stk = append(stk, next)
|
||||
}
|
||||
default:
|
||||
panic("invalid provider map value")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user