From 479a501c08a6119e044c12e43ff00ce688acce4c Mon Sep 17 00:00:00 2001 From: Ross Light Date: Fri, 30 Mar 2018 21:34:08 -0700 Subject: [PATCH] goose: add optional provider inputs Reviewed-by: Tuo Shan --- README.md | 18 + internal/goose/analyze.go | 36 +- internal/goose/goose.go | 9 +- internal/goose/parse.go | 395 ++++++++++++------ .../goose/testdata/OptionalMissing/foo/foo.go | 16 + .../testdata/OptionalMissing/foo/foo_goose.go | 7 + .../goose/testdata/OptionalMissing/out.txt | 1 + internal/goose/testdata/OptionalMissing/pkg | 1 + .../goose/testdata/OptionalPresent/foo/foo.go | 16 + .../testdata/OptionalPresent/foo/foo_goose.go | 7 + .../goose/testdata/OptionalPresent/out.txt | 1 + internal/goose/testdata/OptionalPresent/pkg | 1 + 12 files changed, 370 insertions(+), 138 deletions(-) create mode 100644 internal/goose/testdata/OptionalMissing/foo/foo.go create mode 100644 internal/goose/testdata/OptionalMissing/foo/foo_goose.go create mode 100644 internal/goose/testdata/OptionalMissing/out.txt create mode 100644 internal/goose/testdata/OptionalMissing/pkg create mode 100644 internal/goose/testdata/OptionalPresent/foo/foo.go create mode 100644 internal/goose/testdata/OptionalPresent/foo/foo_goose.go create mode 100644 internal/goose/testdata/OptionalPresent/out.txt create mode 100644 internal/goose/testdata/OptionalPresent/pkg diff --git a/README.md b/README.md index 66f7d54..b5432ac 100644 --- a/README.md +++ b/README.md @@ -206,6 +206,24 @@ through the dependency graph, you would create a wrapping type: type MySQLConnectionString string ``` +## Advanced Features + +### Optional Inputs + +A provider input can be marked optional using `goose:optional`: + +```go +//goose:provide Bar +//goose:optional foo + +func provideBar(foo Foo) Bar { + // ... +} +``` + +If used as part of an injector that does not bring in the `Foo` dependency, then +the injector will pass the provider the zero value as the `foo` argument. + ## Future Work - Support for map bindings. diff --git a/internal/goose/analyze.go b/internal/goose/analyze.go index f9d4872..7e3a5a4 100644 --- a/internal/goose/analyze.go +++ b/internal/goose/analyze.go @@ -14,14 +14,16 @@ type call struct { importPath string funcName string - // args is a list of arguments to call the provider with. Each element is either: - // a) one of the givens (args[i] < len(given)) or - // b) the result of a previous provider call (args[i] >= len(given)). + // args is a list of arguments to call the provider with. Each element is: + // a) one of the givens (args[i] < len(given)), + // b) the result of a previous provider call (args[i] >= len(given)), or + // c) the zero value for the type (args[i] == -1). args []int + // ins is the list of types this call receives as arguments. + ins []types.Type // out is the type produced by this provider call. out types.Type - // hasErr is true if the provider call returns an error. hasErr bool } @@ -56,14 +58,14 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov // using a depth-first search. The graph may contain cycles, which // should trigger an error. var calls []call - var visit func(trail []types.Type) error - visit = func(trail []types.Type) error { - typ := trail[len(trail)-1] + var visit func(trail []providerInput) error + visit = func(trail []providerInput) error { + typ := trail[len(trail)-1].typ if index.At(typ) != nil { return nil } - for _, t := range trail[:len(trail)-1] { - if types.Identical(typ, t) { + for _, in := range trail[:len(trail)-1] { + if types.Identical(typ, in.typ) { // TODO(light): describe cycle return fmt.Errorf("cycle for %s", types.TypeString(typ, nil)) } @@ -71,11 +73,14 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov p, _ := providers.At(typ).(*providerInfo) if p == nil { + if trail[len(trail)-1].optional { + return nil + } if len(trail) == 1 { return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil)) } // TODO(light): give name of provider - return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2], nil)) + return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil)) } for _, a := range p.args { // TODO(light): this will discard grown trail arrays. @@ -84,20 +89,27 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []prov } } args := make([]int, len(p.args)) + ins := make([]types.Type, len(p.args)) for i := range p.args { - args[i] = index.At(p.args[i]).(int) + ins[i] = p.args[i].typ + if x := index.At(p.args[i].typ); x != nil { + args[i] = x.(int) + } else { + args[i] = -1 + } } index.Set(typ, len(given)+len(calls)) calls = append(calls, call{ importPath: p.importPath, funcName: p.funcName, args: args, + ins: ins, out: typ, hasErr: p.hasErr, }) return nil } - if err := visit([]types.Type{out}); err != nil { + if err := visit([]providerInput{{typ: out}}); err != nil { return nil, err } return calls, nil diff --git a/internal/goose/goose.go b/internal/goose/goose.go index 7119a89..554212e 100644 --- a/internal/goose/goose.go +++ b/internal/goose/goose.go @@ -174,6 +174,11 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se } for _, c := range calls { g.qualifyImport(c.importPath) + for i := range c.args { + if c.args[i] == -1 { + zeroValue(c.ins[i], g.qualifyPkg) + } + } } outTypeString := types.TypeString(outType, g.qualifyPkg) zv := zeroValue(outType, g.qualifyPkg) @@ -236,7 +241,9 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se if j > 0 { g.p(", ") } - if a < params.Len() { + if a == -1 { + g.p("%s", zeroValue(c.ins[j], g.qualifyPkg)) + } else if a < params.Len() { g.p("%s", paramNames[a]) } else { g.p("%s", localNames[a-params.Len()]) diff --git a/internal/goose/parse.go b/internal/goose/parse.go index 0262b53..a38e3e0 100644 --- a/internal/goose/parse.go +++ b/internal/goose/parse.go @@ -30,140 +30,41 @@ type providerInfo struct { importPath string funcName string pos token.Pos - args []types.Type + args []providerInput out types.Type hasErr bool } +type providerInput struct { + typ types.Type + optional bool +} + +type findContext struct { + fset *token.FileSet + pkg *types.Package + typeInfo *types.Info + r *importResolver +} + // findProviderSets processes a package and extracts the provider sets declared in it. -func findProviderSets(fset *token.FileSet, pkg *types.Package, r *importResolver, typeInfo *types.Info, files []*ast.File) (map[string]*providerSet, error) { +func findProviderSets(fctx findContext, files []*ast.File) (map[string]*providerSet, error) { sets := make(map[string]*providerSet) - var directives []directive for _, f := range files { - fileScope := typeInfo.Scopes[f] - for _, c := range f.Comments { - directives = extractDirectives(directives[:0], c) - for _, d := range directives { - switch d.kind { - case "provide", "use": - // handled later - case "import": - if fileScope == nil { - return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fset.File(f.Pos()).Name()) - } - i := strings.IndexByte(d.line, ' ') - // TODO(light): allow multiple imports in one line - if i == -1 { - return nil, fmt.Errorf("%s: invalid import: expected TARGET SETREF", fset.Position(d.pos)) - } - name, spec := d.line[:i], d.line[i+1:] - ref, err := parseProviderSetRef(r, spec, fileScope, pkg.Path(), d.pos) - if err != nil { - return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err) - } - if ref.importPath != pkg.Path() { - imported := false - for _, imp := range pkg.Imports() { - if ref.importPath == imp.Path() { - imported = true - break - } - } - if !imported { - return nil, fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fset.Position(d.pos), name, ref.importPath) - } - } - if mod := sets[name]; mod != nil { - found := false - for _, other := range mod.imports { - if ref == other.providerSetRef { - found = true - break - } - } - if !found { - mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos}) - } - } else { - sets[name] = &providerSet{ - imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}}, - } - } - default: - return nil, fmt.Errorf("%v: unknown directive %s", fset.Position(d.pos), d.kind) - } - } + fileScope := fctx.typeInfo.Scopes[f] + if fileScope == nil { + return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fctx.fset.File(f.Pos()).Name()) } - cmap := ast.NewCommentMap(fset, f, f.Comments) - for _, decl := range f.Decls { - directives = directives[:0] - for _, cg := range cmap[decl] { - directives = extractDirectives(directives, cg) - } - fn, isFunction := decl.(*ast.FuncDecl) - var providerSetName string - for _, d := range directives { - if d.kind != "provide" { - continue + for _, dg := range parseFile(fctx.fset, f) { + if dg.decl != nil { + if err := processDeclDirectives(fctx, sets, fileScope, dg); err != nil { + return nil, err } - if providerSetName != "" { - return nil, fmt.Errorf("%v: multiple provide directives for %s", fset.Position(d.pos), fn.Name.Name) - } - if !isFunction { - return nil, fmt.Errorf("%v: only functions can be marked as providers", fset.Position(d.pos)) - } - providerSetName = fn.Name.Name - if d.line != "" { - // TODO(light): validate identifier - providerSetName = d.line - } - } - if providerSetName == "" { - continue - } - fpos := fn.Pos() - sig := typeInfo.ObjectOf(fn.Name).Type().(*types.Signature) - r := sig.Results() - var hasErr bool - switch r.Len() { - case 1: - hasErr = false - case 2: - if t := r.At(1).Type(); !types.Identical(t, errorType) { - return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fset.Position(fpos), fn.Name.Name) - } - hasErr = true - default: - return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name.Name) - } - out := r.At(0).Type() - p := sig.Params() - provider := &providerInfo{ - importPath: pkg.Path(), - funcName: fn.Name.Name, - pos: fn.Pos(), - args: make([]types.Type, p.Len()), - out: out, - hasErr: hasErr, - } - for i := 0; i < p.Len(); i++ { - provider.args[i] = p.At(i).Type() - for j := 0; j < i; j++ { - if types.Identical(provider.args[i], provider.args[j]) { - return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.args[j], nil)) - } - } - } - if mod := sets[providerSetName]; mod != nil { - for _, other := range mod.providers { - if types.Identical(other.out, provider.out) { - return nil, fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fset.Position(fpos), providerSetName, types.TypeString(provider.out, nil), fset.Position(other.pos)) - } - } - mod.providers = append(mod.providers, provider) } else { - sets[providerSetName] = &providerSet{ - providers: []*providerInfo{provider}, + for _, d := range dg.dirs { + if err := processUnassociatedDirective(fctx, sets, fileScope, d); err != nil { + return nil, err + } } } } @@ -171,6 +72,147 @@ func findProviderSets(fset *token.FileSet, pkg *types.Package, r *importResolver return sets, nil } +// processUnassociatedDirective handles any directive that was not associated with a top-level declaration. +func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet, scope *types.Scope, d directive) error { + switch d.kind { + case "provide", "optional": + return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos)) + case "use": + // Ignore, picked up by injector flow. + case "import": + i := strings.IndexByte(d.line, ' ') + // TODO(light): allow multiple imports in one line + if i == -1 { + return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos)) + } + name, spec := d.line[:i], d.line[i+1:] + ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos) + if err != nil { + return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) + } + if ref.importPath != fctx.pkg.Path() { + imported := false + for _, imp := range fctx.pkg.Imports() { + if ref.importPath == imp.Path() { + imported = true + break + } + } + if !imported { + return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath) + } + } + if mod := sets[name]; mod != nil { + found := false + for _, other := range mod.imports { + if ref == other.providerSetRef { + found = true + break + } + } + if !found { + mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos}) + } + } else { + sets[name] = &providerSet{ + imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}}, + } + } + default: + return fmt.Errorf("%v: unknown directive %s", fctx.fset.Position(d.pos), d.kind) + } + return nil +} + +// processDeclDirectives processes the directives associated with a top-level declaration. +func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope *types.Scope, dg directiveGroup) error { + p, err := dg.single(fctx.fset, "provide") + if err != nil { + return err + } + if !p.isValid() { + for _, d := range dg.dirs { + if d.kind == "optional" { + return fmt.Errorf("%v: cannot use goose:%s directive on non-provider", fctx.fset.Position(d.pos), d.kind) + } + } + return nil + } + fn, ok := dg.decl.(*ast.FuncDecl) + if !ok { + return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(p.pos)) + } + sig := fctx.typeInfo.ObjectOf(fn.Name).Type().(*types.Signature) + + optionals := make([]bool, sig.Params().Len()) + for _, d := range dg.dirs { + if d.kind == "optional" { + // Marking the given argument names as optional inputs. + for _, arg := range strings.Fields(d.line) { + pi := paramIndex(sig.Params(), arg) + if pi == -1 { + return fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(d.pos), arg, fn.Name.Name) + } + optionals[pi] = true + } + } + } + + fpos := fn.Pos() + r := sig.Results() + var hasErr bool + switch r.Len() { + case 1: + hasErr = false + case 2: + if t := r.At(1).Type(); !types.Identical(t, errorType) { + return fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fctx.fset.Position(fpos), fn.Name.Name) + } + hasErr = true + default: + return fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name.Name) + } + out := r.At(0).Type() + params := sig.Params() + provider := &providerInfo{ + importPath: fctx.pkg.Path(), + funcName: fn.Name.Name, + pos: fn.Pos(), + args: make([]providerInput, params.Len()), + out: out, + hasErr: hasErr, + } + for i := 0; i < params.Len(); i++ { + provider.args[i] = providerInput{ + typ: params.At(i).Type(), + optional: optionals[i], + } + for j := 0; j < i; j++ { + if types.Identical(provider.args[i].typ, provider.args[j].typ) { + return fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil)) + } + } + } + providerSetName := fn.Name.Name + if p.line != "" { + // TODO(light): validate identifier + providerSetName = p.line + } + if mod := sets[providerSetName]; mod != nil { + for _, other := range mod.providers { + if types.Identical(other.out, provider.out) { + return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos)) + } + } + mod.providers = append(mod.providers, provider) + } else { + sets[providerSetName] = &providerSet{ + providers: []*providerInfo{provider}, + } + } + return nil +} + // providerSetCache is a lazily evaluated index of provider sets. type providerSetCache struct { sets map[string]map[string]*providerSet @@ -199,7 +241,12 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) { mc.sets = make(map[string]map[string]*providerSet) } pkg := mc.prog.Package(ref.importPath) - mods, err := findProviderSets(mc.fset, pkg.Pkg, mc.r, &pkg.Info, pkg.Files) + mods, err := findProviderSets(findContext{ + fset: mc.fset, + pkg: pkg.Pkg, + typeInfo: &pkg.Info, + r: mc.r, + }, pkg.Files) if err != nil { mc.sets[ref.importPath] = nil return nil, err @@ -282,12 +329,68 @@ func (r *importResolver) resolve(pos token.Pos, path string) (string, error) { return pkg.ImportPath, nil } +// A directive is a parsed goose comment. type directive struct { pos token.Pos kind string line string } +// A directiveGroup is a set of directives associated with a particular +// declaration. +type directiveGroup struct { + decl ast.Decl + dirs []directive +} + +// parseFile extracts the directives from a file, grouped by declaration. +func parseFile(fset *token.FileSet, f *ast.File) []directiveGroup { + cmap := ast.NewCommentMap(fset, f, f.Comments) + // Reserve first group for directives that don't associate with a + // declaration, like import. + groups := make([]directiveGroup, 1, len(f.Decls)+1) + // Walk declarations and add to groups. + for _, decl := range f.Decls { + grp := directiveGroup{decl: decl} + ast.Inspect(decl, func(node ast.Node) bool { + if g := cmap[node]; len(g) > 0 { + for _, cg := range g { + start := len(grp.dirs) + grp.dirs = extractDirectives(grp.dirs, cg) + + // Move directives that don't associate into the unassociated group. + n := 0 + for i := start; i < len(grp.dirs); i++ { + if k := grp.dirs[i].kind; k == "provide" || k == "optional" || k == "use" { + grp.dirs[start+n] = grp.dirs[i] + n++ + } else { + groups[0].dirs = append(groups[0].dirs, grp.dirs[i]) + } + } + grp.dirs = grp.dirs[:start+n] + } + delete(cmap, node) + } + return true + }) + if len(grp.dirs) > 0 { + groups = append(groups, grp) + } + } + // Place remaining directives into the unassociated group. + unassoc := &groups[0] + for _, g := range cmap { + for _, cg := range g { + unassoc.dirs = extractDirectives(unassoc.dirs, cg) + } + } + if len(unassoc.dirs) == 0 { + return groups[1:] + } + return groups +} + func extractDirectives(d []directive, cg *ast.CommentGroup) []directive { const prefix = "goose:" text := cg.Text() @@ -318,6 +421,37 @@ func extractDirectives(d []directive, cg *ast.CommentGroup) []directive { return d } +// single finds at most one directive that matches the given kind. +func (dg directiveGroup) single(fset *token.FileSet, kind string) (directive, error) { + var found directive + ok := false + for _, d := range dg.dirs { + if d.kind != kind { + continue + } + if ok { + switch decl := dg.decl.(type) { + case *ast.FuncDecl: + return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, decl.Name.Name) + case *ast.GenDecl: + if decl.Tok == token.TYPE && len(decl.Specs) == 1 { + name := decl.Specs[0].(*ast.TypeSpec).Name.Name + return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, name) + } + return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind) + default: + return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind) + } + } + found, ok = d, true + } + return found, nil +} + +func (d directive) isValid() bool { + return d.kind != "" +} + // isInjectFile reports whether a given file is an injection template. func isInjectFile(f *ast.File) bool { // TODO(light): better determination @@ -329,3 +463,14 @@ func isInjectFile(f *ast.File) bool { } return false } + +// paramIndex returns the index of the parameter with the given name, or +// -1 if no such parameter exists. +func paramIndex(params *types.Tuple, name string) int { + for i := 0; i < params.Len(); i++ { + if params.At(i).Name() == name { + return i + } + } + return -1 +} diff --git a/internal/goose/testdata/OptionalMissing/foo/foo.go b/internal/goose/testdata/OptionalMissing/foo/foo.go new file mode 100644 index 0000000..e7e839a --- /dev/null +++ b/internal/goose/testdata/OptionalMissing/foo/foo.go @@ -0,0 +1,16 @@ +package main + +import "fmt" + +func main() { + fmt.Println(injectBar()) +} + +type foo int +type bar int + +//goose:provide +//goose:optional f +func provideBar(f foo) bar { + return bar(f) +} diff --git a/internal/goose/testdata/OptionalMissing/foo/foo_goose.go b/internal/goose/testdata/OptionalMissing/foo/foo_goose.go new file mode 100644 index 0000000..0d53f4d --- /dev/null +++ b/internal/goose/testdata/OptionalMissing/foo/foo_goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use provideBar + +func injectBar() bar diff --git a/internal/goose/testdata/OptionalMissing/out.txt b/internal/goose/testdata/OptionalMissing/out.txt new file mode 100644 index 0000000..573541a --- /dev/null +++ b/internal/goose/testdata/OptionalMissing/out.txt @@ -0,0 +1 @@ +0 diff --git a/internal/goose/testdata/OptionalMissing/pkg b/internal/goose/testdata/OptionalMissing/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/OptionalMissing/pkg @@ -0,0 +1 @@ +foo diff --git a/internal/goose/testdata/OptionalPresent/foo/foo.go b/internal/goose/testdata/OptionalPresent/foo/foo.go new file mode 100644 index 0000000..c0a9fa4 --- /dev/null +++ b/internal/goose/testdata/OptionalPresent/foo/foo.go @@ -0,0 +1,16 @@ +package main + +import "fmt" + +func main() { + fmt.Println(injectBar(42)) +} + +type foo int +type bar int + +//goose:provide +//goose:optional f +func provideBar(f foo) bar { + return bar(f) +} diff --git a/internal/goose/testdata/OptionalPresent/foo/foo_goose.go b/internal/goose/testdata/OptionalPresent/foo/foo_goose.go new file mode 100644 index 0000000..c2665b4 --- /dev/null +++ b/internal/goose/testdata/OptionalPresent/foo/foo_goose.go @@ -0,0 +1,7 @@ +//+build gooseinject + +package main + +//goose:use provideBar + +func injectBar(foo) bar diff --git a/internal/goose/testdata/OptionalPresent/out.txt b/internal/goose/testdata/OptionalPresent/out.txt new file mode 100644 index 0000000..d81cc07 --- /dev/null +++ b/internal/goose/testdata/OptionalPresent/out.txt @@ -0,0 +1 @@ +42 diff --git a/internal/goose/testdata/OptionalPresent/pkg b/internal/goose/testdata/OptionalPresent/pkg new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/internal/goose/testdata/OptionalPresent/pkg @@ -0,0 +1 @@ +foo