goose: add struct field injection
This makes options structs and application structs much simpler to inject. Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
@@ -27,9 +27,14 @@ type providerSet struct {
|
||||
//
|
||||
// provided is always a type that is assignable to iface.
|
||||
type ifaceBinding struct {
|
||||
iface types.Type
|
||||
// iface is the interface type, which is what can be injected.
|
||||
iface types.Type
|
||||
|
||||
// provided is always a type that is assignable to Iface.
|
||||
provided types.Type
|
||||
pos token.Pos
|
||||
|
||||
// pos is the position where the binding was declared.
|
||||
pos token.Pos
|
||||
}
|
||||
|
||||
type providerSetImport struct {
|
||||
@@ -37,15 +42,39 @@ type providerSetImport struct {
|
||||
pos token.Pos
|
||||
}
|
||||
|
||||
// providerInfo records the signature of a provider function.
|
||||
// providerInfo records the signature of a provider.
|
||||
type providerInfo struct {
|
||||
// importPath is the package path that the Go object resides in.
|
||||
importPath string
|
||||
funcName string
|
||||
pos token.Pos // provider function definition
|
||||
args []providerInput
|
||||
out types.Type
|
||||
|
||||
// name is the name of the Go object.
|
||||
name string
|
||||
|
||||
// pos is the source position of the func keyword or type spec
|
||||
// defining this provider.
|
||||
pos token.Pos
|
||||
|
||||
// args is the list of data dependencies this provider has.
|
||||
args []providerInput
|
||||
|
||||
// isStruct is true if this provider is a named struct type.
|
||||
// Otherwise it's a function.
|
||||
isStruct bool
|
||||
|
||||
// fields lists the field names to populate. This will map 1:1 with
|
||||
// elements in Args.
|
||||
fields []string
|
||||
|
||||
// out is the type this provider produces.
|
||||
out types.Type
|
||||
|
||||
// hasCleanup reports whether the provider function returns a cleanup
|
||||
// function. (Always false for structs.)
|
||||
hasCleanup bool
|
||||
hasErr bool
|
||||
|
||||
// hasErr reports whether the provider function can return an error.
|
||||
// (Always false for structs.)
|
||||
hasErr bool
|
||||
}
|
||||
|
||||
type providerInput struct {
|
||||
@@ -203,25 +232,97 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
|
||||
}
|
||||
return nil
|
||||
}
|
||||
fn, ok := dg.decl.(*ast.FuncDecl)
|
||||
if !ok {
|
||||
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(p.pos))
|
||||
var providerSetName string
|
||||
if args := p.args(); len(args) == 1 {
|
||||
// TODO(light): validate identifier
|
||||
providerSetName = args[0]
|
||||
} else if len(args) > 1 {
|
||||
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(p.pos))
|
||||
}
|
||||
sig := fctx.typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||
|
||||
optionals := make([]bool, sig.Params().Len())
|
||||
optionals := make(map[string]token.Pos)
|
||||
for _, d := range dg.dirs {
|
||||
if d.kind == "optional" {
|
||||
// Marking the given argument names as optional inputs.
|
||||
for _, arg := range d.args() {
|
||||
pi := paramIndex(sig.Params(), arg)
|
||||
if pi == -1 {
|
||||
return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name)
|
||||
}
|
||||
optionals[pi] = true
|
||||
optionals[arg] = d.pos
|
||||
}
|
||||
}
|
||||
}
|
||||
switch decl := dg.decl.(type) {
|
||||
case *ast.FuncDecl:
|
||||
fn := fctx.typeInfo.ObjectOf(decl.Name).(*types.Func)
|
||||
provider, err := processFuncProvider(fctx, fn, optionals)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if providerSetName == "" {
|
||||
providerSetName = fn.Name()
|
||||
}
|
||||
if mod := sets[providerSetName]; mod != nil {
|
||||
for _, other := range mod.providers {
|
||||
if types.Identical(other.out, provider.out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
|
||||
}
|
||||
}
|
||||
mod.providers = append(mod.providers, provider)
|
||||
} else {
|
||||
sets[providerSetName] = &providerSet{
|
||||
providers: []*providerInfo{provider},
|
||||
}
|
||||
}
|
||||
case *ast.GenDecl:
|
||||
if decl.Tok != token.TYPE {
|
||||
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
|
||||
}
|
||||
if len(decl.Specs) != 1 {
|
||||
// TODO(light): tighten directive extraction to associate with particular specs.
|
||||
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
|
||||
}
|
||||
typeName := fctx.typeInfo.ObjectOf(decl.Specs[0].(*ast.TypeSpec).Name).(*types.TypeName)
|
||||
if _, ok := typeName.Type().(*types.Named).Underlying().(*types.Struct); !ok {
|
||||
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
|
||||
}
|
||||
provider, err := processStructProvider(fctx, typeName, optionals)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if providerSetName == "" {
|
||||
providerSetName = typeName.Name()
|
||||
}
|
||||
ptrProvider := new(providerInfo)
|
||||
*ptrProvider = *provider
|
||||
ptrProvider.out = types.NewPointer(provider.out)
|
||||
if mod := sets[providerSetName]; mod != nil {
|
||||
for _, other := range mod.providers {
|
||||
if types.Identical(other.out, provider.out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
|
||||
}
|
||||
if types.Identical(other.out, ptrProvider.out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.out, nil), fctx.fset.Position(other.pos))
|
||||
}
|
||||
}
|
||||
mod.providers = append(mod.providers, provider, ptrProvider)
|
||||
} else {
|
||||
sets[providerSetName] = &providerSet{
|
||||
providers: []*providerInfo{provider, ptrProvider},
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*providerInfo, error) {
|
||||
sig := fn.Type().(*types.Signature)
|
||||
|
||||
optionals := make([]bool, sig.Params().Len())
|
||||
for arg, dpos := range optionalArgs {
|
||||
pi := paramIndex(sig.Params(), arg)
|
||||
if pi == -1 {
|
||||
return nil, fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(dpos), arg, fn.Name())
|
||||
}
|
||||
optionals[pi] = true
|
||||
}
|
||||
|
||||
fpos := fn.Pos()
|
||||
r := sig.Results()
|
||||
@@ -236,24 +337,24 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
|
||||
case types.Identical(t, cleanupType):
|
||||
hasCleanup, hasErr = true, false
|
||||
default:
|
||||
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name.Name)
|
||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name())
|
||||
}
|
||||
case 3:
|
||||
if t := r.At(1).Type(); !types.Identical(t, cleanupType) {
|
||||
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name.Name)
|
||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name())
|
||||
}
|
||||
if t := r.At(2).Type(); !types.Identical(t, errorType) {
|
||||
return fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name.Name)
|
||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name())
|
||||
}
|
||||
hasCleanup, hasErr = true, true
|
||||
default:
|
||||
return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name)
|
||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name())
|
||||
}
|
||||
out := r.At(0).Type()
|
||||
params := sig.Params()
|
||||
provider := &providerInfo{
|
||||
importPath: fctx.pkg.Path(),
|
||||
funcName: fn.Name.Name,
|
||||
name: fn.Name(),
|
||||
pos: fn.Pos(),
|
||||
args: make([]providerInput, params.Len()),
|
||||
out: out,
|
||||
@@ -267,30 +368,54 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
|
||||
}
|
||||
for j := 0; j < i; j++ {
|
||||
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
|
||||
return fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil))
|
||||
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil))
|
||||
}
|
||||
}
|
||||
}
|
||||
providerSetName := fn.Name.Name
|
||||
if args := p.args(); len(args) == 1 {
|
||||
// TODO(light): validate identifier
|
||||
providerSetName = args[0]
|
||||
} else if len(args) > 1 {
|
||||
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(fpos))
|
||||
}
|
||||
if mod := sets[providerSetName]; mod != nil {
|
||||
for _, other := range mod.providers {
|
||||
if types.Identical(other.out, provider.out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*providerInfo, error) {
|
||||
out := typeName.Type()
|
||||
st := out.Underlying().(*types.Struct)
|
||||
for arg, dpos := range optionals {
|
||||
found := false
|
||||
for i := 0; i < st.NumFields(); i++ {
|
||||
if st.Field(i).Name() == arg {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
mod.providers = append(mod.providers, provider)
|
||||
} else {
|
||||
sets[providerSetName] = &providerSet{
|
||||
providers: []*providerInfo{provider},
|
||||
if !found {
|
||||
return nil, fmt.Errorf("%v: %s is not a field of struct %s", fctx.fset.Position(dpos), arg, types.TypeString(st, nil))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
pos := typeName.Pos()
|
||||
provider := &providerInfo{
|
||||
importPath: fctx.pkg.Path(),
|
||||
name: typeName.Name(),
|
||||
pos: pos,
|
||||
args: make([]providerInput, st.NumFields()),
|
||||
fields: make([]string, st.NumFields()),
|
||||
isStruct: true,
|
||||
out: out,
|
||||
}
|
||||
for i := 0; i < st.NumFields(); i++ {
|
||||
f := st.Field(i)
|
||||
_, optional := optionals[f.Name()]
|
||||
provider.args[i] = providerInput{
|
||||
typ: f.Type(),
|
||||
optional: optional,
|
||||
}
|
||||
provider.fields[i] = f.Name()
|
||||
for j := 0; j < i; j++ {
|
||||
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
|
||||
return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.args[j].typ, nil))
|
||||
}
|
||||
}
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// providerSetCache is a lazily evaluated index of provider sets.
|
||||
|
||||
Reference in New Issue
Block a user