goose: allow non-injector code to live along with injectors
Previously, goose would ignore declarations in the //+build gooseinject files that were not injectors. This meant that if you wanted to write application-specific providers, you would need to place them in a separate file, away from the goose injectors. This means that a typical application would have three handwritten files: one for the abstract business logic, one for the platform-specific providers, one for the platform-specific injector declarations. This change allows the two platform-specific files to be merged into one: the //+build gooseinject file. goose will now copy these declarations out to goose_gen.go. This requires a bit of hackery, since the generated file may have different identifiers for the imported packages, so goose will do some light AST rewriting to address these cases. (Historical note: this was the first change made externally, so also in here are the copyright headers and other housekeeping changes.) Reviewed-by: Tuo Shan <shantuo@google.com> Reviewed-by: kokoro <noreply+kokoro@google.com>
This commit is contained in:
@@ -1,3 +1,17 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package goose provides compile-time dependency injection logic as a
|
||||
// Go library.
|
||||
package goose
|
||||
@@ -8,14 +22,17 @@ import (
|
||||
"go/ast"
|
||||
"go/build"
|
||||
"go/format"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/tools/go/ast/astutil"
|
||||
"golang.org/x/tools/go/loader"
|
||||
)
|
||||
|
||||
@@ -50,7 +67,25 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||
}
|
||||
pkgInfo := prog.InitialPackages()[0]
|
||||
g := newGen(prog, pkgInfo.Pkg.Path())
|
||||
oc := newObjectCache(prog)
|
||||
injectorFiles, err := generateInjectors(g, pkgInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info)
|
||||
goSrc := g.frame()
|
||||
fmtSrc, err := format.Source(goSrc)
|
||||
if err != nil {
|
||||
// This is likely a bug from a poorly generated source file.
|
||||
// Return an error and the unformatted source.
|
||||
return goSrc, err
|
||||
}
|
||||
return fmtSrc, nil
|
||||
}
|
||||
|
||||
// generateInjectors generates the injectors for a given package.
|
||||
func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*ast.File, _ error) {
|
||||
oc := newObjectCache(g.prog)
|
||||
injectorFiles = make([]*ast.File, 0, len(pkgInfo.Files))
|
||||
for _, f := range pkgInfo.Files {
|
||||
for _, decl := range f.Decls {
|
||||
fn, ok := decl.(*ast.FuncDecl)
|
||||
@@ -61,24 +96,54 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||
if useCall == nil {
|
||||
continue
|
||||
}
|
||||
if len(injectorFiles) == 0 || injectorFiles[len(injectorFiles)-1] != f {
|
||||
// This is the first injector generated for this file.
|
||||
// Write a file header.
|
||||
name := filepath.Base(g.prog.Fset.File(f.Pos()).Name())
|
||||
g.p("// Injectors from %s:\n\n", name)
|
||||
injectorFiles = append(injectorFiles, f)
|
||||
}
|
||||
set, err := oc.processNewSet(pkgInfo, useCall)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
|
||||
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
|
||||
}
|
||||
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||
if err := g.inject(prog.Fset, fn.Name.Name, sig, set); err != nil {
|
||||
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
|
||||
if err := g.inject(g.prog.Fset, fn.Name.Name, sig, set); err != nil {
|
||||
return nil, fmt.Errorf("%v: %v", g.prog.Fset.Position(fn.Pos()), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
goSrc := g.frame()
|
||||
fmtSrc, err := format.Source(goSrc)
|
||||
if err != nil {
|
||||
// This is likely a bug from a poorly generated source file.
|
||||
// Return an error and the unformatted source.
|
||||
return goSrc, err
|
||||
return injectorFiles, nil
|
||||
}
|
||||
|
||||
// copyNonInjectorDecls copies any non-injector declarations from the
|
||||
// given files into the generated output.
|
||||
func copyNonInjectorDecls(g *gen, files []*ast.File, info *types.Info) {
|
||||
for _, f := range files {
|
||||
name := filepath.Base(g.prog.Fset.File(f.Pos()).Name())
|
||||
first := true
|
||||
for _, decl := range f.Decls {
|
||||
switch decl := decl.(type) {
|
||||
case *ast.FuncDecl:
|
||||
if isInjector(info, decl) != nil {
|
||||
continue
|
||||
}
|
||||
case *ast.GenDecl:
|
||||
if decl.Tok == token.IMPORT {
|
||||
continue
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if first {
|
||||
g.p("// %s:\n\n", name)
|
||||
first = false
|
||||
}
|
||||
// TODO(light): Add line number at top of each declaration.
|
||||
g.writeAST(g.prog.Fset, info, decl)
|
||||
g.p("\n\n")
|
||||
}
|
||||
}
|
||||
return fmtSrc, nil
|
||||
}
|
||||
|
||||
// gen is the generator state.
|
||||
@@ -334,10 +399,127 @@ func (g *gen) inject(fset *token.FileSet, name string, sig *types.Signature, set
|
||||
if returnsErr {
|
||||
g.p(", nil")
|
||||
}
|
||||
g.p("\n}\n")
|
||||
g.p("\n}\n\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeAST prints an AST node into the generated output, rewriting any
|
||||
// package references it encounters.
|
||||
func (g *gen) writeAST(fset *token.FileSet, info *types.Info, node ast.Node) {
|
||||
start, end := node.Pos(), node.End()
|
||||
node = copyAST(node)
|
||||
// First, rewrite all package names. This lets us know all the
|
||||
// potentially colliding identifiers.
|
||||
node = astutil.Apply(node, func(c *astutil.Cursor) bool {
|
||||
switch node := c.Node().(type) {
|
||||
case *ast.Ident:
|
||||
// This is an unqualified identifier (qualified identifiers are peeled off below).
|
||||
obj := info.ObjectOf(node)
|
||||
if obj == nil {
|
||||
return false
|
||||
}
|
||||
if pkg := obj.Pkg(); pkg != nil && obj.Parent() == pkg.Scope() && pkg.Path() != g.currPackage {
|
||||
// An identifier from either a dot import or read from a different package.
|
||||
newPkgID := g.qualifyImport(pkg.Path())
|
||||
c.Replace(&ast.SelectorExpr{
|
||||
X: ast.NewIdent(newPkgID),
|
||||
Sel: ast.NewIdent(node.Name),
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case *ast.SelectorExpr:
|
||||
pkgIdent, ok := node.X.(*ast.Ident)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
pkgName, ok := info.ObjectOf(pkgIdent).(*types.PkgName)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
// This is a qualified identifier. Rewrite and avoid visiting subexpressions.
|
||||
newPkgID := g.qualifyImport(pkgName.Imported().Path())
|
||||
c.Replace(&ast.SelectorExpr{
|
||||
X: ast.NewIdent(newPkgID),
|
||||
Sel: ast.NewIdent(node.Sel.Name),
|
||||
})
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}, nil)
|
||||
// Now that we have all the identifiers, rename any variables declared
|
||||
// in this scope to not collide.
|
||||
newNames := make(map[types.Object]string)
|
||||
inNewNames := func(n string) bool {
|
||||
for _, other := range newNames {
|
||||
if other == n {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
var scopeStack []*types.Scope
|
||||
pkgScope := g.prog.Package(g.currPackage).Pkg.Scope()
|
||||
node = astutil.Apply(node, func(c *astutil.Cursor) bool {
|
||||
if scope := info.Scopes[c.Node()]; scope != nil {
|
||||
scopeStack = append(scopeStack, scope)
|
||||
}
|
||||
id, ok := c.Node().(*ast.Ident)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
obj := info.ObjectOf(id)
|
||||
if obj == nil {
|
||||
// We rewrote this identifier earlier, so it does not need
|
||||
// further rewriting.
|
||||
return true
|
||||
}
|
||||
if n, ok := newNames[obj]; ok {
|
||||
// We picked a new name for this symbol. Rewrite it.
|
||||
c.Replace(ast.NewIdent(n))
|
||||
return false
|
||||
}
|
||||
if par := obj.Parent(); par == nil || par == pkgScope {
|
||||
// Don't rename methods, field names, or top-level identifiers.
|
||||
return true
|
||||
}
|
||||
|
||||
// Rename any symbols defined within writeAST's node that conflict
|
||||
// with any symbols in the generated file.
|
||||
objName := obj.Name()
|
||||
if pos := obj.Pos(); pos < start || end <= pos || !(g.nameInFileScope(objName) || inNewNames(objName)) {
|
||||
return true
|
||||
}
|
||||
newName := disambiguate(objName, func(n string) bool {
|
||||
if g.nameInFileScope(n) || inNewNames(n) {
|
||||
return true
|
||||
}
|
||||
if len(scopeStack) > 0 {
|
||||
// Avoid picking a name that conflicts with other names in the
|
||||
// current scope.
|
||||
_, obj := scopeStack[len(scopeStack)-1].LookupParent(n, 0)
|
||||
if obj != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
newNames[obj] = newName
|
||||
c.Replace(ast.NewIdent(newName))
|
||||
return false
|
||||
}, func(c *astutil.Cursor) bool {
|
||||
if info.Scopes[c.Node()] != nil {
|
||||
// Should be top of stack; pop it.
|
||||
scopeStack = scopeStack[:len(scopeStack)-1]
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err := printer.Fprint(&g.buf, fset, node); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gen) qualifiedID(path, sym string) string {
|
||||
name := g.qualifyImport(path)
|
||||
if name == "" {
|
||||
|
||||
Reference in New Issue
Block a user