Files
wire/internal/goose/goose.go
Ross Light f8e446fa17 goose: use marker functions instead of comments
To avoid making this CL too large, I did not migrate the existing goose
comments through the repository.  This will be addressed in a subsequent
CL.

Reviewed-by: Tuo Shan <shantuo@google.com>
2018-11-12 14:09:56 -08:00

481 lines
12 KiB
Go

// Package goose provides compile-time dependency injection logic as a
// Go library.
package goose
import (
"bytes"
"fmt"
"go/ast"
"go/build"
"go/format"
"go/token"
"go/types"
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/tools/go/loader"
)
// Generate performs dependency injection for a single package,
// returning the gofmt'd Go source code.
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
mainPkg, err := bctx.Import(pkg, wd, build.FindOnly)
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
// TODO(light): Stop errors from printing to stderr.
conf := &loader.Config{
Build: new(build.Context),
Cwd: wd,
TypeCheckFuncBodies: func(path string) bool {
return path == mainPkg.ImportPath
},
}
*conf.Build = *bctx
n := len(conf.Build.BuildTags)
// TODO(light): Only apply gooseinject build tag on main package.
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
conf.Import(pkg)
prog, err := conf.Load()
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
if len(prog.InitialPackages()) != 1 {
// This is more of a violated precondition than anything else.
return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))
}
pkgInfo := prog.InitialPackages()[0]
g := newGen(prog, pkgInfo.Pkg.Path())
oc := newObjectCache(prog)
for _, f := range pkgInfo.Files {
for _, decl := range f.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
useCall := isInjector(&pkgInfo.Info, fn)
if useCall == nil {
continue
}
set, err := oc.processNewSet(pkgInfo, useCall)
if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
}
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
if err := g.inject(prog.Fset, fn.Name.Name, sig, set); err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
}
}
}
goSrc := g.frame()
fmtSrc, err := format.Source(goSrc)
if err != nil {
// This is likely a bug from a poorly generated source file.
// Return an error and the unformatted source.
return goSrc, err
}
return fmtSrc, nil
}
// gen is the generator state.
type gen struct {
currPackage string
buf bytes.Buffer
imports map[string]string
prog *loader.Program // for determining package names
}
func newGen(prog *loader.Program, pkg string) *gen {
return &gen{
currPackage: pkg,
imports: make(map[string]string),
prog: prog,
}
}
// frame bakes the built up source body into an unformatted Go source file.
func (g *gen) frame() []byte {
if g.buf.Len() == 0 {
return nil
}
var buf bytes.Buffer
buf.WriteString("// Code generated by goose. DO NOT EDIT.\n\n//+build !gooseinject\n\npackage ")
buf.WriteString(g.prog.Package(g.currPackage).Pkg.Name())
buf.WriteString("\n\n")
if len(g.imports) > 0 {
buf.WriteString("import (\n")
imps := make([]string, 0, len(g.imports))
for path := range g.imports {
imps = append(imps, path)
}
sort.Strings(imps)
for _, path := range imps {
// TODO(light): Omit the local package identifier if it matches
// the package name.
fmt.Fprintf(&buf, "\t%s %q\n", g.imports[path], path)
}
buf.WriteString(")\n\n")
}
buf.Write(g.buf.Bytes())
return buf.Bytes()
}
// inject emits the code for an injector.
func (g *gen) inject(fset *token.FileSet, name string, sig *types.Signature, set *ProviderSet) error {
results := sig.Results()
var returnsCleanup, returnsErr bool
switch results.Len() {
case 0:
return fmt.Errorf("inject %s: no return values", name)
case 1:
returnsCleanup, returnsErr = false, false
case 2:
switch t := results.At(1).Type(); {
case types.Identical(t, errorType):
returnsCleanup, returnsErr = false, true
case types.Identical(t, cleanupType):
returnsCleanup, returnsErr = true, false
default:
return fmt.Errorf("inject %s: second return type is %s; must be error or func()", name, types.TypeString(t, nil))
}
case 3:
if t := results.At(1).Type(); !types.Identical(t, cleanupType) {
return fmt.Errorf("inject %s: second return type is %s; must be func()", name, types.TypeString(t, nil))
}
if t := results.At(2).Type(); !types.Identical(t, errorType) {
return fmt.Errorf("inject %s: third return type is %s; must be error", name, types.TypeString(t, nil))
}
returnsCleanup, returnsErr = true, true
default:
return fmt.Errorf("inject %s: too many return values", name)
}
outType := results.At(0).Type()
params := sig.Params()
given := make([]types.Type, params.Len())
for i := 0; i < params.Len(); i++ {
given[i] = params.At(i).Type()
}
calls, err := solve(fset, outType, given, set)
if err != nil {
return err
}
for i := range calls {
if calls[i].hasCleanup && !returnsCleanup {
return fmt.Errorf("inject %s: provider for %s returns cleanup but injection does not return cleanup function", name, types.TypeString(calls[i].out, nil))
}
if calls[i].hasErr && !returnsErr {
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
}
}
// Prequalify all types. Since import disambiguation ignores local
// variables, it takes precedence.
paramTypes := make([]string, params.Len())
for i := 0; i < params.Len(); i++ {
paramTypes[i] = types.TypeString(params.At(i).Type(), g.qualifyPkg)
}
for _, c := range calls {
g.qualifyImport(c.importPath)
if !c.isStruct {
// Struct providers just omit zero-valued fields.
continue
}
for i := range c.args {
if c.args[i] == -1 {
zeroValue(c.ins[i], g.qualifyPkg)
}
}
}
outTypeString := types.TypeString(outType, g.qualifyPkg)
zv := zeroValue(outType, g.qualifyPkg)
// Set up local variables.
paramNames := make([]string, params.Len())
localNames := make([]string, len(calls))
cleanupNames := make([]string, len(calls))
errVar := disambiguate("err", g.nameInFileScope)
collides := func(v string) bool {
if v == errVar {
return true
}
for _, a := range paramNames {
if a == v {
return true
}
}
for _, l := range localNames {
if l == v {
return true
}
}
for _, l := range cleanupNames {
if l == v {
return true
}
}
return g.nameInFileScope(v)
}
g.p("func %s(", name)
for i := 0; i < params.Len(); i++ {
if i > 0 {
g.p(", ")
}
pi := params.At(i)
a := pi.Name()
if a == "" || a == "_" {
a = typeVariableName(pi.Type())
if a == "" {
a = "arg"
}
}
paramNames[i] = disambiguate(a, collides)
g.p("%s %s", paramNames[i], paramTypes[i])
}
if returnsCleanup && returnsErr {
g.p(") (%s, func(), error) {\n", outTypeString)
} else if returnsCleanup {
g.p(") (%s, func()) {\n", outTypeString)
} else if returnsErr {
g.p(") (%s, error) {\n", outTypeString)
} else {
g.p(") %s {\n", outTypeString)
}
for i := range calls {
c := &calls[i]
lname := typeVariableName(c.out)
if lname == "" {
lname = "v"
}
lname = disambiguate(lname, collides)
localNames[i] = lname
g.p("\t%s", lname)
if c.hasCleanup {
cleanupNames[i] = disambiguate("cleanup", collides)
g.p(", %s", cleanupNames[i])
}
if c.hasErr {
g.p(", %s", errVar)
}
g.p(" := ")
if c.isStruct {
if _, ok := c.out.(*types.Pointer); ok {
g.p("&")
}
g.p("%s{\n", g.qualifiedID(c.importPath, c.name))
for j, a := range c.args {
if a == -1 {
// Omit zero value fields from composite literal.
continue
}
g.p("\t\t%s: ", c.fieldNames[j])
if a < params.Len() {
g.p("%s", paramNames[a])
} else {
g.p("%s", localNames[a-params.Len()])
}
g.p(",\n")
}
g.p("\t}\n")
} else {
g.p("%s(", g.qualifiedID(c.importPath, c.name))
for j, a := range c.args {
if j > 0 {
g.p(", ")
}
if a == -1 {
g.p("%s", zeroValue(c.ins[j], g.qualifyPkg))
} else if a < params.Len() {
g.p("%s", paramNames[a])
} else {
g.p("%s", localNames[a-params.Len()])
}
}
g.p(")\n")
}
if c.hasErr {
g.p("\tif %s != nil {\n", errVar)
for j := i - 1; j >= 0; j-- {
if calls[j].hasCleanup {
g.p("\t\t%s()\n", cleanupNames[j])
}
}
g.p("\t\treturn %s", zv)
if returnsCleanup {
g.p(", nil")
}
// TODO(light): Give information about failing provider.
g.p(", err\n")
g.p("\t}\n")
}
}
if len(calls) == 0 {
for i := range given {
if types.Identical(outType, given[i]) {
g.p("\treturn %s", paramNames[i])
break
}
}
} else {
g.p("\treturn %s", localNames[len(calls)-1])
}
if returnsCleanup {
g.p(", func() {\n")
for i := len(calls) - 1; i >= 0; i-- {
if calls[i].hasCleanup {
g.p("\t\t%s()\n", cleanupNames[i])
}
}
g.p("\t}")
}
if returnsErr {
g.p(", nil")
}
g.p("\n}\n")
return nil
}
func (g *gen) qualifiedID(path, sym string) string {
name := g.qualifyImport(path)
if name == "" {
return sym
}
return name + "." + sym
}
func (g *gen) qualifyImport(path string) string {
if path == g.currPackage {
return ""
}
// TODO(light): This is depending on details of the current loader.
const vendorPart = "vendor/"
unvendored := path
if i := strings.LastIndex(path, vendorPart); i != -1 && (i == 0 || path[i-1] == '/') {
unvendored = path[i+len(vendorPart):]
}
if name := g.imports[unvendored]; name != "" {
return name
}
// TODO(light): Use parts of import path to disambiguate.
name := disambiguate(g.prog.Package(path).Pkg.Name(), func(n string) bool {
// Don't let an import take the "err" name. That's annoying.
return n == "err" || g.nameInFileScope(n)
})
g.imports[unvendored] = name
return name
}
func (g *gen) nameInFileScope(name string) bool {
for _, other := range g.imports {
if other == name {
return true
}
}
_, obj := g.prog.Package(g.currPackage).Pkg.Scope().LookupParent(name, 0)
return obj != nil
}
func (g *gen) qualifyPkg(pkg *types.Package) string {
return g.qualifyImport(pkg.Path())
}
func (g *gen) p(format string, args ...interface{}) {
fmt.Fprintf(&g.buf, format, args...)
}
// zeroValue returns the shortest expression that evaluates to the zero
// value for the given type.
func zeroValue(t types.Type, qf types.Qualifier) string {
switch u := t.Underlying().(type) {
case *types.Array, *types.Struct:
return types.TypeString(t, qf) + "{}"
case *types.Basic:
info := u.Info()
switch {
case info&types.IsBoolean != 0:
return "false"
case info&(types.IsInteger|types.IsFloat|types.IsComplex) != 0:
return "0"
case info&types.IsString != 0:
return `""`
default:
panic("unreachable")
}
case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice:
return "nil"
default:
panic("unreachable")
}
}
// typeVariableName invents a variable name derived from the type name
// or returns the empty string if one could not be found.
func typeVariableName(t types.Type) string {
if p, ok := t.(*types.Pointer); ok {
t = p.Elem()
}
tn, ok := t.(*types.Named)
if !ok {
return ""
}
// TODO(light): Include package name when appropriate.
return unexport(tn.Obj().Name())
}
// unexport converts a name that is potentially exported to an unexported name.
func unexport(name string) string {
r, sz := utf8.DecodeRuneInString(name)
if !unicode.IsUpper(r) {
// foo -> foo
return name
}
r2, sz2 := utf8.DecodeRuneInString(name[sz:])
if !unicode.IsUpper(r2) {
// Foo -> foo
return string(unicode.ToLower(r)) + name[sz:]
}
// UPPERWord -> upperWord
sbuf := new(strings.Builder)
sbuf.WriteRune(unicode.ToLower(r))
i := sz
r, sz = r2, sz2
for unicode.IsUpper(r) && sz > 0 {
r2, sz2 := utf8.DecodeRuneInString(name[i+sz:])
if sz2 > 0 && unicode.IsLower(r2) {
break
}
i += sz
sbuf.WriteRune(unicode.ToLower(r))
r, sz = r2, sz2
}
sbuf.WriteString(name[i:])
return sbuf.String()
}
// disambiguate picks a unique name, preferring name if it is already unique.
func disambiguate(name string, collides func(string) bool) string {
if !collides(name) {
return name
}
buf := []byte(name)
if len(buf) > 0 && buf[len(buf)-1] >= '0' && buf[len(buf)-1] <= '9' {
buf = append(buf, '_')
}
base := len(buf)
for n := 2; ; n++ {
buf = strconv.AppendInt(buf[:base], int64(n), 10)
sbuf := string(buf)
if !collides(sbuf) {
return sbuf
}
}
}
var (
errorType = types.Universe.Lookup("error").Type()
cleanupType = types.NewSignature(nil, nil, nil, false)
)