diff --git a/README.md b/README.md index 5a2df00..072a63c 100644 --- a/README.md +++ b/README.md @@ -261,6 +261,56 @@ func provideBar(foo Foo) Bar { If used as part of an injector that does not bring in the `Foo` dependency, then the injector will pass the provider the zero value as the `foo` argument. +### Struct Providers + +Structs can also be marked as providers. Instead of calling a function, an +injector will fill in each field using the corresponding provider. For a given +struct type `S`, this would provide both `S` and `*S`. For example, given the +following providers: + +```go +type Foo int +type Bar int + +//goose:provide Foo + +func provideFoo() Foo { + // ... +} + +//goose:provide Bar + +func provideBar() Bar { + // ... +} + +//goose:provide + +type FooBar struct { + Foo Foo + Bar Bar +} +``` + +A generated injector for `FooBar` would look like this: + +```go +func injectFooBar() FooBar { + foo := provideFoo() + bar := provideBar() + fooBar := FooBar{ + Foo: foo, + Bar: bar, + } + return fooBar +} +``` + +And similarly if the injector needed a `*FooBar`. + +Like function providers, you can mark dependencies of a struct provider optional +by using the `goose:optional` directive with the field names. + ### Cleanup functions If a provider creates a value that needs to be cleaned up (e.g. closing a file), @@ -292,8 +342,6 @@ of the provider's inputs and must have the signature `func()`. - Support for map bindings. - Support for multiple provider outputs. -- Support for field binding: declare a struct as a provider and have it be - filled in by the corresponding bindings from the graph. - Tighter validation for a provider set (cycles in unused providers goes unreported currently) - Visualization for provider sets diff --git a/internal/goose/analyze.go b/internal/goose/analyze.go index 35af3bc..3b066c5 100644 --- a/internal/goose/analyze.go +++ b/internal/goose/analyze.go @@ -8,11 +8,13 @@ import ( "golang.org/x/tools/go/types/typeutil" ) -// A call represents a step of an injector function. +// A call represents a step of an injector function. It may be either a +// function call or a composite struct literal, depending on the value +// of isStruct. type call struct { - // importPath and funcName identify the provider function to call. + // importPath and name identify the provider to call. importPath string - funcName string + name string // args is a list of arguments to call the provider with. Each element is: // a) one of the givens (args[i] < len(given)), @@ -20,6 +22,14 @@ type call struct { // c) the zero value for the type (args[i] == -1). args []int + // isStruct indicates whether this should generate a struct composite + // literal instead of a function call. + isStruct bool + + // fieldNames maps the arguments to struct field names. + // This will only be set if isStruct is true. + fieldNames []string + // ins is the list of types this call receives as arguments. ins []types.Type // out is the type produced by this provider call. @@ -51,7 +61,7 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr for i, g := range given { if p := providers.At(g); p != nil { pp := p.(*providerInfo) - return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos)) + return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.name, mc.fset.Position(pp.pos)) } index.Set(g, i) } @@ -111,8 +121,10 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr index.Set(typ, len(given)+len(calls)) calls = append(calls, call{ importPath: p.importPath, - funcName: p.funcName, + name: p.name, args: args, + isStruct: p.isStruct, + fieldNames: p.fields, ins: ins, out: typ, hasCleanup: p.hasCleanup, @@ -189,7 +201,7 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error if prev := pm.At(b.iface); prev != nil { pos := mc.fset.Position(b.pos) typ := types.TypeString(b.iface, nil) - // TODO(light): error message for conflicting with another interface binding will point at provider function instead of binding. + // TODO(light): error message for conflicting with another interface binding will point at provider instead of binding. prevPos := mc.fset.Position(prev.(*providerInfo).pos) if b.from.importPath == "" { // Provider set is imported directly by injector. diff --git a/internal/goose/goose.go b/internal/goose/goose.go index 4de0733..366ae33 100644 --- a/internal/goose/goose.go +++ b/internal/goose/goose.go @@ -198,6 +198,10 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se } for _, c := range calls { g.qualifyImport(c.importPath) + if !c.isStruct { + // Struct providers just omit zero-valued fields. + continue + } for i := range c.args { if c.args[i] == -1 { zeroValue(c.ins[i], g.qualifyPkg) @@ -274,20 +278,42 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se if c.hasErr { g.p(", %s", errVar) } - g.p(" := %s(", g.qualifiedID(c.importPath, c.funcName)) - for j, a := range c.args { - if j > 0 { - g.p(", ") + g.p(" := ") + if c.isStruct { + if _, ok := c.out.(*types.Pointer); ok { + g.p("&") } - if a == -1 { - g.p("%s", zeroValue(c.ins[j], g.qualifyPkg)) - } else if a < params.Len() { - g.p("%s", paramNames[a]) - } else { - g.p("%s", localNames[a-params.Len()]) + g.p("%s{\n", g.qualifiedID(c.importPath, c.name)) + for j, a := range c.args { + if a == -1 { + // Omit zero value fields from composite literal. + continue + } + g.p("\t\t%s: ", c.fieldNames[j]) + if a < params.Len() { + g.p("%s", paramNames[a]) + } else { + g.p("%s", localNames[a-params.Len()]) + } + g.p(",\n") } + g.p("\t}\n") + } else { + g.p("%s(", g.qualifiedID(c.importPath, c.name)) + for j, a := range c.args { + if j > 0 { + g.p(", ") + } + if a == -1 { + g.p("%s", zeroValue(c.ins[j], g.qualifyPkg)) + } else if a < params.Len() { + g.p("%s", paramNames[a]) + } else { + g.p("%s", localNames[a-params.Len()]) + } + } + g.p(")\n") } - g.p(")\n") if c.hasErr { g.p("\tif %s != nil {\n", errVar) for j := i - 1; j >= 0; j-- { diff --git a/internal/goose/parse.go b/internal/goose/parse.go index aa2c159..18f8cf3 100644 --- a/internal/goose/parse.go +++ b/internal/goose/parse.go @@ -27,9 +27,14 @@ type providerSet struct { // // provided is always a type that is assignable to iface. type ifaceBinding struct { - iface types.Type + // iface is the interface type, which is what can be injected. + iface types.Type + + // provided is always a type that is assignable to Iface. provided types.Type - pos token.Pos + + // pos is the position where the binding was declared. + pos token.Pos } type providerSetImport struct { @@ -37,15 +42,39 @@ type providerSetImport struct { pos token.Pos } -// providerInfo records the signature of a provider function. +// providerInfo records the signature of a provider. type providerInfo struct { + // importPath is the package path that the Go object resides in. importPath string - funcName string - pos token.Pos // provider function definition - args []providerInput - out types.Type + + // name is the name of the Go object. + name string + + // pos is the source position of the func keyword or type spec + // defining this provider. + pos token.Pos + + // args is the list of data dependencies this provider has. + args []providerInput + + // isStruct is true if this provider is a named struct type. + // Otherwise it's a function. + isStruct bool + + // fields lists the field names to populate. This will map 1:1 with + // elements in Args. + fields []string + + // out is the type this provider produces. + out types.Type + + // hasCleanup reports whether the provider function returns a cleanup + // function. (Always false for structs.) hasCleanup bool - hasErr bool + + // hasErr reports whether the provider function can return an error. + // (Always false for structs.) + hasErr bool } type providerInput struct { @@ -203,25 +232,97 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope } return nil } - fn, ok := dg.decl.(*ast.FuncDecl) - if !ok { - return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(p.pos)) + var providerSetName string + if args := p.args(); len(args) == 1 { + // TODO(light): validate identifier + providerSetName = args[0] + } else if len(args) > 1 { + return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(p.pos)) } - sig := fctx.typeInfo.ObjectOf(fn.Name).Type().(*types.Signature) - - optionals := make([]bool, sig.Params().Len()) + optionals := make(map[string]token.Pos) for _, d := range dg.dirs { if d.kind == "optional" { - // Marking the given argument names as optional inputs. for _, arg := range d.args() { - pi := paramIndex(sig.Params(), arg) - if pi == -1 { - return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name) - } - optionals[pi] = true + optionals[arg] = d.pos } } } + switch decl := dg.decl.(type) { + case *ast.FuncDecl: + fn := fctx.typeInfo.ObjectOf(decl.Name).(*types.Func) + provider, err := processFuncProvider(fctx, fn, optionals) + if err != nil { + return err + } + if providerSetName == "" { + providerSetName = fn.Name() + } + if mod := sets[providerSetName]; mod != nil { + for _, other := range mod.providers { + if types.Identical(other.out, provider.out) { + return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos)) + } + } + mod.providers = append(mod.providers, provider) + } else { + sets[providerSetName] = &providerSet{ + providers: []*providerInfo{provider}, + } + } + case *ast.GenDecl: + if decl.Tok != token.TYPE { + return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) + } + if len(decl.Specs) != 1 { + // TODO(light): tighten directive extraction to associate with particular specs. + return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) + } + typeName := fctx.typeInfo.ObjectOf(decl.Specs[0].(*ast.TypeSpec).Name).(*types.TypeName) + if _, ok := typeName.Type().(*types.Named).Underlying().(*types.Struct); !ok { + return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) + } + provider, err := processStructProvider(fctx, typeName, optionals) + if err != nil { + return err + } + if providerSetName == "" { + providerSetName = typeName.Name() + } + ptrProvider := new(providerInfo) + *ptrProvider = *provider + ptrProvider.out = types.NewPointer(provider.out) + if mod := sets[providerSetName]; mod != nil { + for _, other := range mod.providers { + if types.Identical(other.out, provider.out) { + return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos)) + } + if types.Identical(other.out, ptrProvider.out) { + return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.out, nil), fctx.fset.Position(other.pos)) + } + } + mod.providers = append(mod.providers, provider, ptrProvider) + } else { + sets[providerSetName] = &providerSet{ + providers: []*providerInfo{provider, ptrProvider}, + } + } + default: + return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) + } + return nil +} + +func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*providerInfo, error) { + sig := fn.Type().(*types.Signature) + + optionals := make([]bool, sig.Params().Len()) + for arg, dpos := range optionalArgs { + pi := paramIndex(sig.Params(), arg) + if pi == -1 { + return nil, fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(dpos), arg, fn.Name()) + } + optionals[pi] = true + } fpos := fn.Pos() r := sig.Results() @@ -236,24 +337,24 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope case types.Identical(t, cleanupType): hasCleanup, hasErr = true, false default: - return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name.Name) + return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name()) } case 3: if t := r.At(1).Type(); !types.Identical(t, cleanupType) { - return fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name.Name) + return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name()) } if t := r.At(2).Type(); !types.Identical(t, errorType) { - return fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name.Name) + return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name()) } hasCleanup, hasErr = true, true default: - return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name) + return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name()) } out := r.At(0).Type() params := sig.Params() provider := &providerInfo{ importPath: fctx.pkg.Path(), - funcName: fn.Name.Name, + name: fn.Name(), pos: fn.Pos(), args: make([]providerInput, params.Len()), out: out, @@ -267,30 +368,54 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope } for j := 0; j < i; j++ { if types.Identical(provider.args[i].typ, provider.args[j].typ) { - return fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil)) + return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil)) } } } - providerSetName := fn.Name.Name - if args := p.args(); len(args) == 1 { - // TODO(light): validate identifier - providerSetName = args[0] - } else if len(args) > 1 { - return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(fpos)) - } - if mod := sets[providerSetName]; mod != nil { - for _, other := range mod.providers { - if types.Identical(other.out, provider.out) { - return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos)) + return provider, nil +} + +func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*providerInfo, error) { + out := typeName.Type() + st := out.Underlying().(*types.Struct) + for arg, dpos := range optionals { + found := false + for i := 0; i < st.NumFields(); i++ { + if st.Field(i).Name() == arg { + found = true + break } } - mod.providers = append(mod.providers, provider) - } else { - sets[providerSetName] = &providerSet{ - providers: []*providerInfo{provider}, + if !found { + return nil, fmt.Errorf("%v: %s is not a field of struct %s", fctx.fset.Position(dpos), arg, types.TypeString(st, nil)) } } - return nil + + pos := typeName.Pos() + provider := &providerInfo{ + importPath: fctx.pkg.Path(), + name: typeName.Name(), + pos: pos, + args: make([]providerInput, st.NumFields()), + fields: make([]string, st.NumFields()), + isStruct: true, + out: out, + } + for i := 0; i < st.NumFields(); i++ { + f := st.Field(i) + _, optional := optionals[f.Name()] + provider.args[i] = providerInput{ + typ: f.Type(), + optional: optional, + } + provider.fields[i] = f.Name() + for j := 0; j < i; j++ { + if types.Identical(provider.args[i].typ, provider.args[j].typ) { + return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.args[j].typ, nil)) + } + } + } + return provider, nil } // providerSetCache is a lazily evaluated index of provider sets. diff --git a/internal/goose/testdata/Struct/foo/foo.go b/internal/goose/testdata/Struct/foo/foo.go new file mode 100644 index 0000000..d53f1c4 --- /dev/null +++ b/internal/goose/testdata/Struct/foo/foo.go @@ -0,0 +1,27 @@ +package main + +import "fmt" + +func main() { + fb := injectFooBar() + fmt.Println(fb.Foo, fb.Bar) +} + +type Foo int +type Bar int + +//goose:provide Set +type FooBar struct { + Foo Foo + Bar Bar +} + +//goose:provide Set +func provideFoo() Foo { + return 41 +} + +//goose:provide Set +func provideBar() Bar { + return 1 +} diff --git a/internal/goose/testdata/Struct/foo/goose.go b/internal/goose/testdata/Struct/foo/goose.go new file mode 100644 index 0000000..73f5093 --- /dev/null +++ b/internal/goose/testdata/Struct/foo/goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use Set + +func injectFooBar() FooBar diff --git a/internal/goose/testdata/Struct/out.txt b/internal/goose/testdata/Struct/out.txt new file mode 100644 index 0000000..b1ae43f --- /dev/null +++ b/internal/goose/testdata/Struct/out.txt @@ -0,0 +1 @@ +41 1 diff --git a/internal/goose/testdata/Struct/pkg b/internal/goose/testdata/Struct/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/Struct/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/goose/testdata/StructOptionalField/foo/foo.go b/internal/goose/testdata/StructOptionalField/foo/foo.go new file mode 100644 index 0000000..f37938a --- /dev/null +++ b/internal/goose/testdata/StructOptionalField/foo/foo.go @@ -0,0 +1,23 @@ +package main + +import "fmt" + +func main() { + fb := injectFooBar() + fmt.Println(fb.Foo, fb.OptionalBar) +} + +type Foo int +type Bar int + +//goose:provide Set +//goose:optional OptionalBar +type FooBar struct { + Foo Foo + OptionalBar Bar +} + +//goose:provide Set +func provideFoo() Foo { + return 42 +} diff --git a/internal/goose/testdata/StructOptionalField/foo/goose.go b/internal/goose/testdata/StructOptionalField/foo/goose.go new file mode 100644 index 0000000..73f5093 --- /dev/null +++ b/internal/goose/testdata/StructOptionalField/foo/goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use Set + +func injectFooBar() FooBar diff --git a/internal/goose/testdata/StructOptionalField/out.txt b/internal/goose/testdata/StructOptionalField/out.txt new file mode 100644 index 0000000..426756f --- /dev/null +++ b/internal/goose/testdata/StructOptionalField/out.txt @@ -0,0 +1 @@ +42 0 diff --git a/internal/goose/testdata/StructOptionalField/pkg b/internal/goose/testdata/StructOptionalField/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/StructOptionalField/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/goose/testdata/StructOptionalFieldPresent/foo/foo.go b/internal/goose/testdata/StructOptionalFieldPresent/foo/foo.go new file mode 100644 index 0000000..b612fa3 --- /dev/null +++ b/internal/goose/testdata/StructOptionalFieldPresent/foo/foo.go @@ -0,0 +1,28 @@ +package main + +import "fmt" + +func main() { + fb := injectFooBar() + fmt.Println(fb.Foo, fb.Bar) +} + +type Foo int +type Bar int + +//goose:provide Set +//goose:optional Bar +type FooBar struct { + Foo Foo + Bar Bar +} + +//goose:provide Set +func provideFoo() Foo { + return 41 +} + +//goose:provide Set +func provideBar() Bar { + return 1 +} diff --git a/internal/goose/testdata/StructOptionalFieldPresent/foo/goose.go b/internal/goose/testdata/StructOptionalFieldPresent/foo/goose.go new file mode 100644 index 0000000..73f5093 --- /dev/null +++ b/internal/goose/testdata/StructOptionalFieldPresent/foo/goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use Set + +func injectFooBar() FooBar diff --git a/internal/goose/testdata/StructOptionalFieldPresent/out.txt b/internal/goose/testdata/StructOptionalFieldPresent/out.txt new file mode 100644 index 0000000..b1ae43f --- /dev/null +++ b/internal/goose/testdata/StructOptionalFieldPresent/out.txt @@ -0,0 +1 @@ +41 1 diff --git a/internal/goose/testdata/StructOptionalFieldPresent/pkg b/internal/goose/testdata/StructOptionalFieldPresent/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/StructOptionalFieldPresent/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/goose/testdata/StructPointer/foo/foo.go b/internal/goose/testdata/StructPointer/foo/foo.go new file mode 100644 index 0000000..d53f1c4 --- /dev/null +++ b/internal/goose/testdata/StructPointer/foo/foo.go @@ -0,0 +1,27 @@ +package main + +import "fmt" + +func main() { + fb := injectFooBar() + fmt.Println(fb.Foo, fb.Bar) +} + +type Foo int +type Bar int + +//goose:provide Set +type FooBar struct { + Foo Foo + Bar Bar +} + +//goose:provide Set +func provideFoo() Foo { + return 41 +} + +//goose:provide Set +func provideBar() Bar { + return 1 +} diff --git a/internal/goose/testdata/StructPointer/foo/goose.go b/internal/goose/testdata/StructPointer/foo/goose.go new file mode 100644 index 0000000..d99308a --- /dev/null +++ b/internal/goose/testdata/StructPointer/foo/goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use Set + +func injectFooBar() *FooBar diff --git a/internal/goose/testdata/StructPointer/out.txt b/internal/goose/testdata/StructPointer/out.txt new file mode 100644 index 0000000..b1ae43f --- /dev/null +++ b/internal/goose/testdata/StructPointer/out.txt @@ -0,0 +1 @@ +41 1 diff --git a/internal/goose/testdata/StructPointer/pkg b/internal/goose/testdata/StructPointer/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/StructPointer/pkg @@ -0,0 +1 @@ +foo