goose: add struct field injection

This makes options structs and application structs much simpler to
inject.

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-03 21:11:53 -07:00
parent ccf63fec5d
commit 2044e2213b
20 changed files with 413 additions and 61 deletions

View File

@@ -261,6 +261,56 @@ func provideBar(foo Foo) Bar {
If used as part of an injector that does not bring in the `Foo` dependency, then If used as part of an injector that does not bring in the `Foo` dependency, then
the injector will pass the provider the zero value as the `foo` argument. the injector will pass the provider the zero value as the `foo` argument.
### Struct Providers
Structs can also be marked as providers. Instead of calling a function, an
injector will fill in each field using the corresponding provider. For a given
struct type `S`, this would provide both `S` and `*S`. For example, given the
following providers:
```go
type Foo int
type Bar int
//goose:provide Foo
func provideFoo() Foo {
// ...
}
//goose:provide Bar
func provideBar() Bar {
// ...
}
//goose:provide
type FooBar struct {
Foo Foo
Bar Bar
}
```
A generated injector for `FooBar` would look like this:
```go
func injectFooBar() FooBar {
foo := provideFoo()
bar := provideBar()
fooBar := FooBar{
Foo: foo,
Bar: bar,
}
return fooBar
}
```
And similarly if the injector needed a `*FooBar`.
Like function providers, you can mark dependencies of a struct provider optional
by using the `goose:optional` directive with the field names.
### Cleanup functions ### Cleanup functions
If a provider creates a value that needs to be cleaned up (e.g. closing a file), If a provider creates a value that needs to be cleaned up (e.g. closing a file),
@@ -292,8 +342,6 @@ of the provider's inputs and must have the signature `func()`.
- Support for map bindings. - Support for map bindings.
- Support for multiple provider outputs. - Support for multiple provider outputs.
- Support for field binding: declare a struct as a provider and have it be
filled in by the corresponding bindings from the graph.
- Tighter validation for a provider set (cycles in unused providers goes - Tighter validation for a provider set (cycles in unused providers goes
unreported currently) unreported currently)
- Visualization for provider sets - Visualization for provider sets

View File

@@ -8,11 +8,13 @@ import (
"golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/go/types/typeutil"
) )
// A call represents a step of an injector function. // A call represents a step of an injector function. It may be either a
// function call or a composite struct literal, depending on the value
// of isStruct.
type call struct { type call struct {
// importPath and funcName identify the provider function to call. // importPath and name identify the provider to call.
importPath string importPath string
funcName string name string
// args is a list of arguments to call the provider with. Each element is: // args is a list of arguments to call the provider with. Each element is:
// a) one of the givens (args[i] < len(given)), // a) one of the givens (args[i] < len(given)),
@@ -20,6 +22,14 @@ type call struct {
// c) the zero value for the type (args[i] == -1). // c) the zero value for the type (args[i] == -1).
args []int args []int
// isStruct indicates whether this should generate a struct composite
// literal instead of a function call.
isStruct bool
// fieldNames maps the arguments to struct field names.
// This will only be set if isStruct is true.
fieldNames []string
// ins is the list of types this call receives as arguments. // ins is the list of types this call receives as arguments.
ins []types.Type ins []types.Type
// out is the type produced by this provider call. // out is the type produced by this provider call.
@@ -51,7 +61,7 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
for i, g := range given { for i, g := range given {
if p := providers.At(g); p != nil { if p := providers.At(g); p != nil {
pp := p.(*providerInfo) pp := p.(*providerInfo)
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos)) return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.name, mc.fset.Position(pp.pos))
} }
index.Set(g, i) index.Set(g, i)
} }
@@ -111,8 +121,10 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
index.Set(typ, len(given)+len(calls)) index.Set(typ, len(given)+len(calls))
calls = append(calls, call{ calls = append(calls, call{
importPath: p.importPath, importPath: p.importPath,
funcName: p.funcName, name: p.name,
args: args, args: args,
isStruct: p.isStruct,
fieldNames: p.fields,
ins: ins, ins: ins,
out: typ, out: typ,
hasCleanup: p.hasCleanup, hasCleanup: p.hasCleanup,
@@ -189,7 +201,7 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error
if prev := pm.At(b.iface); prev != nil { if prev := pm.At(b.iface); prev != nil {
pos := mc.fset.Position(b.pos) pos := mc.fset.Position(b.pos)
typ := types.TypeString(b.iface, nil) typ := types.TypeString(b.iface, nil)
// TODO(light): error message for conflicting with another interface binding will point at provider function instead of binding. // TODO(light): error message for conflicting with another interface binding will point at provider instead of binding.
prevPos := mc.fset.Position(prev.(*providerInfo).pos) prevPos := mc.fset.Position(prev.(*providerInfo).pos)
if b.from.importPath == "" { if b.from.importPath == "" {
// Provider set is imported directly by injector. // Provider set is imported directly by injector.

View File

@@ -198,6 +198,10 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
} }
for _, c := range calls { for _, c := range calls {
g.qualifyImport(c.importPath) g.qualifyImport(c.importPath)
if !c.isStruct {
// Struct providers just omit zero-valued fields.
continue
}
for i := range c.args { for i := range c.args {
if c.args[i] == -1 { if c.args[i] == -1 {
zeroValue(c.ins[i], g.qualifyPkg) zeroValue(c.ins[i], g.qualifyPkg)
@@ -274,7 +278,28 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
if c.hasErr { if c.hasErr {
g.p(", %s", errVar) g.p(", %s", errVar)
} }
g.p(" := %s(", g.qualifiedID(c.importPath, c.funcName)) g.p(" := ")
if c.isStruct {
if _, ok := c.out.(*types.Pointer); ok {
g.p("&")
}
g.p("%s{\n", g.qualifiedID(c.importPath, c.name))
for j, a := range c.args {
if a == -1 {
// Omit zero value fields from composite literal.
continue
}
g.p("\t\t%s: ", c.fieldNames[j])
if a < params.Len() {
g.p("%s", paramNames[a])
} else {
g.p("%s", localNames[a-params.Len()])
}
g.p(",\n")
}
g.p("\t}\n")
} else {
g.p("%s(", g.qualifiedID(c.importPath, c.name))
for j, a := range c.args { for j, a := range c.args {
if j > 0 { if j > 0 {
g.p(", ") g.p(", ")
@@ -288,6 +313,7 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
} }
} }
g.p(")\n") g.p(")\n")
}
if c.hasErr { if c.hasErr {
g.p("\tif %s != nil {\n", errVar) g.p("\tif %s != nil {\n", errVar)
for j := i - 1; j >= 0; j-- { for j := i - 1; j >= 0; j-- {

View File

@@ -27,8 +27,13 @@ type providerSet struct {
// //
// provided is always a type that is assignable to iface. // provided is always a type that is assignable to iface.
type ifaceBinding struct { type ifaceBinding struct {
// iface is the interface type, which is what can be injected.
iface types.Type iface types.Type
// provided is always a type that is assignable to Iface.
provided types.Type provided types.Type
// pos is the position where the binding was declared.
pos token.Pos pos token.Pos
} }
@@ -37,14 +42,38 @@ type providerSetImport struct {
pos token.Pos pos token.Pos
} }
// providerInfo records the signature of a provider function. // providerInfo records the signature of a provider.
type providerInfo struct { type providerInfo struct {
// importPath is the package path that the Go object resides in.
importPath string importPath string
funcName string
pos token.Pos // provider function definition // name is the name of the Go object.
name string
// pos is the source position of the func keyword or type spec
// defining this provider.
pos token.Pos
// args is the list of data dependencies this provider has.
args []providerInput args []providerInput
// isStruct is true if this provider is a named struct type.
// Otherwise it's a function.
isStruct bool
// fields lists the field names to populate. This will map 1:1 with
// elements in Args.
fields []string
// out is the type this provider produces.
out types.Type out types.Type
// hasCleanup reports whether the provider function returns a cleanup
// function. (Always false for structs.)
hasCleanup bool hasCleanup bool
// hasErr reports whether the provider function can return an error.
// (Always false for structs.)
hasErr bool hasErr bool
} }
@@ -203,25 +232,97 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
} }
return nil return nil
} }
fn, ok := dg.decl.(*ast.FuncDecl) var providerSetName string
if !ok { if args := p.args(); len(args) == 1 {
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(p.pos)) // TODO(light): validate identifier
providerSetName = args[0]
} else if len(args) > 1 {
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(p.pos))
} }
sig := fctx.typeInfo.ObjectOf(fn.Name).Type().(*types.Signature) optionals := make(map[string]token.Pos)
optionals := make([]bool, sig.Params().Len())
for _, d := range dg.dirs { for _, d := range dg.dirs {
if d.kind == "optional" { if d.kind == "optional" {
// Marking the given argument names as optional inputs.
for _, arg := range d.args() { for _, arg := range d.args() {
optionals[arg] = d.pos
}
}
}
switch decl := dg.decl.(type) {
case *ast.FuncDecl:
fn := fctx.typeInfo.ObjectOf(decl.Name).(*types.Func)
provider, err := processFuncProvider(fctx, fn, optionals)
if err != nil {
return err
}
if providerSetName == "" {
providerSetName = fn.Name()
}
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers {
if types.Identical(other.out, provider.out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
}
}
mod.providers = append(mod.providers, provider)
} else {
sets[providerSetName] = &providerSet{
providers: []*providerInfo{provider},
}
}
case *ast.GenDecl:
if decl.Tok != token.TYPE {
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
}
if len(decl.Specs) != 1 {
// TODO(light): tighten directive extraction to associate with particular specs.
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
}
typeName := fctx.typeInfo.ObjectOf(decl.Specs[0].(*ast.TypeSpec).Name).(*types.TypeName)
if _, ok := typeName.Type().(*types.Named).Underlying().(*types.Struct); !ok {
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
}
provider, err := processStructProvider(fctx, typeName, optionals)
if err != nil {
return err
}
if providerSetName == "" {
providerSetName = typeName.Name()
}
ptrProvider := new(providerInfo)
*ptrProvider = *provider
ptrProvider.out = types.NewPointer(provider.out)
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers {
if types.Identical(other.out, provider.out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
}
if types.Identical(other.out, ptrProvider.out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.out, nil), fctx.fset.Position(other.pos))
}
}
mod.providers = append(mod.providers, provider, ptrProvider)
} else {
sets[providerSetName] = &providerSet{
providers: []*providerInfo{provider, ptrProvider},
}
}
default:
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
}
return nil
}
func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*providerInfo, error) {
sig := fn.Type().(*types.Signature)
optionals := make([]bool, sig.Params().Len())
for arg, dpos := range optionalArgs {
pi := paramIndex(sig.Params(), arg) pi := paramIndex(sig.Params(), arg)
if pi == -1 { if pi == -1 {
return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name) return nil, fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(dpos), arg, fn.Name())
} }
optionals[pi] = true optionals[pi] = true
} }
}
}
fpos := fn.Pos() fpos := fn.Pos()
r := sig.Results() r := sig.Results()
@@ -236,24 +337,24 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
case types.Identical(t, cleanupType): case types.Identical(t, cleanupType):
hasCleanup, hasErr = true, false hasCleanup, hasErr = true, false
default: default:
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name.Name) return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name())
} }
case 3: case 3:
if t := r.At(1).Type(); !types.Identical(t, cleanupType) { if t := r.At(1).Type(); !types.Identical(t, cleanupType) {
return fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name.Name) return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name())
} }
if t := r.At(2).Type(); !types.Identical(t, errorType) { if t := r.At(2).Type(); !types.Identical(t, errorType) {
return fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name.Name) return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name())
} }
hasCleanup, hasErr = true, true hasCleanup, hasErr = true, true
default: default:
return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name) return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name())
} }
out := r.At(0).Type() out := r.At(0).Type()
params := sig.Params() params := sig.Params()
provider := &providerInfo{ provider := &providerInfo{
importPath: fctx.pkg.Path(), importPath: fctx.pkg.Path(),
funcName: fn.Name.Name, name: fn.Name(),
pos: fn.Pos(), pos: fn.Pos(),
args: make([]providerInput, params.Len()), args: make([]providerInput, params.Len()),
out: out, out: out,
@@ -267,30 +368,54 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
} }
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if types.Identical(provider.args[i].typ, provider.args[j].typ) { if types.Identical(provider.args[i].typ, provider.args[j].typ) {
return fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil)) return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil))
} }
} }
} }
providerSetName := fn.Name.Name return provider, nil
if args := p.args(); len(args) == 1 { }
// TODO(light): validate identifier
providerSetName = args[0] func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*providerInfo, error) {
} else if len(args) > 1 { out := typeName.Type()
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(fpos)) st := out.Underlying().(*types.Struct)
} for arg, dpos := range optionals {
if mod := sets[providerSetName]; mod != nil { found := false
for _, other := range mod.providers { for i := 0; i < st.NumFields(); i++ {
if types.Identical(other.out, provider.out) { if st.Field(i).Name() == arg {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos)) found = true
break
} }
} }
mod.providers = append(mod.providers, provider) if !found {
} else { return nil, fmt.Errorf("%v: %s is not a field of struct %s", fctx.fset.Position(dpos), arg, types.TypeString(st, nil))
sets[providerSetName] = &providerSet{
providers: []*providerInfo{provider},
} }
} }
return nil
pos := typeName.Pos()
provider := &providerInfo{
importPath: fctx.pkg.Path(),
name: typeName.Name(),
pos: pos,
args: make([]providerInput, st.NumFields()),
fields: make([]string, st.NumFields()),
isStruct: true,
out: out,
}
for i := 0; i < st.NumFields(); i++ {
f := st.Field(i)
_, optional := optionals[f.Name()]
provider.args[i] = providerInput{
typ: f.Type(),
optional: optional,
}
provider.fields[i] = f.Name()
for j := 0; j < i; j++ {
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.args[j].typ, nil))
}
}
}
return provider, nil
} }
// providerSetCache is a lazily evaluated index of provider sets. // providerSetCache is a lazily evaluated index of provider sets.

View File

@@ -0,0 +1,27 @@
package main
import "fmt"
func main() {
fb := injectFooBar()
fmt.Println(fb.Foo, fb.Bar)
}
type Foo int
type Bar int
//goose:provide Set
type FooBar struct {
Foo Foo
Bar Bar
}
//goose:provide Set
func provideFoo() Foo {
return 41
}
//goose:provide Set
func provideBar() Bar {
return 1
}

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use Set
func injectFooBar() FooBar

View File

@@ -0,0 +1 @@
41 1

1
internal/goose/testdata/Struct/pkg vendored Normal file
View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,23 @@
package main
import "fmt"
func main() {
fb := injectFooBar()
fmt.Println(fb.Foo, fb.OptionalBar)
}
type Foo int
type Bar int
//goose:provide Set
//goose:optional OptionalBar
type FooBar struct {
Foo Foo
OptionalBar Bar
}
//goose:provide Set
func provideFoo() Foo {
return 42
}

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use Set
func injectFooBar() FooBar

View File

@@ -0,0 +1 @@
42 0

View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,28 @@
package main
import "fmt"
func main() {
fb := injectFooBar()
fmt.Println(fb.Foo, fb.Bar)
}
type Foo int
type Bar int
//goose:provide Set
//goose:optional Bar
type FooBar struct {
Foo Foo
Bar Bar
}
//goose:provide Set
func provideFoo() Foo {
return 41
}
//goose:provide Set
func provideBar() Bar {
return 1
}

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use Set
func injectFooBar() FooBar

View File

@@ -0,0 +1 @@
41 1

View File

@@ -0,0 +1 @@
foo

View File

@@ -0,0 +1,27 @@
package main
import "fmt"
func main() {
fb := injectFooBar()
fmt.Println(fb.Foo, fb.Bar)
}
type Foo int
type Bar int
//goose:provide Set
type FooBar struct {
Foo Foo
Bar Bar
}
//goose:provide Set
func provideFoo() Foo {
return 41
}
//goose:provide Set
func provideBar() Bar {
return 1
}

View File

@@ -0,0 +1,7 @@
//+build gooseinject
package main
//goose:use Set
func injectFooBar() *FooBar

View File

@@ -0,0 +1 @@
41 1

View File

@@ -0,0 +1 @@
foo