diff --git a/internal/goose/analyze.go b/internal/goose/analyze.go index 4c4bc73..b2d32ed 100644 --- a/internal/goose/analyze.go +++ b/internal/goose/analyze.go @@ -48,9 +48,8 @@ type call struct { name string // args is a list of arguments to call the provider with. Each element is: - // a) one of the givens (args[i] < len(given)), - // b) the result of a previous provider call (args[i] >= len(given)), or - // c) the zero value for the type (args[i] == -1). + // a) one of the givens (args[i] < len(given)), or + // b) the result of a previous provider call (args[i] >= len(given)) // // This will be nil for kind == valueExpr. args []int diff --git a/internal/goose/goose.go b/internal/goose/goose.go index 5cb9ae8..07984f3 100644 --- a/internal/goose/goose.go +++ b/internal/goose/goose.go @@ -108,7 +108,7 @@ func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*as return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err) } sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature) - if err := g.inject(g.prog.Fset, fn.Name.Name, sig, set); err != nil { + if err := g.inject(fn.Name.Name, sig, set); err != nil { return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err) } } @@ -140,18 +140,18 @@ func copyNonInjectorDecls(g *gen, files []*ast.File, info *types.Info) { first = false } // TODO(light): Add line number at top of each declaration. - g.writeAST(g.prog.Fset, info, decl) + g.writeAST(info, decl) g.p("\n\n") } } } -// gen is the generator state. +// gen is the file-wide generator state. type gen struct { currPackage string buf bytes.Buffer imports map[string]string - prog *loader.Program // for determining package names + prog *loader.Program // for positions and determining package names } func newGen(prog *loader.Program, pkg string) *gen { @@ -190,237 +190,54 @@ func (g *gen) frame() []byte { } // inject emits the code for an injector. -func (g *gen) inject(fset *token.FileSet, name string, sig *types.Signature, set *ProviderSet) error { - results := sig.Results() - var returnsCleanup, returnsErr bool - switch results.Len() { - case 0: - return fmt.Errorf("inject %s: no return values", name) - case 1: - returnsCleanup, returnsErr = false, false - case 2: - switch t := results.At(1).Type(); { - case types.Identical(t, errorType): - returnsCleanup, returnsErr = false, true - case types.Identical(t, cleanupType): - returnsCleanup, returnsErr = true, false - default: - return fmt.Errorf("inject %s: second return type is %s; must be error or func()", name, types.TypeString(t, nil)) - } - case 3: - if t := results.At(1).Type(); !types.Identical(t, cleanupType) { - return fmt.Errorf("inject %s: second return type is %s; must be func()", name, types.TypeString(t, nil)) - } - if t := results.At(2).Type(); !types.Identical(t, errorType) { - return fmt.Errorf("inject %s: third return type is %s; must be error", name, types.TypeString(t, nil)) - } - returnsCleanup, returnsErr = true, true - default: - return fmt.Errorf("inject %s: too many return values", name) +func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error { + injectSig, err := funcOutput(sig) + if err != nil { + return fmt.Errorf("inject %s: %v", name, err) } - outType := results.At(0).Type() params := sig.Params() given := make([]types.Type, params.Len()) for i := 0; i < params.Len(); i++ { given[i] = params.At(i).Type() } - calls, err := solve(fset, outType, given, set) + calls, err := solve(g.prog.Fset, injectSig.out, given, set) if err != nil { return err } for i := range calls { - if calls[i].hasCleanup && !returnsCleanup { - return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(calls[i].out, nil)) + c := &calls[i] + if c.hasCleanup && !injectSig.cleanup { + return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(c.out, nil)) } - if calls[i].hasErr && !returnsErr { - return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil)) + if c.hasErr && !injectSig.err { + return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(c.out, nil)) } - } - - // Prequalify all types. Since import disambiguation ignores local - // variables, it takes precedence. - paramTypes := make([]string, params.Len()) - for i := 0; i < params.Len(); i++ { - paramTypes[i] = types.TypeString(params.At(i).Type(), g.qualifyPkg) - } - for _, c := range calls { - switch c.kind { - case funcProviderCall: - g.qualifyImport(c.importPath) - for i := range c.args { - if c.args[i] == -1 { - zeroValue(c.ins[i], g.qualifyPkg) - } - } - case structProvider: - g.qualifyImport(c.importPath) - case valueExpr: + if c.kind == valueExpr { if err := accessibleFrom(c.valueTypeInfo, c.valueExpr, g.currPackage); err != nil { // TODO(light): Display line number of value expression. ts := types.TypeString(c.out, nil) return fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err) } - default: - panic("unknown kind") } } - outTypeString := types.TypeString(outType, g.qualifyPkg) - zv := zeroValue(outType, g.qualifyPkg) - // Set up local variables. - paramNames := make([]string, params.Len()) - localNames := make([]string, len(calls)) - cleanupNames := make([]string, len(calls)) - errVar := disambiguate("err", g.nameInFileScope) - collides := func(v string) bool { - if v == errVar { - return true - } - for _, a := range paramNames { - if a == v { - return true - } - } - for _, l := range localNames { - if l == v { - return true - } - } - for _, l := range cleanupNames { - if l == v { - return true - } - } - return g.nameInFileScope(v) - } - g.p("func %s(", name) - for i := 0; i < params.Len(); i++ { - if i > 0 { - g.p(", ") - } - pi := params.At(i) - a := pi.Name() - if a == "" || a == "_" { - a = typeVariableName(pi.Type()) - if a == "" { - a = "arg" - } - } - paramNames[i] = disambiguate(a, collides) - g.p("%s %s", paramNames[i], paramTypes[i]) - } - if returnsCleanup && returnsErr { - g.p(") (%s, func(), error) {\n", outTypeString) - } else if returnsCleanup { - g.p(") (%s, func()) {\n", outTypeString) - } else if returnsErr { - g.p(") (%s, error) {\n", outTypeString) - } else { - g.p(") %s {\n", outTypeString) - } - for i := range calls { - c := &calls[i] - lname := typeVariableName(c.out) - if lname == "" { - lname = "v" - } - lname = disambiguate(lname, collides) - localNames[i] = lname - g.p("\t%s", lname) - if c.hasCleanup { - cleanupNames[i] = disambiguate("cleanup", collides) - g.p(", %s", cleanupNames[i]) - } - if c.hasErr { - g.p(", %s", errVar) - } - g.p(" := ") - switch c.kind { - case structProvider: - if _, ok := c.out.(*types.Pointer); ok { - g.p("&") - } - g.p("%s{\n", g.qualifiedID(c.importPath, c.name)) - for j, a := range c.args { - if a == -1 { - // Omit zero value fields from composite literal. - continue - } - g.p("\t\t%s: ", c.fieldNames[j]) - if a < params.Len() { - g.p("%s", paramNames[a]) - } else { - g.p("%s", localNames[a-params.Len()]) - } - g.p(",\n") - } - g.p("\t}\n") - case funcProviderCall: - g.p("%s(", g.qualifiedID(c.importPath, c.name)) - for j, a := range c.args { - if j > 0 { - g.p(", ") - } - if a == -1 { - g.p("%s", zeroValue(c.ins[j], g.qualifyPkg)) - } else if a < params.Len() { - g.p("%s", paramNames[a]) - } else { - g.p("%s", localNames[a-params.Len()]) - } - } - g.p(")\n") - case valueExpr: - g.writeAST(fset, c.valueTypeInfo, c.valueExpr) - g.p("\n") - default: - panic("unknown kind") - } - if c.hasErr { - g.p("\tif %s != nil {\n", errVar) - for j := i - 1; j >= 0; j-- { - if calls[j].hasCleanup { - g.p("\t\t%s()\n", cleanupNames[j]) - } - } - g.p("\t\treturn %s", zv) - if returnsCleanup { - g.p(", nil") - } - // TODO(light): Give information about failing provider. - g.p(", err\n") - g.p("\t}\n") - } - } - if len(calls) == 0 { - for i := range given { - if types.Identical(outType, given[i]) { - g.p("\treturn %s", paramNames[i]) - break - } - } - } else { - g.p("\treturn %s", localNames[len(calls)-1]) - } - if returnsCleanup { - g.p(", func() {\n") - for i := len(calls) - 1; i >= 0; i-- { - if calls[i].hasCleanup { - g.p("\t\t%s()\n", cleanupNames[i]) - } - } - g.p("\t}") - } - if returnsErr { - g.p(", nil") - } - g.p("\n}\n\n") + // Perform one pass to collect all imports, followed by the real pass. + injectPass(name, params, injectSig, calls, &injectorGen{ + g: g, + errVar: disambiguate("err", g.nameInFileScope), + discard: true, + }) + injectPass(name, params, injectSig, calls, &injectorGen{ + g: g, + errVar: disambiguate("err", g.nameInFileScope), + discard: false, + }) return nil } -// writeAST prints an AST node into the generated output, rewriting any -// package references it encounters. -func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) { +// rewritePkgRefs rewrites any package references in an AST into references for the +// generated package. +func (g *gen) rewritePkgRefs(info *types.Info, node ast.Node) ast.Node { start, end := node.Pos(), node.End() node = copyAST(node) // First, rewrite all package names. This lets us know all the @@ -500,7 +317,7 @@ func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) { return true } - // Rename any symbols defined within writeAST's node that conflict + // Rename any symbols defined within rewritePkgRefs's node that conflict // with any symbols in the generated file. objName := obj.Name() if pos := obj.Pos(); pos < start || end <= pos || !(g.nameInFileScope(objName) || inNewNames(objName)) { @@ -530,7 +347,14 @@ func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) { } return true }) - if err := printer.Fprint(&g.buf, fset, node); err != nil { + return node +} + +// writeAST prints an AST node into the generated output, rewriting any +// package references it encounters. +func (g *gen) writeAST(info *types.Info, node ast.Node) { + node = g.rewritePkgRefs(info, node) + if err := printer.Fprint(&g.buf, g.prog.Fset, node); err != nil { panic(err) } } @@ -583,6 +407,196 @@ func (g *gen) p(format string, args ...interface{}) { fmt.Fprintf(&g.buf, format, args...) } +// injectorGen is the per-injector pass generator state. +type injectorGen struct { + g *gen + + paramNames []string + localNames []string + cleanupNames []string + errVar string + + // discard causes ig.p and ig.writeAST to no-op. Useful to run + // generation for side-effects like filling in g.imports. + discard bool +} + +// injectPass generates an injector given the output from analysis. +func injectPass(name string, params *types.Tuple, injectSig outputSignature, calls []call, ig *injectorGen) { + ig.p("func %s(", name) + for i := 0; i < params.Len(); i++ { + if i > 0 { + ig.p(", ") + } + pi := params.At(i) + a := pi.Name() + if a == "" || a == "_" { + a = typeVariableName(pi.Type()) + if a == "" { + a = "arg" + } + } + ig.paramNames = append(ig.paramNames, disambiguate(a, ig.nameInInjector)) + ig.p("%s %s", ig.paramNames[i], types.TypeString(pi.Type(), ig.g.qualifyPkg)) + } + outTypeString := types.TypeString(injectSig.out, ig.g.qualifyPkg) + if injectSig.cleanup && injectSig.err { + ig.p(") (%s, func(), error) {\n", outTypeString) + } else if injectSig.cleanup { + ig.p(") (%s, func()) {\n", outTypeString) + } else if injectSig.err { + ig.p(") (%s, error) {\n", outTypeString) + } else { + ig.p(") %s {\n", outTypeString) + } + for i := range calls { + c := &calls[i] + lname := typeVariableName(c.out) + if lname == "" { + lname = "v" + } + lname = disambiguate(lname, ig.nameInInjector) + ig.localNames = append(ig.localNames, lname) + switch c.kind { + case structProvider: + ig.structProviderCall(lname, c) + case funcProviderCall: + ig.funcProviderCall(lname, c, injectSig) + case valueExpr: + ig.valueExpr(lname, c) + default: + panic("unknown kind") + } + } + if len(calls) == 0 { + for i := 0; i < params.Len(); i++ { + if types.Identical(injectSig.out, params.At(i).Type()) { + ig.p("\treturn %s", ig.paramNames[i]) + break + } + } + } else { + ig.p("\treturn %s", ig.localNames[len(calls)-1]) + } + if injectSig.cleanup { + ig.p(", func() {\n") + for i := len(ig.cleanupNames) - 1; i >= 0; i-- { + ig.p("\t\t%s()\n", ig.cleanupNames[i]) + } + ig.p("\t}") + } + if injectSig.err { + ig.p(", nil") + } + ig.p("\n}\n\n") +} + +func (ig *injectorGen) funcProviderCall(lname string, c *call, injectSig outputSignature) { + ig.p("\t%s", lname) + prevCleanup := len(ig.cleanupNames) + if c.hasCleanup { + cname := disambiguate("cleanup", ig.nameInInjector) + ig.cleanupNames = append(ig.cleanupNames, cname) + ig.p(", %s", cname) + } + if c.hasErr { + ig.p(", %s", ig.errVar) + } + ig.p(" := ") + ig.p("%s(", ig.g.qualifiedID(c.importPath, c.name)) + for i, a := range c.args { + if i > 0 { + ig.p(", ") + } + if a < len(ig.paramNames) { + ig.p("%s", ig.paramNames[a]) + } else { + ig.p("%s", ig.localNames[a-len(ig.paramNames)]) + } + } + ig.p(")\n") + if c.hasErr { + ig.p("\tif %s != nil {\n", ig.errVar) + for i := prevCleanup - 1; i >= 0; i-- { + ig.p("\t\t%s()\n", ig.cleanupNames[i]) + } + ig.p("\t\treturn %s", zeroValue(injectSig.out, ig.g.qualifyPkg)) + if injectSig.cleanup { + ig.p(", nil") + } + // TODO(light): Give information about failing provider. + ig.p(", err\n") + ig.p("\t}\n") + } +} + +func (ig *injectorGen) structProviderCall(lname string, c *call) { + ig.p("\t%s", lname) + ig.p(" := ") + if _, ok := c.out.(*types.Pointer); ok { + ig.p("&") + } + ig.p("%s{\n", ig.g.qualifiedID(c.importPath, c.name)) + for i, a := range c.args { + ig.p("\t\t%s: ", c.fieldNames[i]) + if a < len(ig.paramNames) { + ig.p("%s", ig.paramNames[a]) + } else { + ig.p("%s", ig.localNames[a-len(ig.paramNames)]) + } + ig.p(",\n") + } + ig.p("\t}\n") +} + +func (ig *injectorGen) valueExpr(lname string, c *call) { + ig.p("\t%s", lname) + ig.p(" := ") + ig.writeAST(c.valueTypeInfo, c.valueExpr) + ig.p("\n") +} + +// nameInInjector reports whether name collides with any other identifier +// in the current injector. +func (ig *injectorGen) nameInInjector(name string) bool { + if name == ig.errVar { + return true + } + for _, a := range ig.paramNames { + if a == name { + return true + } + } + for _, l := range ig.localNames { + if l == name { + return true + } + } + for _, l := range ig.cleanupNames { + if l == name { + return true + } + } + return ig.g.nameInFileScope(name) +} + +func (ig *injectorGen) p(format string, args ...interface{}) { + if ig.discard { + return + } + ig.g.p(format, args...) +} + +func (ig *injectorGen) writeAST(info *types.Info, node ast.Node) { + node = ig.g.rewritePkgRefs(info, node) + if ig.discard { + return + } + if err := printer.Fprint(&ig.g.buf, ig.g.prog.Fset, node); err != nil { + panic(err) + } +} + // zeroValue returns the shortest expression that evaluates to the zero // value for the given type. func zeroValue(t types.Type, qf types.Qualifier) string { diff --git a/internal/goose/parse.go b/internal/goose/parse.go index 0b29d20..e829642 100644 --- a/internal/goose/parse.go +++ b/internal/goose/parse.go @@ -15,6 +15,7 @@ package goose import ( + "errors" "fmt" "go/ast" "go/build" @@ -384,43 +385,20 @@ func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object { // processFuncProvider creates a provider for a function declaration. func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error) { sig := fn.Type().(*types.Signature) - fpos := fn.Pos() - r := sig.Results() - var hasCleanup, hasErr bool - switch r.Len() { - case 1: - hasCleanup, hasErr = false, false - case 2: - switch t := r.At(1).Type(); { - case types.Identical(t, errorType): - hasCleanup, hasErr = false, true - case types.Identical(t, cleanupType): - hasCleanup, hasErr = true, false - default: - return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fset.Position(fpos), fn.Name()) - } - case 3: - if t := r.At(1).Type(); !types.Identical(t, cleanupType) { - return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fset.Position(fpos), fn.Name()) - } - if t := r.At(2).Type(); !types.Identical(t, errorType) { - return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fset.Position(fpos), fn.Name()) - } - hasCleanup, hasErr = true, true - default: - return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name()) + providerSig, err := funcOutput(sig) + if err != nil { + return nil, fmt.Errorf("%v: wrong signature for provider %s: %v", fset.Position(fpos), fn.Name(), err) } - out := r.At(0).Type() params := sig.Params() provider := &Provider{ ImportPath: fn.Pkg().Path(), Name: fn.Name(), Pos: fn.Pos(), Args: make([]ProviderInput, params.Len()), - Out: out, - HasCleanup: hasCleanup, - HasErr: hasErr, + Out: providerSig.out, + HasCleanup: providerSig.cleanup, + HasErr: providerSig.err, } for i := 0; i < params.Len(); i++ { provider.Args[i] = ProviderInput{ @@ -435,6 +413,47 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error) return provider, nil } +type outputSignature struct { + out types.Type + cleanup bool + err bool +} + +// funcOutput validates an injector or provider function's return signature. +func funcOutput(sig *types.Signature) (outputSignature, error) { + results := sig.Results() + switch results.Len() { + case 0: + return outputSignature{}, errors.New("no return values") + case 1: + return outputSignature{out: results.At(0).Type()}, nil + case 2: + out := results.At(0).Type() + switch t := results.At(1).Type(); { + case types.Identical(t, errorType): + return outputSignature{out: out, err: true}, nil + case types.Identical(t, cleanupType): + return outputSignature{out: out, cleanup: true}, nil + default: + return outputSignature{}, fmt.Errorf("second return type is %s; must be error or func()", types.TypeString(t, nil)) + } + case 3: + if t := results.At(1).Type(); !types.Identical(t, cleanupType) { + return outputSignature{}, fmt.Errorf("second return type is %s; must be func()", types.TypeString(t, nil)) + } + if t := results.At(2).Type(); !types.Identical(t, errorType) { + return outputSignature{}, fmt.Errorf("third return type is %s; must be error", types.TypeString(t, nil)) + } + return outputSignature{ + out: results.At(0).Type(), + cleanup: true, + err: true, + }, nil + default: + return outputSignature{}, errors.New("too many return values") + } +} + // processStructProvider creates a provider for a named struct type. // It only produces the non-pointer variant. func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, error) {