goose: refactor *gen.inject
The function had grown too long. Several related cleanups: - Factored out the function return value logic, which had been duplicated between providers and injectors. - Moved code generation for different provider call types into separate functions. This moves injector-specific state to a new type injectorGen to keep the parameter count down. - Since it's infeasible to keep the "shadow pass" collecting import identifiers in sync the spread out logic, the injector code generation is just run twice, with initial output discarded. - Removed the zero value logic left over from Optional. Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
@@ -48,9 +48,8 @@ type call struct {
|
|||||||
name string
|
name string
|
||||||
|
|
||||||
// args is a list of arguments to call the provider with. Each element is:
|
// args is a list of arguments to call the provider with. Each element is:
|
||||||
// a) one of the givens (args[i] < len(given)),
|
// a) one of the givens (args[i] < len(given)), or
|
||||||
// b) the result of a previous provider call (args[i] >= len(given)), or
|
// b) the result of a previous provider call (args[i] >= len(given))
|
||||||
// c) the zero value for the type (args[i] == -1).
|
|
||||||
//
|
//
|
||||||
// This will be nil for kind == valueExpr.
|
// This will be nil for kind == valueExpr.
|
||||||
args []int
|
args []int
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*as
|
|||||||
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
|
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
|
||||||
}
|
}
|
||||||
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||||
if err := g.inject(g.prog.Fset, fn.Name.Name, sig, set); err != nil {
|
if err := g.inject(fn.Name.Name, sig, set); err != nil {
|
||||||
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
|
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -140,18 +140,18 @@ func copyNonInjectorDecls(g *gen, files []*ast.File, info *types.Info) {
|
|||||||
first = false
|
first = false
|
||||||
}
|
}
|
||||||
// TODO(light): Add line number at top of each declaration.
|
// TODO(light): Add line number at top of each declaration.
|
||||||
g.writeAST(g.prog.Fset, info, decl)
|
g.writeAST(info, decl)
|
||||||
g.p("\n\n")
|
g.p("\n\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// gen is the generator state.
|
// gen is the file-wide generator state.
|
||||||
type gen struct {
|
type gen struct {
|
||||||
currPackage string
|
currPackage string
|
||||||
buf bytes.Buffer
|
buf bytes.Buffer
|
||||||
imports map[string]string
|
imports map[string]string
|
||||||
prog *loader.Program // for determining package names
|
prog *loader.Program // for positions and determining package names
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGen(prog *loader.Program, pkg string) *gen {
|
func newGen(prog *loader.Program, pkg string) *gen {
|
||||||
@@ -190,237 +190,54 @@ func (g *gen) frame() []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// inject emits the code for an injector.
|
// inject emits the code for an injector.
|
||||||
func (g *gen) inject(fset *token.FileSet, name string, sig *types.Signature, set *ProviderSet) error {
|
func (g *gen) inject(name string, sig *types.Signature, set *ProviderSet) error {
|
||||||
results := sig.Results()
|
injectSig, err := funcOutput(sig)
|
||||||
var returnsCleanup, returnsErr bool
|
if err != nil {
|
||||||
switch results.Len() {
|
return fmt.Errorf("inject %s: %v", name, err)
|
||||||
case 0:
|
|
||||||
return fmt.Errorf("inject %s: no return values", name)
|
|
||||||
case 1:
|
|
||||||
returnsCleanup, returnsErr = false, false
|
|
||||||
case 2:
|
|
||||||
switch t := results.At(1).Type(); {
|
|
||||||
case types.Identical(t, errorType):
|
|
||||||
returnsCleanup, returnsErr = false, true
|
|
||||||
case types.Identical(t, cleanupType):
|
|
||||||
returnsCleanup, returnsErr = true, false
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("inject %s: second return type is %s; must be error or func()", name, types.TypeString(t, nil))
|
|
||||||
}
|
}
|
||||||
case 3:
|
|
||||||
if t := results.At(1).Type(); !types.Identical(t, cleanupType) {
|
|
||||||
return fmt.Errorf("inject %s: second return type is %s; must be func()", name, types.TypeString(t, nil))
|
|
||||||
}
|
|
||||||
if t := results.At(2).Type(); !types.Identical(t, errorType) {
|
|
||||||
return fmt.Errorf("inject %s: third return type is %s; must be error", name, types.TypeString(t, nil))
|
|
||||||
}
|
|
||||||
returnsCleanup, returnsErr = true, true
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("inject %s: too many return values", name)
|
|
||||||
}
|
|
||||||
outType := results.At(0).Type()
|
|
||||||
params := sig.Params()
|
params := sig.Params()
|
||||||
given := make([]types.Type, params.Len())
|
given := make([]types.Type, params.Len())
|
||||||
for i := 0; i < params.Len(); i++ {
|
for i := 0; i < params.Len(); i++ {
|
||||||
given[i] = params.At(i).Type()
|
given[i] = params.At(i).Type()
|
||||||
}
|
}
|
||||||
calls, err := solve(fset, outType, given, set)
|
calls, err := solve(g.prog.Fset, injectSig.out, given, set)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i := range calls {
|
for i := range calls {
|
||||||
if calls[i].hasCleanup && !returnsCleanup {
|
c := &calls[i]
|
||||||
return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(calls[i].out, nil))
|
if c.hasCleanup && !injectSig.cleanup {
|
||||||
|
return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(c.out, nil))
|
||||||
}
|
}
|
||||||
if calls[i].hasErr && !returnsErr {
|
if c.hasErr && !injectSig.err {
|
||||||
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
|
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(c.out, nil))
|
||||||
}
|
}
|
||||||
}
|
if c.kind == valueExpr {
|
||||||
|
|
||||||
// Prequalify all types. Since import disambiguation ignores local
|
|
||||||
// variables, it takes precedence.
|
|
||||||
paramTypes := make([]string, params.Len())
|
|
||||||
for i := 0; i < params.Len(); i++ {
|
|
||||||
paramTypes[i] = types.TypeString(params.At(i).Type(), g.qualifyPkg)
|
|
||||||
}
|
|
||||||
for _, c := range calls {
|
|
||||||
switch c.kind {
|
|
||||||
case funcProviderCall:
|
|
||||||
g.qualifyImport(c.importPath)
|
|
||||||
for i := range c.args {
|
|
||||||
if c.args[i] == -1 {
|
|
||||||
zeroValue(c.ins[i], g.qualifyPkg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case structProvider:
|
|
||||||
g.qualifyImport(c.importPath)
|
|
||||||
case valueExpr:
|
|
||||||
if err := accessibleFrom(c.valueTypeInfo, c.valueExpr, g.currPackage); err != nil {
|
if err := accessibleFrom(c.valueTypeInfo, c.valueExpr, g.currPackage); err != nil {
|
||||||
// TODO(light): Display line number of value expression.
|
// TODO(light): Display line number of value expression.
|
||||||
ts := types.TypeString(c.out, nil)
|
ts := types.TypeString(c.out, nil)
|
||||||
return fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err)
|
return fmt.Errorf("inject %s: value %s can't be used: %v", name, ts, err)
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
panic("unknown kind")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outTypeString := types.TypeString(outType, g.qualifyPkg)
|
|
||||||
zv := zeroValue(outType, g.qualifyPkg)
|
|
||||||
// Set up local variables.
|
|
||||||
paramNames := make([]string, params.Len())
|
|
||||||
localNames := make([]string, len(calls))
|
|
||||||
cleanupNames := make([]string, len(calls))
|
|
||||||
errVar := disambiguate("err", g.nameInFileScope)
|
|
||||||
collides := func(v string) bool {
|
|
||||||
if v == errVar {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for _, a := range paramNames {
|
|
||||||
if a == v {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, l := range localNames {
|
|
||||||
if l == v {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, l := range cleanupNames {
|
|
||||||
if l == v {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return g.nameInFileScope(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
g.p("func %s(", name)
|
// Perform one pass to collect all imports, followed by the real pass.
|
||||||
for i := 0; i < params.Len(); i++ {
|
injectPass(name, params, injectSig, calls, &injectorGen{
|
||||||
if i > 0 {
|
g: g,
|
||||||
g.p(", ")
|
errVar: disambiguate("err", g.nameInFileScope),
|
||||||
}
|
discard: true,
|
||||||
pi := params.At(i)
|
})
|
||||||
a := pi.Name()
|
injectPass(name, params, injectSig, calls, &injectorGen{
|
||||||
if a == "" || a == "_" {
|
g: g,
|
||||||
a = typeVariableName(pi.Type())
|
errVar: disambiguate("err", g.nameInFileScope),
|
||||||
if a == "" {
|
discard: false,
|
||||||
a = "arg"
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
paramNames[i] = disambiguate(a, collides)
|
|
||||||
g.p("%s %s", paramNames[i], paramTypes[i])
|
|
||||||
}
|
|
||||||
if returnsCleanup && returnsErr {
|
|
||||||
g.p(") (%s, func(), error) {\n", outTypeString)
|
|
||||||
} else if returnsCleanup {
|
|
||||||
g.p(") (%s, func()) {\n", outTypeString)
|
|
||||||
} else if returnsErr {
|
|
||||||
g.p(") (%s, error) {\n", outTypeString)
|
|
||||||
} else {
|
|
||||||
g.p(") %s {\n", outTypeString)
|
|
||||||
}
|
|
||||||
for i := range calls {
|
|
||||||
c := &calls[i]
|
|
||||||
lname := typeVariableName(c.out)
|
|
||||||
if lname == "" {
|
|
||||||
lname = "v"
|
|
||||||
}
|
|
||||||
lname = disambiguate(lname, collides)
|
|
||||||
localNames[i] = lname
|
|
||||||
g.p("\t%s", lname)
|
|
||||||
if c.hasCleanup {
|
|
||||||
cleanupNames[i] = disambiguate("cleanup", collides)
|
|
||||||
g.p(", %s", cleanupNames[i])
|
|
||||||
}
|
|
||||||
if c.hasErr {
|
|
||||||
g.p(", %s", errVar)
|
|
||||||
}
|
|
||||||
g.p(" := ")
|
|
||||||
switch c.kind {
|
|
||||||
case structProvider:
|
|
||||||
if _, ok := c.out.(*types.Pointer); ok {
|
|
||||||
g.p("&")
|
|
||||||
}
|
|
||||||
g.p("%s{\n", g.qualifiedID(c.importPath, c.name))
|
|
||||||
for j, a := range c.args {
|
|
||||||
if a == -1 {
|
|
||||||
// Omit zero value fields from composite literal.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
g.p("\t\t%s: ", c.fieldNames[j])
|
|
||||||
if a < params.Len() {
|
|
||||||
g.p("%s", paramNames[a])
|
|
||||||
} else {
|
|
||||||
g.p("%s", localNames[a-params.Len()])
|
|
||||||
}
|
|
||||||
g.p(",\n")
|
|
||||||
}
|
|
||||||
g.p("\t}\n")
|
|
||||||
case funcProviderCall:
|
|
||||||
g.p("%s(", g.qualifiedID(c.importPath, c.name))
|
|
||||||
for j, a := range c.args {
|
|
||||||
if j > 0 {
|
|
||||||
g.p(", ")
|
|
||||||
}
|
|
||||||
if a == -1 {
|
|
||||||
g.p("%s", zeroValue(c.ins[j], g.qualifyPkg))
|
|
||||||
} else if a < params.Len() {
|
|
||||||
g.p("%s", paramNames[a])
|
|
||||||
} else {
|
|
||||||
g.p("%s", localNames[a-params.Len()])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
g.p(")\n")
|
|
||||||
case valueExpr:
|
|
||||||
g.writeAST(fset, c.valueTypeInfo, c.valueExpr)
|
|
||||||
g.p("\n")
|
|
||||||
default:
|
|
||||||
panic("unknown kind")
|
|
||||||
}
|
|
||||||
if c.hasErr {
|
|
||||||
g.p("\tif %s != nil {\n", errVar)
|
|
||||||
for j := i - 1; j >= 0; j-- {
|
|
||||||
if calls[j].hasCleanup {
|
|
||||||
g.p("\t\t%s()\n", cleanupNames[j])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
g.p("\t\treturn %s", zv)
|
|
||||||
if returnsCleanup {
|
|
||||||
g.p(", nil")
|
|
||||||
}
|
|
||||||
// TODO(light): Give information about failing provider.
|
|
||||||
g.p(", err\n")
|
|
||||||
g.p("\t}\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(calls) == 0 {
|
|
||||||
for i := range given {
|
|
||||||
if types.Identical(outType, given[i]) {
|
|
||||||
g.p("\treturn %s", paramNames[i])
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
g.p("\treturn %s", localNames[len(calls)-1])
|
|
||||||
}
|
|
||||||
if returnsCleanup {
|
|
||||||
g.p(", func() {\n")
|
|
||||||
for i := len(calls) - 1; i >= 0; i-- {
|
|
||||||
if calls[i].hasCleanup {
|
|
||||||
g.p("\t\t%s()\n", cleanupNames[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
g.p("\t}")
|
|
||||||
}
|
|
||||||
if returnsErr {
|
|
||||||
g.p(", nil")
|
|
||||||
}
|
|
||||||
g.p("\n}\n\n")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeAST prints an AST node into the generated output, rewriting any
|
// rewritePkgRefs rewrites any package references in an AST into references for the
|
||||||
// package references it encounters.
|
// generated package.
|
||||||
func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) {
|
func (g *gen) rewritePkgRefs(info *types.Info, node ast.Node) ast.Node {
|
||||||
start, end := node.Pos(), node.End()
|
start, end := node.Pos(), node.End()
|
||||||
node = copyAST(node)
|
node = copyAST(node)
|
||||||
// First, rewrite all package names. This lets us know all the
|
// First, rewrite all package names. This lets us know all the
|
||||||
@@ -500,7 +317,7 @@ func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rename any symbols defined within writeAST's node that conflict
|
// Rename any symbols defined within rewritePkgRefs's node that conflict
|
||||||
// with any symbols in the generated file.
|
// with any symbols in the generated file.
|
||||||
objName := obj.Name()
|
objName := obj.Name()
|
||||||
if pos := obj.Pos(); pos < start || end <= pos || !(g.nameInFileScope(objName) || inNewNames(objName)) {
|
if pos := obj.Pos(); pos < start || end <= pos || !(g.nameInFileScope(objName) || inNewNames(objName)) {
|
||||||
@@ -530,7 +347,14 @@ func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
if err := printer.Fprint(&g.buf, fset, node); err != nil {
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeAST prints an AST node into the generated output, rewriting any
|
||||||
|
// package references it encounters.
|
||||||
|
func (g *gen) writeAST(info *types.Info, node ast.Node) {
|
||||||
|
node = g.rewritePkgRefs(info, node)
|
||||||
|
if err := printer.Fprint(&g.buf, g.prog.Fset, node); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -583,6 +407,196 @@ func (g *gen) p(format string, args ...interface{}) {
|
|||||||
fmt.Fprintf(&g.buf, format, args...)
|
fmt.Fprintf(&g.buf, format, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// injectorGen is the per-injector pass generator state.
|
||||||
|
type injectorGen struct {
|
||||||
|
g *gen
|
||||||
|
|
||||||
|
paramNames []string
|
||||||
|
localNames []string
|
||||||
|
cleanupNames []string
|
||||||
|
errVar string
|
||||||
|
|
||||||
|
// discard causes ig.p and ig.writeAST to no-op. Useful to run
|
||||||
|
// generation for side-effects like filling in g.imports.
|
||||||
|
discard bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectPass generates an injector given the output from analysis.
|
||||||
|
func injectPass(name string, params *types.Tuple, injectSig outputSignature, calls []call, ig *injectorGen) {
|
||||||
|
ig.p("func %s(", name)
|
||||||
|
for i := 0; i < params.Len(); i++ {
|
||||||
|
if i > 0 {
|
||||||
|
ig.p(", ")
|
||||||
|
}
|
||||||
|
pi := params.At(i)
|
||||||
|
a := pi.Name()
|
||||||
|
if a == "" || a == "_" {
|
||||||
|
a = typeVariableName(pi.Type())
|
||||||
|
if a == "" {
|
||||||
|
a = "arg"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ig.paramNames = append(ig.paramNames, disambiguate(a, ig.nameInInjector))
|
||||||
|
ig.p("%s %s", ig.paramNames[i], types.TypeString(pi.Type(), ig.g.qualifyPkg))
|
||||||
|
}
|
||||||
|
outTypeString := types.TypeString(injectSig.out, ig.g.qualifyPkg)
|
||||||
|
if injectSig.cleanup && injectSig.err {
|
||||||
|
ig.p(") (%s, func(), error) {\n", outTypeString)
|
||||||
|
} else if injectSig.cleanup {
|
||||||
|
ig.p(") (%s, func()) {\n", outTypeString)
|
||||||
|
} else if injectSig.err {
|
||||||
|
ig.p(") (%s, error) {\n", outTypeString)
|
||||||
|
} else {
|
||||||
|
ig.p(") %s {\n", outTypeString)
|
||||||
|
}
|
||||||
|
for i := range calls {
|
||||||
|
c := &calls[i]
|
||||||
|
lname := typeVariableName(c.out)
|
||||||
|
if lname == "" {
|
||||||
|
lname = "v"
|
||||||
|
}
|
||||||
|
lname = disambiguate(lname, ig.nameInInjector)
|
||||||
|
ig.localNames = append(ig.localNames, lname)
|
||||||
|
switch c.kind {
|
||||||
|
case structProvider:
|
||||||
|
ig.structProviderCall(lname, c)
|
||||||
|
case funcProviderCall:
|
||||||
|
ig.funcProviderCall(lname, c, injectSig)
|
||||||
|
case valueExpr:
|
||||||
|
ig.valueExpr(lname, c)
|
||||||
|
default:
|
||||||
|
panic("unknown kind")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(calls) == 0 {
|
||||||
|
for i := 0; i < params.Len(); i++ {
|
||||||
|
if types.Identical(injectSig.out, params.At(i).Type()) {
|
||||||
|
ig.p("\treturn %s", ig.paramNames[i])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ig.p("\treturn %s", ig.localNames[len(calls)-1])
|
||||||
|
}
|
||||||
|
if injectSig.cleanup {
|
||||||
|
ig.p(", func() {\n")
|
||||||
|
for i := len(ig.cleanupNames) - 1; i >= 0; i-- {
|
||||||
|
ig.p("\t\t%s()\n", ig.cleanupNames[i])
|
||||||
|
}
|
||||||
|
ig.p("\t}")
|
||||||
|
}
|
||||||
|
if injectSig.err {
|
||||||
|
ig.p(", nil")
|
||||||
|
}
|
||||||
|
ig.p("\n}\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ig *injectorGen) funcProviderCall(lname string, c *call, injectSig outputSignature) {
|
||||||
|
ig.p("\t%s", lname)
|
||||||
|
prevCleanup := len(ig.cleanupNames)
|
||||||
|
if c.hasCleanup {
|
||||||
|
cname := disambiguate("cleanup", ig.nameInInjector)
|
||||||
|
ig.cleanupNames = append(ig.cleanupNames, cname)
|
||||||
|
ig.p(", %s", cname)
|
||||||
|
}
|
||||||
|
if c.hasErr {
|
||||||
|
ig.p(", %s", ig.errVar)
|
||||||
|
}
|
||||||
|
ig.p(" := ")
|
||||||
|
ig.p("%s(", ig.g.qualifiedID(c.importPath, c.name))
|
||||||
|
for i, a := range c.args {
|
||||||
|
if i > 0 {
|
||||||
|
ig.p(", ")
|
||||||
|
}
|
||||||
|
if a < len(ig.paramNames) {
|
||||||
|
ig.p("%s", ig.paramNames[a])
|
||||||
|
} else {
|
||||||
|
ig.p("%s", ig.localNames[a-len(ig.paramNames)])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ig.p(")\n")
|
||||||
|
if c.hasErr {
|
||||||
|
ig.p("\tif %s != nil {\n", ig.errVar)
|
||||||
|
for i := prevCleanup - 1; i >= 0; i-- {
|
||||||
|
ig.p("\t\t%s()\n", ig.cleanupNames[i])
|
||||||
|
}
|
||||||
|
ig.p("\t\treturn %s", zeroValue(injectSig.out, ig.g.qualifyPkg))
|
||||||
|
if injectSig.cleanup {
|
||||||
|
ig.p(", nil")
|
||||||
|
}
|
||||||
|
// TODO(light): Give information about failing provider.
|
||||||
|
ig.p(", err\n")
|
||||||
|
ig.p("\t}\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ig *injectorGen) structProviderCall(lname string, c *call) {
|
||||||
|
ig.p("\t%s", lname)
|
||||||
|
ig.p(" := ")
|
||||||
|
if _, ok := c.out.(*types.Pointer); ok {
|
||||||
|
ig.p("&")
|
||||||
|
}
|
||||||
|
ig.p("%s{\n", ig.g.qualifiedID(c.importPath, c.name))
|
||||||
|
for i, a := range c.args {
|
||||||
|
ig.p("\t\t%s: ", c.fieldNames[i])
|
||||||
|
if a < len(ig.paramNames) {
|
||||||
|
ig.p("%s", ig.paramNames[a])
|
||||||
|
} else {
|
||||||
|
ig.p("%s", ig.localNames[a-len(ig.paramNames)])
|
||||||
|
}
|
||||||
|
ig.p(",\n")
|
||||||
|
}
|
||||||
|
ig.p("\t}\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ig *injectorGen) valueExpr(lname string, c *call) {
|
||||||
|
ig.p("\t%s", lname)
|
||||||
|
ig.p(" := ")
|
||||||
|
ig.writeAST(c.valueTypeInfo, c.valueExpr)
|
||||||
|
ig.p("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// nameInInjector reports whether name collides with any other identifier
|
||||||
|
// in the current injector.
|
||||||
|
func (ig *injectorGen) nameInInjector(name string) bool {
|
||||||
|
if name == ig.errVar {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, a := range ig.paramNames {
|
||||||
|
if a == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, l := range ig.localNames {
|
||||||
|
if l == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, l := range ig.cleanupNames {
|
||||||
|
if l == name {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ig.g.nameInFileScope(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ig *injectorGen) p(format string, args ...interface{}) {
|
||||||
|
if ig.discard {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ig.g.p(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ig *injectorGen) writeAST(info *types.Info, node ast.Node) {
|
||||||
|
node = ig.g.rewritePkgRefs(info, node)
|
||||||
|
if ig.discard {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := printer.Fprint(&ig.g.buf, ig.g.prog.Fset, node); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// zeroValue returns the shortest expression that evaluates to the zero
|
// zeroValue returns the shortest expression that evaluates to the zero
|
||||||
// value for the given type.
|
// value for the given type.
|
||||||
func zeroValue(t types.Type, qf types.Qualifier) string {
|
func zeroValue(t types.Type, qf types.Qualifier) string {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
package goose
|
package goose
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"go/build"
|
"go/build"
|
||||||
@@ -384,43 +385,20 @@ func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object {
|
|||||||
// processFuncProvider creates a provider for a function declaration.
|
// processFuncProvider creates a provider for a function declaration.
|
||||||
func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error) {
|
func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error) {
|
||||||
sig := fn.Type().(*types.Signature)
|
sig := fn.Type().(*types.Signature)
|
||||||
|
|
||||||
fpos := fn.Pos()
|
fpos := fn.Pos()
|
||||||
r := sig.Results()
|
providerSig, err := funcOutput(sig)
|
||||||
var hasCleanup, hasErr bool
|
if err != nil {
|
||||||
switch r.Len() {
|
return nil, fmt.Errorf("%v: wrong signature for provider %s: %v", fset.Position(fpos), fn.Name(), err)
|
||||||
case 1:
|
|
||||||
hasCleanup, hasErr = false, false
|
|
||||||
case 2:
|
|
||||||
switch t := r.At(1).Type(); {
|
|
||||||
case types.Identical(t, errorType):
|
|
||||||
hasCleanup, hasErr = false, true
|
|
||||||
case types.Identical(t, cleanupType):
|
|
||||||
hasCleanup, hasErr = true, false
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fset.Position(fpos), fn.Name())
|
|
||||||
}
|
}
|
||||||
case 3:
|
|
||||||
if t := r.At(1).Type(); !types.Identical(t, cleanupType) {
|
|
||||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fset.Position(fpos), fn.Name())
|
|
||||||
}
|
|
||||||
if t := r.At(2).Type(); !types.Identical(t, errorType) {
|
|
||||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fset.Position(fpos), fn.Name())
|
|
||||||
}
|
|
||||||
hasCleanup, hasErr = true, true
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name())
|
|
||||||
}
|
|
||||||
out := r.At(0).Type()
|
|
||||||
params := sig.Params()
|
params := sig.Params()
|
||||||
provider := &Provider{
|
provider := &Provider{
|
||||||
ImportPath: fn.Pkg().Path(),
|
ImportPath: fn.Pkg().Path(),
|
||||||
Name: fn.Name(),
|
Name: fn.Name(),
|
||||||
Pos: fn.Pos(),
|
Pos: fn.Pos(),
|
||||||
Args: make([]ProviderInput, params.Len()),
|
Args: make([]ProviderInput, params.Len()),
|
||||||
Out: out,
|
Out: providerSig.out,
|
||||||
HasCleanup: hasCleanup,
|
HasCleanup: providerSig.cleanup,
|
||||||
HasErr: hasErr,
|
HasErr: providerSig.err,
|
||||||
}
|
}
|
||||||
for i := 0; i < params.Len(); i++ {
|
for i := 0; i < params.Len(); i++ {
|
||||||
provider.Args[i] = ProviderInput{
|
provider.Args[i] = ProviderInput{
|
||||||
@@ -435,6 +413,47 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error)
|
|||||||
return provider, nil
|
return provider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type outputSignature struct {
|
||||||
|
out types.Type
|
||||||
|
cleanup bool
|
||||||
|
err bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// funcOutput validates an injector or provider function's return signature.
|
||||||
|
func funcOutput(sig *types.Signature) (outputSignature, error) {
|
||||||
|
results := sig.Results()
|
||||||
|
switch results.Len() {
|
||||||
|
case 0:
|
||||||
|
return outputSignature{}, errors.New("no return values")
|
||||||
|
case 1:
|
||||||
|
return outputSignature{out: results.At(0).Type()}, nil
|
||||||
|
case 2:
|
||||||
|
out := results.At(0).Type()
|
||||||
|
switch t := results.At(1).Type(); {
|
||||||
|
case types.Identical(t, errorType):
|
||||||
|
return outputSignature{out: out, err: true}, nil
|
||||||
|
case types.Identical(t, cleanupType):
|
||||||
|
return outputSignature{out: out, cleanup: true}, nil
|
||||||
|
default:
|
||||||
|
return outputSignature{}, fmt.Errorf("second return type is %s; must be error or func()", types.TypeString(t, nil))
|
||||||
|
}
|
||||||
|
case 3:
|
||||||
|
if t := results.At(1).Type(); !types.Identical(t, cleanupType) {
|
||||||
|
return outputSignature{}, fmt.Errorf("second return type is %s; must be func()", types.TypeString(t, nil))
|
||||||
|
}
|
||||||
|
if t := results.At(2).Type(); !types.Identical(t, errorType) {
|
||||||
|
return outputSignature{}, fmt.Errorf("third return type is %s; must be error", types.TypeString(t, nil))
|
||||||
|
}
|
||||||
|
return outputSignature{
|
||||||
|
out: results.At(0).Type(),
|
||||||
|
cleanup: true,
|
||||||
|
err: true,
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
return outputSignature{}, errors.New("too many return values")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// processStructProvider creates a provider for a named struct type.
|
// processStructProvider creates a provider for a named struct type.
|
||||||
// It only produces the non-pointer variant.
|
// It only produces the non-pointer variant.
|
||||||
func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, error) {
|
func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user