goose: refactor *gen.inject

The function had grown too long. Several related cleanups:

- Factored out the function return value logic, which had been
  duplicated between providers and injectors.
- Moved code generation for different provider call types into separate
  functions. This moves injector-specific state to a new type
  injectorGen to keep the parameter count down.
- Since it's infeasible to keep the "shadow pass" collecting import
  identifiers in sync the spread out logic, the injector code
  generation is just run twice, with initial output discarded.
- Removed the zero value logic left over from Optional.

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-05-08 16:26:38 -04:00
parent 10676a814b
commit 3c0eaf830e
3 changed files with 278 additions and 246 deletions

View File

@@ -48,9 +48,8 @@ type call struct {
name string
// args is a list of arguments to call the provider with. Each element is:
// a) one of the givens (args[i] < len(given)),
// b) the result of a previous provider call (args[i] >= len(given)), or
// c) the zero value for the type (args[i] == -1).
// a) one of the givens (args[i] < len(given)), or
// b) the result of a previous provider call (args[i] >= len(given))
//
// This will be nil for kind == valueExpr.
args []int

View File

@@ -108,7 +108,7 @@ func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*as
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
}
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
if err := g.inject(g.prog.Fset, fn.Name.Name, sig, set); err != nil {
if err := g.inject(fn.Name.Name, sig, set); err != nil {
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
}
}
@@ -140,18 +140,18 @@ func copyNonInjectorDecls(g *gen, files []*ast.File, info *types.Info) {
first = false
}
// TODO(light): Add line number at top of each declaration.
g.writeAST(g.prog.Fset, info, decl)
g.writeAST(info, decl)
g.p("\n\n")
}
}
}
// gen is the generator state.
// gen is the file-wide generator state.
type gen struct {
currPackage string
buf bytes.Buffer
imports map[string]string
prog *loader.Program // for determining package names
prog *loader.Program // for positions and determining package names
}
func newGen(prog *loader.Program, pkg string) *gen {
@@ -190,237 +190,54 @@ func (g *gen) frame() []byte {
}
// inject emits the code for an injector.
func (g *gen) inject(fset *token.FileSet, name string, sig *types.Signature, set *ProviderSet) error {
results := sig.Results()
var returnsCleanup, returnsErr bool
switch results.Len() {
case 0:
return fmt.Errorf("inject %s: no return values", name)
case 1:
returnsCleanup, returnsErr = false, false
case 2:
switch t := results.At(1).Type(); {
case types.Identical(t, errorType):
returnsCleanup, returnsErr = false, true
case types.Identical(t, cleanupType):
returnsCleanup, returnsErr = true, false
default:
return fmt.Errorf("inject %s: second return type is %s; must be error or func()", name, types.TypeString(t, nil))
func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error {
injectSig, err := funcOutput(sig)
if err != nil {
return fmt.Errorf("inject %s: %v", name, err)
}
case 3:
if t := results.At(1).Type(); !types.Identical(t, cleanupType) {
return fmt.Errorf("inject %s: second return type is %s; must be func()", name, types.TypeString(t, nil))
}
if t := results.At(2).Type(); !types.Identical(t, errorType) {
return fmt.Errorf("inject %s: third return type is %s; must be error", name, types.TypeString(t, nil))
}
returnsCleanup, returnsErr = true, true
default:
return fmt.Errorf("inject %s: too many return values", name)
}
outType := results.At(0).Type()
params := sig.Params()
given := make([]types.Type, params.Len())
for i := 0; i < params.Len(); i++ {
given[i] = params.At(i).Type()
}
calls, err := solve(fset, outType, given, set)
calls, err := solve(g.prog.Fset, injectSig.out, given, set)
if err != nil {
return err
}
for i := range calls {
if calls[i].hasCleanup && !returnsCleanup {
return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(calls[i].out, nil))
c := &calls[i]
if c.hasCleanup && !injectSig.cleanup {
return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(c.out, nil))
}
if calls[i].hasErr && !returnsErr {
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
if c.hasErr && !injectSig.err {
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(c.out, nil))
}
}
// Prequalify all types. Since import disambiguation ignores local
// variables, it takes precedence.
paramTypes := make([]string, params.Len())
for i := 0; i < params.Len(); i++ {
paramTypes[i] = types.TypeString(params.At(i).Type(), g.qualifyPkg)
}
for _, c := range calls {
switch c.kind {
case funcProviderCall:
g.qualifyImport(c.importPath)
for i := range c.args {
if c.args[i] == -1 {
zeroValue(c.ins[i], g.qualifyPkg)
}
}
case structProvider:
g.qualifyImport(c.importPath)
case valueExpr:
if c.kind == valueExpr {
if err := accessibleFrom(c.valueTypeInfo, c.valueExpr, g.currPackage); err != nil {
// TODO(light): Display line number of value expression.
ts := types.TypeString(c.out, nil)
return fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err)
}
default:
panic("unknown kind")
}
}
outTypeString := types.TypeString(outType, g.qualifyPkg)
zv := zeroValue(outType, g.qualifyPkg)
// Set up local variables.
paramNames := make([]string, params.Len())
localNames := make([]string, len(calls))
cleanupNames := make([]string, len(calls))
errVar := disambiguate("err", g.nameInFileScope)
collides := func(v string) bool {
if v == errVar {
return true
}
for _, a := range paramNames {
if a == v {
return true
}
}
for _, l := range localNames {
if l == v {
return true
}
}
for _, l := range cleanupNames {
if l == v {
return true
}
}
return g.nameInFileScope(v)
}
g.p("func %s(", name)
for i := 0; i < params.Len(); i++ {
if i > 0 {
g.p(", ")
}
pi := params.At(i)
a := pi.Name()
if a == "" || a == "_" {
a = typeVariableName(pi.Type())
if a == "" {
a = "arg"
}
}
paramNames[i] = disambiguate(a, collides)
g.p("%s %s", paramNames[i], paramTypes[i])
}
if returnsCleanup && returnsErr {
g.p(") (%s, func(), error) {\n", outTypeString)
} else if returnsCleanup {
g.p(") (%s, func()) {\n", outTypeString)
} else if returnsErr {
g.p(") (%s, error) {\n", outTypeString)
} else {
g.p(") %s {\n", outTypeString)
}
for i := range calls {
c := &calls[i]
lname := typeVariableName(c.out)
if lname == "" {
lname = "v"
}
lname = disambiguate(lname, collides)
localNames[i] = lname
g.p("\t%s", lname)
if c.hasCleanup {
cleanupNames[i] = disambiguate("cleanup", collides)
g.p(", %s", cleanupNames[i])
}
if c.hasErr {
g.p(", %s", errVar)
}
g.p(" := ")
switch c.kind {
case structProvider:
if _, ok := c.out.(*types.Pointer); ok {
g.p("&")
}
g.p("%s{\n", g.qualifiedID(c.importPath, c.name))
for j, a := range c.args {
if a == -1 {
// Omit zero value fields from composite literal.
continue
}
g.p("\t\t%s: ", c.fieldNames[j])
if a < params.Len() {
g.p("%s", paramNames[a])
} else {
g.p("%s", localNames[a-params.Len()])
}
g.p(",\n")
}
g.p("\t}\n")
case funcProviderCall:
g.p("%s(", g.qualifiedID(c.importPath, c.name))
for j, a := range c.args {
if j > 0 {
g.p(", ")
}
if a == -1 {
g.p("%s", zeroValue(c.ins[j], g.qualifyPkg))
} else if a < params.Len() {
g.p("%s", paramNames[a])
} else {
g.p("%s", localNames[a-params.Len()])
}
}
g.p(")\n")
case valueExpr:
g.writeAST(fset, c.valueTypeInfo, c.valueExpr)
g.p("\n")
default:
panic("unknown kind")
}
if c.hasErr {
g.p("\tif %s != nil {\n", errVar)
for j := i - 1; j >= 0; j-- {
if calls[j].hasCleanup {
g.p("\t\t%s()\n", cleanupNames[j])
}
}
g.p("\t\treturn %s", zv)
if returnsCleanup {
g.p(", nil")
}
// TODO(light): Give information about failing provider.
g.p(", err\n")
g.p("\t}\n")
}
}
if len(calls) == 0 {
for i := range given {
if types.Identical(outType, given[i]) {
g.p("\treturn %s", paramNames[i])
break
}
}
} else {
g.p("\treturn %s", localNames[len(calls)-1])
}
if returnsCleanup {
g.p(", func() {\n")
for i := len(calls) - 1; i >= 0; i-- {
if calls[i].hasCleanup {
g.p("\t\t%s()\n", cleanupNames[i])
}
}
g.p("\t}")
}
if returnsErr {
g.p(", nil")
}
g.p("\n}\n\n")
// Perform one pass to collect all imports, followed by the real pass.
injectPass(name, params, injectSig, calls, &injectorGen{
g: g,
errVar: disambiguate("err", g.nameInFileScope),
discard: true,
})
injectPass(name, params, injectSig, calls, &injectorGen{
g: g,
errVar: disambiguate("err", g.nameInFileScope),
discard: false,
})
return nil
}
// writeAST prints an AST node into the generated output, rewriting any
// package references it encounters.
func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) {
// rewritePkgRefs rewrites any package references in an AST into references for the
// generated package.
func (g *gen) rewritePkgRefs(info *types.Info, node ast.Node) ast.Node {
start, end := node.Pos(), node.End()
node = copyAST(node)
// First, rewrite all package names. This lets us know all the
@@ -500,7 +317,7 @@ func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) {
return true
}
// Rename any symbols defined within writeAST's node that conflict
// Rename any symbols defined within rewritePkgRefs's node that conflict
// with any symbols in the generated file.
objName := obj.Name()
if pos := obj.Pos(); pos < start || end <= pos || !(g.nameInFileScope(objName) || inNewNames(objName)) {
@@ -530,7 +347,14 @@ func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) {
}
return true
})
if err := printer.Fprint(&g.buf, fset, node); err != nil {
return node
}
// writeAST prints an AST node into the generated output, rewriting any
// package references it encounters.
func (g *gen) writeAST(info *types.Info, node ast.Node) {
node = g.rewritePkgRefs(info, node)
if err := printer.Fprint(&g.buf, g.prog.Fset, node); err != nil {
panic(err)
}
}
@@ -583,6 +407,196 @@ func (g *gen) p(format string, args ...interface{}) {
fmt.Fprintf(&g.buf, format, args...)
}
// injectorGen is the per-injector pass generator state.
type injectorGen struct {
g *gen
paramNames []string
localNames []string
cleanupNames []string
errVar string
// discard causes ig.p and ig.writeAST to no-op. Useful to run
// generation for side-effects like filling in g.imports.
discard bool
}
// injectPass generates an injector given the output from analysis.
func injectPass(name string, params *types.Tuple, injectSig outputSignature, calls []call, ig *injectorGen) {
ig.p("func %s(", name)
for i := 0; i < params.Len(); i++ {
if i > 0 {
ig.p(", ")
}
pi := params.At(i)
a := pi.Name()
if a == "" || a == "_" {
a = typeVariableName(pi.Type())
if a == "" {
a = "arg"
}
}
ig.paramNames = append(ig.paramNames, disambiguate(a, ig.nameInInjector))
ig.p("%s %s", ig.paramNames[i], types.TypeString(pi.Type(), ig.g.qualifyPkg))
}
outTypeString := types.TypeString(injectSig.out, ig.g.qualifyPkg)
if injectSig.cleanup && injectSig.err {
ig.p(") (%s, func(), error) {\n", outTypeString)
} else if injectSig.cleanup {
ig.p(") (%s, func()) {\n", outTypeString)
} else if injectSig.err {
ig.p(") (%s, error) {\n", outTypeString)
} else {
ig.p(") %s {\n", outTypeString)
}
for i := range calls {
c := &calls[i]
lname := typeVariableName(c.out)
if lname == "" {
lname = "v"
}
lname = disambiguate(lname, ig.nameInInjector)
ig.localNames = append(ig.localNames, lname)
switch c.kind {
case structProvider:
ig.structProviderCall(lname, c)
case funcProviderCall:
ig.funcProviderCall(lname, c, injectSig)
case valueExpr:
ig.valueExpr(lname, c)
default:
panic("unknown kind")
}
}
if len(calls) == 0 {
for i := 0; i < params.Len(); i++ {
if types.Identical(injectSig.out, params.At(i).Type()) {
ig.p("\treturn %s", ig.paramNames[i])
break
}
}
} else {
ig.p("\treturn %s", ig.localNames[len(calls)-1])
}
if injectSig.cleanup {
ig.p(", func() {\n")
for i := len(ig.cleanupNames) - 1; i >= 0; i-- {
ig.p("\t\t%s()\n", ig.cleanupNames[i])
}
ig.p("\t}")
}
if injectSig.err {
ig.p(", nil")
}
ig.p("\n}\n\n")
}
func (ig *injectorGen) funcProviderCall(lname string, c *call, injectSig outputSignature) {
ig.p("\t%s", lname)
prevCleanup := len(ig.cleanupNames)
if c.hasCleanup {
cname := disambiguate("cleanup", ig.nameInInjector)
ig.cleanupNames = append(ig.cleanupNames, cname)
ig.p(", %s", cname)
}
if c.hasErr {
ig.p(", %s", ig.errVar)
}
ig.p(" := ")
ig.p("%s(", ig.g.qualifiedID(c.importPath, c.name))
for i, a := range c.args {
if i > 0 {
ig.p(", ")
}
if a < len(ig.paramNames) {
ig.p("%s", ig.paramNames[a])
} else {
ig.p("%s", ig.localNames[a-len(ig.paramNames)])
}
}
ig.p(")\n")
if c.hasErr {
ig.p("\tif %s != nil {\n", ig.errVar)
for i := prevCleanup - 1; i >= 0; i-- {
ig.p("\t\t%s()\n", ig.cleanupNames[i])
}
ig.p("\t\treturn %s", zeroValue(injectSig.out, ig.g.qualifyPkg))
if injectSig.cleanup {
ig.p(", nil")
}
// TODO(light): Give information about failing provider.
ig.p(", err\n")
ig.p("\t}\n")
}
}
func (ig *injectorGen) structProviderCall(lname string, c *call) {
ig.p("\t%s", lname)
ig.p(" := ")
if _, ok := c.out.(*types.Pointer); ok {
ig.p("&")
}
ig.p("%s{\n", ig.g.qualifiedID(c.importPath, c.name))
for i, a := range c.args {
ig.p("\t\t%s: ", c.fieldNames[i])
if a < len(ig.paramNames) {
ig.p("%s", ig.paramNames[a])
} else {
ig.p("%s", ig.localNames[a-len(ig.paramNames)])
}
ig.p(",\n")
}
ig.p("\t}\n")
}
func (ig *injectorGen) valueExpr(lname string, c *call) {
ig.p("\t%s", lname)
ig.p(" := ")
ig.writeAST(c.valueTypeInfo, c.valueExpr)
ig.p("\n")
}
// nameInInjector reports whether name collides with any other identifier
// in the current injector.
func (ig *injectorGen) nameInInjector(name string) bool {
if name == ig.errVar {
return true
}
for _, a := range ig.paramNames {
if a == name {
return true
}
}
for _, l := range ig.localNames {
if l == name {
return true
}
}
for _, l := range ig.cleanupNames {
if l == name {
return true
}
}
return ig.g.nameInFileScope(name)
}
func (ig *injectorGen) p(format string, args ...interface{}) {
if ig.discard {
return
}
ig.g.p(format, args...)
}
func (ig *injectorGen) writeAST(info *types.Info, node ast.Node) {
node = ig.g.rewritePkgRefs(info, node)
if ig.discard {
return
}
if err := printer.Fprint(&ig.g.buf, ig.g.prog.Fset, node); err != nil {
panic(err)
}
}
// zeroValue returns the shortest expression that evaluates to the zero
// value for the given type.
func zeroValue(t types.Type, qf types.Qualifier) string {

View File

@@ -15,6 +15,7 @@
package goose
import (
"errors"
"fmt"
"go/ast"
"go/build"
@@ -384,43 +385,20 @@ func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object {
// processFuncProvider creates a provider for a function declaration.
func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error) {
sig := fn.Type().(*types.Signature)
fpos := fn.Pos()
r := sig.Results()
var hasCleanup, hasErr bool
switch r.Len() {
case 1:
hasCleanup, hasErr = false, false
case 2:
switch t := r.At(1).Type(); {
case types.Identical(t, errorType):
hasCleanup, hasErr = false, true
case types.Identical(t, cleanupType):
hasCleanup, hasErr = true, false
default:
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fset.Position(fpos), fn.Name())
providerSig, err := funcOutput(sig)
if err != nil {
return nil, fmt.Errorf("%v: wrong signature for provider %s: %v", fset.Position(fpos), fn.Name(), err)
}
case 3:
if t := r.At(1).Type(); !types.Identical(t, cleanupType) {
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fset.Position(fpos), fn.Name())
}
if t := r.At(2).Type(); !types.Identical(t, errorType) {
return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fset.Position(fpos), fn.Name())
}
hasCleanup, hasErr = true, true
default:
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name())
}
out := r.At(0).Type()
params := sig.Params()
provider := &Provider{
ImportPath: fn.Pkg().Path(),
Name: fn.Name(),
Pos: fn.Pos(),
Args: make([]ProviderInput, params.Len()),
Out: out,
HasCleanup: hasCleanup,
HasErr: hasErr,
Out: providerSig.out,
HasCleanup: providerSig.cleanup,
HasErr: providerSig.err,
}
for i := 0; i < params.Len(); i++ {
provider.Args[i] = ProviderInput{
@@ -435,6 +413,47 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error)
return provider, nil
}
type outputSignature struct {
out types.Type
cleanup bool
err bool
}
// funcOutput validates an injector or provider function's return signature.
func funcOutput(sig *types.Signature) (outputSignature, error) {
results := sig.Results()
switch results.Len() {
case 0:
return outputSignature{}, errors.New("no return values")
case 1:
return outputSignature{out: results.At(0).Type()}, nil
case 2:
out := results.At(0).Type()
switch t := results.At(1).Type(); {
case types.Identical(t, errorType):
return outputSignature{out: out, err: true}, nil
case types.Identical(t, cleanupType):
return outputSignature{out: out, cleanup: true}, nil
default:
return outputSignature{}, fmt.Errorf("second return type is %s; must be error or func()", types.TypeString(t, nil))
}
case 3:
if t := results.At(1).Type(); !types.Identical(t, cleanupType) {
return outputSignature{}, fmt.Errorf("second return type is %s; must be func()", types.TypeString(t, nil))
}
if t := results.At(2).Type(); !types.Identical(t, errorType) {
return outputSignature{}, fmt.Errorf("third return type is %s; must be error", types.TypeString(t, nil))
}
return outputSignature{
out: results.At(0).Type(),
cleanup: true,
err: true,
}, nil
default:
return outputSignature{}, errors.New("too many return values")
}
}
// processStructProvider creates a provider for a named struct type.
// It only produces the non-pointer variant.
func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, error) {