Files
wire/internal/goose/goose.go
Ross Light cb1853b1af goose: use readable variable names
Names are inferred from types most of the time, but have a fallback for
a non-named type. Names are now also disambiguated from symbols in the
same scope.

Reviewed-by: Tuo Shan <shantuo@google.com>
Reviewed-by: Herbie Ong <herbie@google.com>
2018-11-12 14:09:55 -08:00

826 lines
23 KiB
Go

// Package goose provides compile-time dependency injection logic as a
// Go library.
package goose
import (
"bytes"
"fmt"
"go/ast"
"go/build"
"go/format"
"go/parser"
"go/token"
"go/types"
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/types/typeutil"
)
// Generate performs dependency injection for a single package,
// returning the gofmt'd Go source code.
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
// TODO(light): allow errors
// TODO(light): stop errors from printing to stderr
conf := &loader.Config{
Build: new(build.Context),
ParserMode: parser.ParseComments,
Cwd: wd,
TypeCheckFuncBodies: func(string) bool { return false },
}
*conf.Build = *bctx
n := len(conf.Build.BuildTags)
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
conf.Import(pkg)
prog, err := conf.Load()
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
if len(prog.InitialPackages()) != 1 {
// This is more of a violated precondition than anything else.
return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))
}
pkgInfo := prog.InitialPackages()[0]
g := newGen(prog, pkgInfo.Pkg.Path())
mc := newProviderSetCache(prog)
var directives []directive
for _, f := range pkgInfo.Files {
if !isInjectFile(f) {
continue
}
// TODO(light): use same directive extraction logic as provider set finding.
fileScope := pkgInfo.Scopes[f]
cmap := ast.NewCommentMap(prog.Fset, f, f.Comments)
for _, decl := range f.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
directives = directives[:0]
for _, c := range cmap[fn] {
directives = extractDirectives(directives, c)
}
sets := make([]providerSetRef, 0, len(directives))
for _, d := range directives {
if d.kind != "use" {
return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind)
}
ref, err := parseProviderSetRef(d.line, fileScope, g.currPackage, d.pos)
if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err)
}
sets = append(sets, ref)
}
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
if err := g.inject(mc, fn.Name.Name, sig, sets); err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
}
}
}
goSrc := g.frame()
fmtSrc, err := format.Source(goSrc)
if err != nil {
// This is likely a bug from a poorly generated source file.
// Return an error and the unformatted source.
return goSrc, err
}
return fmtSrc, nil
}
// gen is the generator state.
type gen struct {
currPackage string
buf bytes.Buffer
imports map[string]string
prog *loader.Program // for determining package names
}
func newGen(prog *loader.Program, pkg string) *gen {
return &gen{
currPackage: pkg,
imports: make(map[string]string),
prog: prog,
}
}
// frame bakes the built up source body into an unformatted Go source file.
func (g *gen) frame() []byte {
if g.buf.Len() == 0 {
return nil
}
var buf bytes.Buffer
buf.WriteString("// Code generated by goose. DO NOT EDIT.\n\n//+build !gooseinject\n\npackage ")
buf.WriteString(g.prog.Package(g.currPackage).Pkg.Name())
buf.WriteString("\n\n")
if len(g.imports) > 0 {
buf.WriteString("import (\n")
imps := make([]string, 0, len(g.imports))
for path := range g.imports {
imps = append(imps, path)
}
sort.Strings(imps)
for _, path := range imps {
// TODO(light): Omit the local package identifier if it matches
// the package name.
fmt.Fprintf(&buf, "\t%s %q\n", g.imports[path], path)
}
buf.WriteString(")\n\n")
}
buf.Write(g.buf.Bytes())
return buf.Bytes()
}
// inject emits the code for an injector.
func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []providerSetRef) error {
results := sig.Results()
returnsErr := false
switch results.Len() {
case 0:
return fmt.Errorf("inject %s: no return values", name)
case 1:
// nothing special
case 2:
if t := results.At(1).Type(); !types.Identical(t, errorType) {
return fmt.Errorf("inject %s: second return type is %s; must be error", name, types.TypeString(t, nil))
}
returnsErr = true
default:
return fmt.Errorf("inject %s: too many return values", name)
}
outType := results.At(0).Type()
params := sig.Params()
given := make([]types.Type, params.Len())
for i := 0; i < params.Len(); i++ {
given[i] = params.At(i).Type()
}
calls, err := solve(mc, outType, given, sets)
if err != nil {
return err
}
for i := range calls {
if calls[i].hasErr && !returnsErr {
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
}
}
// Prequalify all types. Since import disambiguation ignores local
// variables, it takes precedence.
paramTypes := make([]string, params.Len())
for i := 0; i < params.Len(); i++ {
paramTypes[i] = types.TypeString(params.At(i).Type(), g.qualifyPkg)
}
for _, c := range calls {
g.qualifyImport(c.importPath)
}
outTypeString := types.TypeString(outType, g.qualifyPkg)
zv := zeroValue(outType, g.qualifyPkg)
// Set up local variables
paramNames := make([]string, params.Len())
localNames := make([]string, len(calls))
errVar := disambiguate("err", g.nameInFileScope)
collides := func(v string) bool {
if v == errVar {
return true
}
for _, a := range paramNames {
if a == v {
return true
}
}
for _, l := range localNames {
if l == v {
return true
}
}
return g.nameInFileScope(v)
}
g.p("func %s(", name)
for i := 0; i < params.Len(); i++ {
if i > 0 {
g.p(", ")
}
pi := params.At(i)
a := pi.Name()
if a == "" || a == "_" {
a = typeVariableName(pi.Type())
if a == "" {
a = "arg"
}
}
paramNames[i] = disambiguate(a, collides)
g.p("%s %s", paramNames[i], paramTypes[i])
}
if returnsErr {
g.p(") (%s, error) {\n", outTypeString)
} else {
g.p(") %s {\n", outTypeString)
}
for i := range calls {
c := &calls[i]
lname := typeVariableName(c.out)
if lname == "" {
lname = "v"
}
lname = disambiguate(lname, collides)
localNames[i] = lname
g.p("\t%s", lname)
if c.hasErr {
g.p(", %s", errVar)
}
g.p(" := %s(", g.qualifiedID(c.importPath, c.funcName))
for j, a := range c.args {
if j > 0 {
g.p(", ")
}
if a < params.Len() {
g.p("%s", paramNames[a])
} else {
g.p("%s", localNames[a-params.Len()])
}
}
g.p(")\n")
if c.hasErr {
g.p("\tif %s != nil {\n", errVar)
// TODO(light): give information about failing provider
g.p("\t\treturn %s, err\n", zv)
g.p("\t}\n")
}
}
if len(calls) == 0 {
for i := range given {
if types.Identical(outType, given[i]) {
g.p("\treturn %s", paramNames[i])
break
}
}
} else {
g.p("\treturn %s", localNames[len(calls)-1])
}
if returnsErr {
g.p(", nil")
}
g.p("\n}\n")
return nil
}
func (g *gen) qualifiedID(path, sym string) string {
name := g.qualifyImport(path)
if name == "" {
return sym
}
return name + "." + sym
}
func (g *gen) qualifyImport(path string) string {
if path == g.currPackage {
return ""
}
if name := g.imports[path]; name != "" {
return name
}
// TODO(light): use parts of import path to disambiguate.
name := disambiguate(g.prog.Package(path).Pkg.Name(), func(n string) bool {
// Don't let an import take the "err" name. That's annoying.
return n == "err" || g.nameInFileScope(n)
})
g.imports[path] = name
return name
}
func (g *gen) nameInFileScope(name string) bool {
for _, other := range g.imports {
if other == name {
return true
}
}
_, obj := g.prog.Package(g.currPackage).Pkg.Scope().LookupParent(name, 0)
return obj != nil
}
func (g *gen) qualifyPkg(pkg *types.Package) string {
return g.qualifyImport(pkg.Path())
}
func (g *gen) p(format string, args ...interface{}) {
fmt.Fprintf(&g.buf, format, args...)
}
// A providerSet describes a set of providers. The zero value is an empty
// providerSet.
type providerSet struct {
providers []*providerInfo
imports []providerSetImport
}
type providerSetImport struct {
providerSetRef
pos token.Pos
}
// findProviderSets processes a package and extracts the provider sets declared in it.
func findProviderSets(fset *token.FileSet, pkg *types.Package, typeInfo *types.Info, files []*ast.File) (map[string]*providerSet, error) {
sets := make(map[string]*providerSet)
var directives []directive
for _, f := range files {
fileScope := typeInfo.Scopes[f]
for _, c := range f.Comments {
directives = extractDirectives(directives[:0], c)
for _, d := range directives {
switch d.kind {
case "provide", "use":
// handled later
case "import":
if fileScope == nil {
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fset.File(f.Pos()).Name())
}
i := strings.IndexByte(d.line, ' ')
// TODO(light): allow multiple imports in one line
if i == -1 {
return nil, fmt.Errorf("%s: invalid import: expected TARGET SETREF", fset.Position(d.pos))
}
name, spec := d.line[:i], d.line[i+1:]
ref, err := parseProviderSetRef(spec, fileScope, pkg.Path(), d.pos)
if err != nil {
return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err)
}
if ref.importPath != pkg.Path() {
imported := false
for _, imp := range pkg.Imports() {
if ref.importPath == imp.Path() {
imported = true
break
}
}
if !imported {
return nil, fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fset.Position(d.pos), name, ref.importPath)
}
}
if mod := sets[name]; mod != nil {
found := false
for _, other := range mod.imports {
if ref == other.providerSetRef {
found = true
break
}
}
if !found {
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
}
}
default:
return nil, fmt.Errorf("%v: unknown directive %s", fset.Position(d.pos), d.kind)
}
}
}
cmap := ast.NewCommentMap(fset, f, f.Comments)
for _, decl := range f.Decls {
directives = directives[:0]
for _, cg := range cmap[decl] {
directives = extractDirectives(directives, cg)
}
fn, isFunction := decl.(*ast.FuncDecl)
var providerSetName string
for _, d := range directives {
if d.kind != "provide" {
continue
}
if providerSetName != "" {
return nil, fmt.Errorf("%v: multiple provide directives for %s", fset.Position(d.pos), fn.Name.Name)
}
if !isFunction {
return nil, fmt.Errorf("%v: only functions can be marked as providers", fset.Position(d.pos))
}
providerSetName = fn.Name.Name
if d.line != "" {
// TODO(light): validate identifier
providerSetName = d.line
}
}
if providerSetName == "" {
continue
}
fpos := fn.Pos()
sig := typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
r := sig.Results()
var hasErr bool
switch r.Len() {
case 1:
hasErr = false
case 2:
if t := r.At(1).Type(); !types.Identical(t, errorType) {
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fset.Position(fpos), fn.Name.Name)
}
hasErr = true
default:
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name.Name)
}
out := r.At(0).Type()
p := sig.Params()
provider := &providerInfo{
importPath: pkg.Path(),
funcName: fn.Name.Name,
pos: fn.Pos(),
args: make([]types.Type, p.Len()),
out: out,
hasErr: hasErr,
}
for i := 0; i < p.Len(); i++ {
provider.args[i] = p.At(i).Type()
for j := 0; j < i; j++ {
if types.Identical(provider.args[i], provider.args[j]) {
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.args[j], nil))
}
}
}
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers {
if types.Identical(other.out, provider.out) {
return nil, fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fset.Position(fpos), providerSetName, types.TypeString(provider.out, nil), fset.Position(other.pos))
}
}
mod.providers = append(mod.providers, provider)
} else {
sets[providerSetName] = &providerSet{
providers: []*providerInfo{provider},
}
}
}
}
return sets, nil
}
// providerSetCache is a lazily evaluated index of provider sets.
type providerSetCache struct {
sets map[string]map[string]*providerSet
fset *token.FileSet
prog *loader.Program
}
func newProviderSetCache(prog *loader.Program) *providerSetCache {
return &providerSetCache{
fset: prog.Fset,
prog: prog,
}
}
func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
if mods, cached := mc.sets[ref.importPath]; cached {
mod := mods[ref.name]
if mod == nil {
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
}
return mod, nil
}
if mc.sets == nil {
mc.sets = make(map[string]map[string]*providerSet)
}
pkg := mc.prog.Package(ref.importPath)
mods, err := findProviderSets(mc.fset, pkg.Pkg, &pkg.Info, pkg.Files)
if err != nil {
mc.sets[ref.importPath] = nil
return nil, err
}
mc.sets[ref.importPath] = mods
mod := mods[ref.name]
if mod == nil {
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
}
return mod, nil
}
// solve finds the sequence of calls required to produce an output type
// with an optional set of provided inputs.
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []providerSetRef) ([]call, error) {
for i, g := range given {
for _, h := range given[:i] {
if types.Identical(g, h) {
return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))
}
}
}
providers, err := buildProviderMap(mc, sets)
if err != nil {
return nil, err
}
// Start building the mapping of type to local variable of the given type.
// The first len(given) local variables are the given types.
index := new(typeutil.Map)
for i, g := range given {
if p := providers.At(g); p != nil {
pp := p.(*providerInfo)
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos))
}
index.Set(g, i)
}
// Topological sort of the directed graph defined by the providers
// using a depth-first search. The graph may contain cycles, which
// should trigger an error.
var calls []call
var visit func(trail []types.Type) error
visit = func(trail []types.Type) error {
typ := trail[len(trail)-1]
if index.At(typ) != nil {
return nil
}
for _, t := range trail[:len(trail)-1] {
if types.Identical(typ, t) {
// TODO(light): describe cycle
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
}
}
p, _ := providers.At(typ).(*providerInfo)
if p == nil {
if len(trail) == 1 {
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
}
// TODO(light): give name of provider
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2], nil))
}
for _, a := range p.args {
// TODO(light): this will discard grown trail arrays.
if err := visit(append(trail, a)); err != nil {
return err
}
}
args := make([]int, len(p.args))
for i := range p.args {
args[i] = index.At(p.args[i]).(int)
}
index.Set(typ, len(given)+len(calls))
calls = append(calls, call{
importPath: p.importPath,
funcName: p.funcName,
args: args,
out: typ,
hasErr: p.hasErr,
})
return nil
}
if err := visit([]types.Type{out}); err != nil {
return nil, err
}
return calls, nil
}
func buildProviderMap(mc *providerSetCache, sets []providerSetRef) (*typeutil.Map, error) {
type nextEnt struct {
to providerSetRef
from providerSetRef
pos token.Pos
}
pm := new(typeutil.Map) // to *providerInfo
visited := make(map[providerSetRef]struct{})
var next []nextEnt
for _, ref := range sets {
next = append(next, nextEnt{to: ref})
}
for len(next) > 0 {
curr := next[0]
copy(next, next[1:])
next = next[:len(next)-1]
if _, skip := visited[curr.to]; skip {
continue
}
visited[curr.to] = struct{}{}
mod, err := mc.get(curr.to)
if err != nil {
if !curr.pos.IsValid() {
return nil, err
}
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
}
for _, p := range mod.providers {
if prev := pm.At(p.out); prev != nil {
pos := mc.fset.Position(p.pos)
typ := types.TypeString(p.out, nil)
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
if curr.from.importPath != "" {
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
}
pm.Set(p.out, p)
}
for _, imp := range mod.imports {
next = append(next, nextEnt{to: imp.providerSetRef, from: curr.to, pos: imp.pos})
}
}
return pm, nil
}
// A call represents a step of an injector function.
type call struct {
// importPath and funcName identify the provider function to call.
importPath string
funcName string
// args is a list of arguments to call the provider with. Each element is either:
// a) one of the givens (args[i] < len(given)) or
// b) the result of a previous provider call (args[i] >= len(given)).
args []int
// out is the type produced by this provider call.
out types.Type
// hasErr is true if the provider call returns an error.
hasErr bool
}
// providerInfo records the signature of a provider function.
type providerInfo struct {
importPath string
funcName string
pos token.Pos
args []types.Type
out types.Type
hasErr bool
}
// A providerSetRef is a parsed reference to a collection of providers.
type providerSetRef struct {
importPath string
name string
}
func parseProviderSetRef(ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) {
// TODO(light): verify that provider set name is an identifier before returning
i := strings.LastIndexByte(ref, '.')
if i == -1 {
return providerSetRef{importPath: pkg, name: ref}, nil
}
imp, name := ref[:i], ref[i+1:]
if strings.HasPrefix(imp, `"`) {
path, err := strconv.Unquote(imp)
if err != nil {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: bad import path", ref)
}
return providerSetRef{importPath: path, name: name}, nil
}
_, obj := s.LookupParent(imp, pos)
if obj == nil {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: unknown identifier %s", ref, imp)
}
pn, ok := obj.(*types.PkgName)
if !ok {
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %s does not name a package", ref, imp)
}
return providerSetRef{importPath: pn.Imported().Path(), name: name}, nil
}
func (ref providerSetRef) String() string {
return strconv.Quote(ref.importPath) + "." + ref.name
}
type directive struct {
pos token.Pos
kind string
line string
}
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
const prefix = "goose:"
text := cg.Text()
for len(text) > 0 {
text = strings.TrimLeft(text, " \t\r\n")
if !strings.HasPrefix(text, prefix) {
break
}
line := text[len(prefix):]
if i := strings.IndexByte(line, '\n'); i != -1 {
line, text = line[:i], line[i+1:]
} else {
text = ""
}
if i := strings.IndexByte(line, ' '); i != -1 {
d = append(d, directive{
kind: line[:i],
line: strings.TrimSpace(line[i+1:]),
pos: cg.Pos(), // TODO(light): more precise position
})
} else {
d = append(d, directive{
kind: line,
pos: cg.Pos(), // TODO(light): more precise position
})
}
}
return d
}
// isInjectFile reports whether a given file is an injection template.
func isInjectFile(f *ast.File) bool {
// TODO(light): better determination
for _, cg := range f.Comments {
text := cg.Text()
if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") {
return true
}
}
return false
}
// zeroValue returns the shortest expression that evaluates to the zero
// value for the given type.
func zeroValue(t types.Type, qf types.Qualifier) string {
switch u := t.Underlying().(type) {
case *types.Array, *types.Struct:
return types.TypeString(t, qf) + "{}"
case *types.Basic:
info := u.Info()
switch {
case info&types.IsBoolean != 0:
return "false"
case info&(types.IsInteger|types.IsFloat|types.IsComplex) != 0:
return "0"
case info&types.IsString != 0:
return `""`
default:
panic("unreachable")
}
case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice:
return "nil"
default:
panic("unreachable")
}
}
// typeVariableName invents a variable name derived from the type name
// or returns the empty string if one could not be found.
func typeVariableName(t types.Type) string {
if p, ok := t.(*types.Pointer); ok {
t = p.Elem()
}
tn, ok := t.(*types.Named)
if !ok {
return ""
}
// TODO(light): include package name when appropriate
return unexport(tn.Obj().Name())
}
// unexport converts a name that is potentially exported to an unexported name.
func unexport(name string) string {
r, sz := utf8.DecodeRuneInString(name)
if !unicode.IsUpper(r) {
// foo -> foo
return name
}
r2, sz2 := utf8.DecodeRuneInString(name[sz:])
if !unicode.IsUpper(r2) {
// Foo -> foo
return string(unicode.ToLower(r)) + name[sz:]
}
// UPPERWord -> upperWord
sbuf := new(strings.Builder)
sbuf.WriteRune(unicode.ToLower(r))
i := sz
r, sz = r2, sz2
for unicode.IsUpper(r) && sz > 0 {
r2, sz2 := utf8.DecodeRuneInString(name[i+sz:])
if sz2 > 0 && unicode.IsLower(r2) {
break
}
i += sz
sbuf.WriteRune(unicode.ToLower(r))
r, sz = r2, sz2
}
sbuf.WriteString(name[i:])
return sbuf.String()
}
// disambiguate picks a unique name, preferring name if it is already unique.
func disambiguate(name string, collides func(string) bool) string {
if !collides(name) {
return name
}
buf := []byte(name)
if len(buf) > 0 && buf[len(buf)-1] >= '0' && buf[len(buf)-1] <= '9' {
buf = append(buf, '_')
}
base := len(buf)
for n := 2; ; n++ {
buf = strconv.AppendInt(buf[:base], int64(n), 10)
sbuf := string(buf)
if !collides(sbuf) {
return sbuf
}
}
}
var errorType = types.Universe.Lookup("error").Type()