add FieldsOf to inject fields of a struct directly (#138)

This commit is contained in:
shantuo
2019-03-01 13:52:07 -08:00
committed by GitHub
parent 58e5de342a
commit 327f42724c
37 changed files with 869 additions and 56 deletions

View File

@@ -32,6 +32,7 @@ const (
funcProviderCall callKind = iota
structProvider
valueExpr
selectorExpr
)
// A call represents a step of an injector function. It may be either a
@@ -44,17 +45,21 @@ type call struct {
// out is the type this step produces.
out types.Type
// pkg and name identify the provider to call for kind ==
// funcProviderCall or the type to construct for kind ==
// structProvider.
// pkg and name identify one of the following:
// 1) the provider to call for kind == funcProviderCall;
// 2) the type to construct for kind == structProvider;
// 3) the name to select for kind == selectorExpr.
pkg *types.Package
name string
// args is a list of arguments to call the provider with. Each element is:
// a) one of the givens (args[i] < len(given)), or
// args is a list of arguments to call the provider with. Each element is:
// a) one of the givens (args[i] < len(given)),
// b) the result of a previous provider call (args[i] >= len(given))
//
// This will be nil for kind == valueExpr.
//
// If kind == selectorExpr, then the length of this slice will be 1 and the
// "argument" will be the value to access fields from.
args []int
// varargs is true if the provider function is variadic.
@@ -207,6 +212,29 @@ dfs:
valueExpr: v.expr,
valueTypeInfo: v.info,
})
case pv.IsField():
f := pv.Field()
if index.At(f.Parent) == nil {
// Fields have one dependency which is the parent struct. Make
// sure to visit it first if it is not already visited.
stk = append(stk, curr, frame{t: f.Parent, from: curr.t, up: &curr})
continue
}
index.Set(curr.t, given.Len()+len(calls))
v := index.At(f.Parent)
if v == errAbort {
index.Set(curr.t, errAbort)
continue dfs
}
// Use the args[0] to store the position of the parent struct.
args := []int{v.(int)}
calls = append(calls, call{
kind: selectorExpr,
pkg: f.Pkg,
name: f.Name,
out: curr.t,
args: args,
})
default:
panic("unknown return value from ProviderSet.For")
}
@@ -275,11 +303,24 @@ func verifyArgsUsed(set *ProviderSet, used []*providerSetSrc) []error {
errs = append(errs, fmt.Errorf("unused interface binding to type %s", types.TypeString(b.Iface, nil)))
}
}
for _, f := range set.Fields {
found := false
for _, u := range used {
if u.Field == f {
found = true
break
}
}
if !found {
errs = append(errs, fmt.Errorf("unused field %q.%s", f.Parent, f.Name))
}
}
return errs
}
// buildProviderMap creates the providerMap and srcMap fields for a given provider set.
// The given provider set's providerMap and srcMap fields are ignored.
// buildProviderMap creates the providerMap and srcMap fields for a given
// provider set. The given provider set's providerMap and srcMap fields are
// ignored.
func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *ProviderSet) (*typeutil.Map, *typeutil.Map, []error) {
providerMap := new(typeutil.Map)
providerMap.SetHasher(hasher)
@@ -339,6 +380,15 @@ func buildProviderMap(fset *token.FileSet, hasher typeutil.Hasher, set *Provider
providerMap.Set(v.Out, &ProvidedType{t: v.Out, v: v})
srcMap.Set(v.Out, src)
}
for _, f := range set.Fields {
src := &providerSetSrc{Field: f}
if prevSrc := srcMap.At(f.Out); prevSrc != nil {
ec.add(bindingConflictError(fset, f.Out, set, src, prevSrc.(*providerSetSrc)))
continue
}
providerMap.Set(f.Out, &ProvidedType{t: f.Out, f: f})
srcMap.Set(f.Out, src)
}
if len(ec.errors) > 0 {
return nil, nil, ec.errors
}
@@ -398,38 +448,49 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
continue
}
pt := x.(*ProvidedType)
if pt.IsValue() {
switch {
case pt.IsValue():
// Leaf: values do not have dependencies.
continue
}
if pt.IsArg() {
case pt.IsArg():
// Injector arguments do not have dependencies.
continue
}
if !pt.IsProvider() {
panic("invalid provider map value")
}
for _, arg := range pt.Provider().Args {
a := arg.Type
hasCycle := false
for i, b := range curr {
if types.Identical(a, b) {
sb := new(strings.Builder)
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
for j := i; j < len(curr); j++ {
p := providerMap.At(curr[j]).(*ProvidedType).Provider()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
case pt.IsProvider() || pt.IsField():
var args []types.Type
if pt.IsProvider() {
for _, arg := range pt.Provider().Args {
args = append(args, arg.Type)
}
} else {
args = append(args, pt.Field().Parent)
}
for _, a := range args {
hasCycle := false
for i, b := range curr {
if types.Identical(a, b) {
sb := new(strings.Builder)
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
for j := i; j < len(curr); j++ {
t := providerMap.At(curr[j]).(*ProvidedType)
if t.IsProvider() {
p := t.Provider()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
} else {
p := t.Field()
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name)
}
}
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
ec.add(errors.New(sb.String()))
hasCycle = true
break
}
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
ec.add(errors.New(sb.String()))
hasCycle = true
break
}
if !hasCycle {
next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
}
}
if !hasCycle {
next := append(append([]types.Type(nil), curr...), a)
stk = append(stk, next)
}
default:
panic("invalid provider map value")
}
}
}