goose: allow non-injector code to live along with injectors

Previously, goose would ignore declarations in the //+build gooseinject
files that were not injectors. This meant that if you wanted to write
application-specific providers, you would need to place them in a
separate file, away from the goose injectors. This means that a typical
application would have three handwritten files: one for the abstract
business logic, one for the platform-specific providers, one for the
platform-specific injector declarations.

This change allows the two platform-specific files to be merged into
one: the //+build gooseinject file. goose will now copy these
declarations out to goose_gen.go. This requires a bit of hackery, since
the generated file may have different identifiers for the imported
packages, so goose will do some light AST rewriting to address these
cases.

(Historical note: this was the first change made externally, so also in
here are the copyright headers and other housekeeping changes.)

Reviewed-by: Tuo Shan <shantuo@google.com>
Reviewed-by: kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ross Light
2018-05-01 14:46:39 -04:00
parent f8e446fa17
commit 235a7d8f80
57 changed files with 1501 additions and 47 deletions

View File

@@ -1,3 +1,17 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package goose provides compile-time dependency injection logic as a
// Go library.
package goose
@@ -8,14 +22,17 @@ import (
"go/ast"
"go/build"
"go/format"
"go/printer"
"go/token"
"go/types"
"path/filepath"
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/loader"
)
@@ -50,7 +67,25 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
}
pkgInfo := prog.InitialPackages()[0]
g := newGen(prog, pkgInfo.Pkg.Path())
oc := newObjectCache(prog)
injectorFiles, err := generateInjectors(g, pkgInfo)
if err != nil {
return nil, err
}
copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info)
goSrc := g.frame()
fmtSrc, err := format.Source(goSrc)
if err != nil {
// This is likely a bug from a poorly generated source file.
// Return an error and the unformatted source.
return goSrc, err
}
return fmtSrc, nil
}
// generateInjectors generates the injectors for a given package.
func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*ast.File, _ error) {
oc := newObjectCache(g.prog)
injectorFiles = make([]*ast.File, 0, len(pkgInfo.Files))
for _, f := range pkgInfo.Files {
for _, decl := range f.Decls {
fn, ok := decl.(*ast.FuncDecl)
@@ -61,24 +96,54 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
if useCall == nil {
continue
}
if len(injectorFiles) == 0 || injectorFiles[len(injectorFiles)-1] != f {
// This is the first injector generated for this file.
// Write a file header.
name := filepath.Base(g.prog.Fset.File(f.Pos()).Name())
g.p("// Injectors from %s:\n\n", name)
injectorFiles = append(injectorFiles, f)
}
set, err := oc.processNewSet(pkgInfo, useCall)
if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
}
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
if err := g.inject(prog.Fset, fn.Name.Name, sig, set); err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
if err := g.inject(g.prog.Fset, fn.Name.Name, sig, set); err != nil {
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
}
}
}
goSrc := g.frame()
fmtSrc, err := format.Source(goSrc)
if err != nil {
// This is likely a bug from a poorly generated source file.
// Return an error and the unformatted source.
return goSrc, err
return injectorFiles, nil
}
// copyNonInjectorDecls copies any non-injector declarations from the
// given files into the generated output.
func copyNonInjectorDecls(g *gen, files []*ast.File, info *types.Info) {
for _, f := range files {
name := filepath.Base(g.prog.Fset.File(f.Pos()).Name())
first := true
for _, decl := range f.Decls {
switch decl := decl.(type) {
case *ast.FuncDecl:
if isInjector(info, decl) != nil {
continue
}
case *ast.GenDecl:
if decl.Tok == token.IMPORT {
continue
}
default:
continue
}
if first {
g.p("// %s:\n\n", name)
first = false
}
// TODO(light): Add line number at top of each declaration.
g.writeAST(g.prog.Fset, info, decl)
g.p("\n\n")
}
}
return fmtSrc, nil
}
// gen is the generator state.
@@ -334,10 +399,127 @@ func (g *gen) inject(fset *token.FileSet, name string, sig *types.Signature, set
if returnsErr {
g.p(", nil")
}
g.p("\n}\n")
g.p("\n}\n\n")
return nil
}
// writeAST prints an AST node into the generated output, rewriting any
// package references it encounters.
func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) {
start, end := node.Pos(), node.End()
node = copyAST(node)
// First, rewrite all package names. This lets us know all the
// potentially colliding identifiers.
node = astutil.Apply(node, func(c *astutil.Cursor) bool {
switch node := c.Node().(type) {
case *ast.Ident:
// This is an unqualified identifier (qualified identifiers are peeled off below).
obj := info.ObjectOf(node)
if obj == nil {
return false
}
if pkg := obj.Pkg(); pkg != nil && obj.Parent() == pkg.Scope() && pkg.Path() != g.currPackage {
// An identifier from either a dot import or read from a different package.
newPkgID := g.qualifyImport(pkg.Path())
c.Replace(&ast.SelectorExpr{
X: ast.NewIdent(newPkgID),
Sel: ast.NewIdent(node.Name),
})
return false
}
return true
case *ast.SelectorExpr:
pkgIdent, ok := node.X.(*ast.Ident)
if !ok {
return true
}
pkgName, ok := info.ObjectOf(pkgIdent).(*types.PkgName)
if !ok {
return true
}
// This is a qualified identifier. Rewrite and avoid visiting subexpressions.
newPkgID := g.qualifyImport(pkgName.Imported().Path())
c.Replace(&ast.SelectorExpr{
X: ast.NewIdent(newPkgID),
Sel: ast.NewIdent(node.Sel.Name),
})
return false
default:
return true
}
}, nil)
// Now that we have all the identifiers, rename any variables declared
// in this scope to not collide.
newNames := make(map[types.Object]string)
inNewNames := func(n string) bool {
for _, other := range newNames {
if other == n {
return true
}
}
return false
}
var scopeStack []*types.Scope
pkgScope := g.prog.Package(g.currPackage).Pkg.Scope()
node = astutil.Apply(node, func(c *astutil.Cursor) bool {
if scope := info.Scopes[c.Node()]; scope != nil {
scopeStack = append(scopeStack, scope)
}
id, ok := c.Node().(*ast.Ident)
if !ok {
return true
}
obj := info.ObjectOf(id)
if obj == nil {
// We rewrote this identifier earlier, so it does not need
// further rewriting.
return true
}
if n, ok := newNames[obj]; ok {
// We picked a new name for this symbol. Rewrite it.
c.Replace(ast.NewIdent(n))
return false
}
if par := obj.Parent(); par == nil || par == pkgScope {
// Don't rename methods, field names, or top-level identifiers.
return true
}
// Rename any symbols defined within writeAST's node that conflict
// with any symbols in the generated file.
objName := obj.Name()
if pos := obj.Pos(); pos < start || end <= pos || !(g.nameInFileScope(objName) || inNewNames(objName)) {
return true
}
newName := disambiguate(objName, func(n string) bool {
if g.nameInFileScope(n) || inNewNames(n) {
return true
}
if len(scopeStack) > 0 {
// Avoid picking a name that conflicts with other names in the
// current scope.
_, obj := scopeStack[len(scopeStack)-1].LookupParent(n, 0)
if obj != nil {
return true
}
}
return false
})
newNames[obj] = newName
c.Replace(ast.NewIdent(newName))
return false
}, func(c *astutil.Cursor) bool {
if info.Scopes[c.Node()] != nil {
// Should be top of stack; pop it.
scopeStack = scopeStack[:len(scopeStack)-1]
}
return true
})
if err := printer.Fprint(&g.buf, fset, node); err != nil {
panic(err)
}
}
func (g *gen) qualifiedID(path, sym string) string {
name := g.qualifyImport(path)
if name == "" {