goose: allow non-injector code to live along with injectors
Previously, goose would ignore declarations in the //+build gooseinject files that were not injectors. This meant that if you wanted to write application-specific providers, you would need to place them in a separate file, away from the goose injectors. This means that a typical application would have three handwritten files: one for the abstract business logic, one for the platform-specific providers, one for the platform-specific injector declarations. This change allows the two platform-specific files to be merged into one: the //+build gooseinject file. goose will now copy these declarations out to goose_gen.go. This requires a bit of hackery, since the generated file may have different identifiers for the imported packages, so goose will do some light AST rewriting to address these cases. (Historical note: this was the first change made externally, so also in here are the copyright headers and other housekeeping changes.) Reviewed-by: Tuo Shan <shantuo@google.com> Reviewed-by: kokoro <noreply+kokoro@google.com>
This commit is contained in:
493
internal/goose/copyast.go
Normal file
493
internal/goose/copyast.go
Normal file
@@ -0,0 +1,493 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package goose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
|
||||
"golang.org/x/tools/go/ast/astutil"
|
||||
)
|
||||
|
||||
// copyAST performs a deep copy of an AST. *ast.Ident identity will be
|
||||
// preserved.
|
||||
//
|
||||
// This allows using astutil.Apply to rewrite an AST without modifying
|
||||
// the original AST.
|
||||
func copyAST(original ast.Node) ast.Node {
|
||||
// This function is necessarily long. No utility function exists to do this
|
||||
// clone, as most any attempt would need to have customization options, which
|
||||
// would need to be as expressive as Apply. A possibility to shorten the code
|
||||
// here would be to use reflection, but that trades clarity for shorter code.
|
||||
|
||||
m := make(map[ast.Node]ast.Node)
|
||||
astutil.Apply(original, nil, func(c *astutil.Cursor) bool {
|
||||
switch node := c.Node().(type) {
|
||||
case nil:
|
||||
// No-op.
|
||||
case *ast.ArrayType:
|
||||
m[node] = &ast.ArrayType{
|
||||
Lbrack: node.Lbrack,
|
||||
Len: exprFromMap(m, node.Len),
|
||||
Elt: exprFromMap(m, node.Elt),
|
||||
}
|
||||
case *ast.AssignStmt:
|
||||
m[node] = &ast.AssignStmt{
|
||||
Lhs: copyExprList(m, node.Lhs),
|
||||
TokPos: node.TokPos,
|
||||
Tok: node.Tok,
|
||||
Rhs: copyExprList(m, node.Rhs),
|
||||
}
|
||||
case *ast.BadDecl:
|
||||
m[node] = &ast.BadDecl{
|
||||
From: node.From,
|
||||
To: node.To,
|
||||
}
|
||||
case *ast.BadExpr:
|
||||
m[node] = &ast.BadExpr{
|
||||
From: node.From,
|
||||
To: node.To,
|
||||
}
|
||||
case *ast.BadStmt:
|
||||
m[node] = &ast.BadStmt{
|
||||
From: node.From,
|
||||
To: node.To,
|
||||
}
|
||||
case *ast.BasicLit:
|
||||
m[node] = &ast.BasicLit{
|
||||
ValuePos: node.ValuePos,
|
||||
Kind: node.Kind,
|
||||
Value: node.Value,
|
||||
}
|
||||
case *ast.BinaryExpr:
|
||||
m[node] = &ast.BinaryExpr{
|
||||
X: exprFromMap(m, node.X),
|
||||
OpPos: node.OpPos,
|
||||
Op: node.Op,
|
||||
Y: exprFromMap(m, node.Y),
|
||||
}
|
||||
case *ast.BlockStmt:
|
||||
m[node] = &ast.BlockStmt{
|
||||
Lbrace: node.Lbrace,
|
||||
List: copyStmtList(m, node.List),
|
||||
Rbrace: node.Rbrace,
|
||||
}
|
||||
case *ast.BranchStmt:
|
||||
m[node] = &ast.BranchStmt{
|
||||
TokPos: node.TokPos,
|
||||
Tok: node.Tok,
|
||||
Label: identFromMap(m, node.Label),
|
||||
}
|
||||
case *ast.CallExpr:
|
||||
m[node] = &ast.CallExpr{
|
||||
Fun: exprFromMap(m, node.Fun),
|
||||
Lparen: node.Lparen,
|
||||
Args: copyExprList(m, node.Args),
|
||||
Ellipsis: node.Ellipsis,
|
||||
Rparen: node.Rparen,
|
||||
}
|
||||
case *ast.CaseClause:
|
||||
m[node] = &ast.CaseClause{
|
||||
Case: node.Case,
|
||||
List: copyExprList(m, node.List),
|
||||
Colon: node.Colon,
|
||||
Body: copyStmtList(m, node.Body),
|
||||
}
|
||||
case *ast.ChanType:
|
||||
m[node] = &ast.ChanType{
|
||||
Begin: node.Begin,
|
||||
Arrow: node.Arrow,
|
||||
Dir: node.Dir,
|
||||
Value: exprFromMap(m, node.Value),
|
||||
}
|
||||
case *ast.CommClause:
|
||||
m[node] = &ast.CommClause{
|
||||
Case: node.Case,
|
||||
Comm: stmtFromMap(m, node.Comm),
|
||||
Colon: node.Colon,
|
||||
Body: copyStmtList(m, node.Body),
|
||||
}
|
||||
case *ast.Comment:
|
||||
m[node] = &ast.Comment{
|
||||
Slash: node.Slash,
|
||||
Text: node.Text,
|
||||
}
|
||||
case *ast.CommentGroup:
|
||||
cg := new(ast.CommentGroup)
|
||||
if node.List != nil {
|
||||
cg.List = make([]*ast.Comment, len(node.List))
|
||||
for i := range node.List {
|
||||
cg.List[i] = m[node.List[i]].(*ast.Comment)
|
||||
}
|
||||
}
|
||||
m[node] = cg
|
||||
case *ast.CompositeLit:
|
||||
m[node] = &ast.CompositeLit{
|
||||
Type: exprFromMap(m, node.Type),
|
||||
Lbrace: node.Lbrace,
|
||||
Elts: copyExprList(m, node.Elts),
|
||||
Rbrace: node.Rbrace,
|
||||
}
|
||||
case *ast.DeclStmt:
|
||||
m[node] = &ast.DeclStmt{
|
||||
Decl: m[node.Decl].(ast.Decl),
|
||||
}
|
||||
case *ast.DeferStmt:
|
||||
m[node] = &ast.DeferStmt{
|
||||
Defer: node.Defer,
|
||||
Call: callExprFromMap(m, node.Call),
|
||||
}
|
||||
case *ast.Ellipsis:
|
||||
m[node] = &ast.Ellipsis{
|
||||
Ellipsis: node.Ellipsis,
|
||||
Elt: exprFromMap(m, node.Elt),
|
||||
}
|
||||
case *ast.EmptyStmt:
|
||||
m[node] = &ast.EmptyStmt{
|
||||
Semicolon: node.Semicolon,
|
||||
Implicit: node.Implicit,
|
||||
}
|
||||
case *ast.ExprStmt:
|
||||
m[node] = &ast.ExprStmt{
|
||||
X: exprFromMap(m, node.X),
|
||||
}
|
||||
case *ast.Field:
|
||||
m[node] = &ast.Field{
|
||||
Doc: commentGroupFromMap(m, node.Doc),
|
||||
Names: copyIdentList(m, node.Names),
|
||||
Type: exprFromMap(m, node.Type),
|
||||
Tag: basicLitFromMap(m, node.Tag),
|
||||
Comment: commentGroupFromMap(m, node.Comment),
|
||||
}
|
||||
case *ast.FieldList:
|
||||
fl := &ast.FieldList{
|
||||
Opening: node.Opening,
|
||||
Closing: node.Closing,
|
||||
}
|
||||
if node.List != nil {
|
||||
fl.List = make([]*ast.Field, len(node.List))
|
||||
for i := range node.List {
|
||||
fl.List[i] = m[node.List[i]].(*ast.Field)
|
||||
}
|
||||
}
|
||||
m[node] = fl
|
||||
case *ast.ForStmt:
|
||||
m[node] = &ast.ForStmt{
|
||||
For: node.For,
|
||||
Init: stmtFromMap(m, node.Init),
|
||||
Cond: exprFromMap(m, node.Cond),
|
||||
Post: stmtFromMap(m, node.Post),
|
||||
Body: blockStmtFromMap(m, node.Body),
|
||||
}
|
||||
case *ast.FuncDecl:
|
||||
m[node] = &ast.FuncDecl{
|
||||
Doc: commentGroupFromMap(m, node.Doc),
|
||||
Recv: fieldListFromMap(m, node.Recv),
|
||||
Name: identFromMap(m, node.Name),
|
||||
Type: funcTypeFromMap(m, node.Type),
|
||||
Body: blockStmtFromMap(m, node.Body),
|
||||
}
|
||||
case *ast.FuncLit:
|
||||
m[node] = &ast.FuncLit{
|
||||
Type: funcTypeFromMap(m, node.Type),
|
||||
Body: blockStmtFromMap(m, node.Body),
|
||||
}
|
||||
case *ast.FuncType:
|
||||
m[node] = &ast.FuncType{
|
||||
Func: node.Func,
|
||||
Params: fieldListFromMap(m, node.Params),
|
||||
Results: fieldListFromMap(m, node.Results),
|
||||
}
|
||||
case *ast.GenDecl:
|
||||
decl := &ast.GenDecl{
|
||||
Doc: commentGroupFromMap(m, node.Doc),
|
||||
TokPos: node.TokPos,
|
||||
Tok: node.Tok,
|
||||
Lparen: node.Lparen,
|
||||
Rparen: node.Rparen,
|
||||
}
|
||||
if node.Specs != nil {
|
||||
decl.Specs = make([]ast.Spec, len(node.Specs))
|
||||
for i := range node.Specs {
|
||||
decl.Specs[i] = m[node.Specs[i]].(ast.Spec)
|
||||
}
|
||||
}
|
||||
m[node] = decl
|
||||
case *ast.GoStmt:
|
||||
m[node] = &ast.GoStmt{
|
||||
Go: node.Go,
|
||||
Call: callExprFromMap(m, node.Call),
|
||||
}
|
||||
case *ast.Ident:
|
||||
// Keep identifiers the same identity so they can be conveniently
|
||||
// used with the original *types.Info.
|
||||
m[node] = node
|
||||
case *ast.IfStmt:
|
||||
m[node] = &ast.IfStmt{
|
||||
If: node.If,
|
||||
Init: stmtFromMap(m, node.Init),
|
||||
Cond: exprFromMap(m, node.Cond),
|
||||
Body: blockStmtFromMap(m, node.Body),
|
||||
Else: stmtFromMap(m, node.Else),
|
||||
}
|
||||
case *ast.ImportSpec:
|
||||
m[node] = &ast.ImportSpec{
|
||||
Doc: commentGroupFromMap(m, node.Doc),
|
||||
Name: identFromMap(m, node.Name),
|
||||
Path: basicLitFromMap(m, node.Path),
|
||||
Comment: commentGroupFromMap(m, node.Comment),
|
||||
EndPos: node.EndPos,
|
||||
}
|
||||
case *ast.IncDecStmt:
|
||||
m[node] = &ast.IncDecStmt{
|
||||
X: exprFromMap(m, node.X),
|
||||
TokPos: node.TokPos,
|
||||
Tok: node.Tok,
|
||||
}
|
||||
case *ast.IndexExpr:
|
||||
m[node] = &ast.IndexExpr{
|
||||
X: exprFromMap(m, node.X),
|
||||
Lbrack: node.Lbrack,
|
||||
Index: exprFromMap(m, node.Index),
|
||||
Rbrack: node.Rbrack,
|
||||
}
|
||||
case *ast.InterfaceType:
|
||||
m[node] = &ast.InterfaceType{
|
||||
Interface: node.Interface,
|
||||
Methods: fieldListFromMap(m, node.Methods),
|
||||
Incomplete: node.Incomplete,
|
||||
}
|
||||
case *ast.KeyValueExpr:
|
||||
m[node] = &ast.KeyValueExpr{
|
||||
Key: exprFromMap(m, node.Key),
|
||||
Colon: node.Colon,
|
||||
Value: exprFromMap(m, node.Value),
|
||||
}
|
||||
case *ast.LabeledStmt:
|
||||
m[node] = &ast.LabeledStmt{
|
||||
Label: identFromMap(m, node.Label),
|
||||
Colon: node.Colon,
|
||||
Stmt: stmtFromMap(m, node.Stmt),
|
||||
}
|
||||
case *ast.MapType:
|
||||
m[node] = &ast.MapType{
|
||||
Map: node.Map,
|
||||
Key: exprFromMap(m, node.Key),
|
||||
Value: exprFromMap(m, node.Value),
|
||||
}
|
||||
case *ast.ParenExpr:
|
||||
m[node] = &ast.ParenExpr{
|
||||
Lparen: node.Lparen,
|
||||
X: exprFromMap(m, node.X),
|
||||
Rparen: node.Rparen,
|
||||
}
|
||||
case *ast.RangeStmt:
|
||||
m[node] = &ast.RangeStmt{
|
||||
For: node.For,
|
||||
Key: exprFromMap(m, node.Key),
|
||||
Value: exprFromMap(m, node.Value),
|
||||
TokPos: node.TokPos,
|
||||
Tok: node.Tok,
|
||||
X: exprFromMap(m, node.X),
|
||||
Body: blockStmtFromMap(m, node.Body),
|
||||
}
|
||||
case *ast.ReturnStmt:
|
||||
m[node] = &ast.ReturnStmt{
|
||||
Return: node.Return,
|
||||
Results: copyExprList(m, node.Results),
|
||||
}
|
||||
case *ast.SelectStmt:
|
||||
m[node] = &ast.SelectStmt{
|
||||
Select: node.Select,
|
||||
Body: blockStmtFromMap(m, node.Body),
|
||||
}
|
||||
case *ast.SelectorExpr:
|
||||
m[node] = &ast.SelectorExpr{
|
||||
X: exprFromMap(m, node.X),
|
||||
Sel: identFromMap(m, node.Sel),
|
||||
}
|
||||
case *ast.SendStmt:
|
||||
m[node] = &ast.SendStmt{
|
||||
Chan: exprFromMap(m, node.Chan),
|
||||
Arrow: node.Arrow,
|
||||
Value: exprFromMap(m, node.Value),
|
||||
}
|
||||
case *ast.SliceExpr:
|
||||
m[node] = &ast.SliceExpr{
|
||||
X: exprFromMap(m, node.X),
|
||||
Lbrack: node.Lbrack,
|
||||
Low: exprFromMap(m, node.Low),
|
||||
High: exprFromMap(m, node.High),
|
||||
Max: exprFromMap(m, node.Max),
|
||||
Slice3: node.Slice3,
|
||||
Rbrack: node.Rbrack,
|
||||
}
|
||||
case *ast.StarExpr:
|
||||
m[node] = &ast.StarExpr{
|
||||
Star: node.Star,
|
||||
X: exprFromMap(m, node.X),
|
||||
}
|
||||
case *ast.StructType:
|
||||
m[node] = &ast.StructType{
|
||||
Struct: node.Struct,
|
||||
Fields: fieldListFromMap(m, node.Fields),
|
||||
Incomplete: node.Incomplete,
|
||||
}
|
||||
case *ast.SwitchStmt:
|
||||
m[node] = &ast.SwitchStmt{
|
||||
Switch: node.Switch,
|
||||
Init: stmtFromMap(m, node.Init),
|
||||
Tag: exprFromMap(m, node.Tag),
|
||||
Body: blockStmtFromMap(m, node.Body),
|
||||
}
|
||||
case *ast.TypeAssertExpr:
|
||||
m[node] = &ast.TypeAssertExpr{
|
||||
X: exprFromMap(m, node.X),
|
||||
Lparen: node.Lparen,
|
||||
Type: exprFromMap(m, node.Type),
|
||||
Rparen: node.Rparen,
|
||||
}
|
||||
case *ast.TypeSpec:
|
||||
m[node] = &ast.TypeSpec{
|
||||
Doc: commentGroupFromMap(m, node.Doc),
|
||||
Name: identFromMap(m, node.Name),
|
||||
Assign: node.Assign,
|
||||
Type: exprFromMap(m, node.Type),
|
||||
Comment: commentGroupFromMap(m, node.Comment),
|
||||
}
|
||||
case *ast.TypeSwitchStmt:
|
||||
m[node] = &ast.TypeSwitchStmt{
|
||||
Switch: node.Switch,
|
||||
Init: stmtFromMap(m, node.Init),
|
||||
Assign: stmtFromMap(m, node.Assign),
|
||||
Body: blockStmtFromMap(m, node.Body),
|
||||
}
|
||||
case *ast.UnaryExpr:
|
||||
m[node] = &ast.UnaryExpr{
|
||||
OpPos: node.OpPos,
|
||||
Op: node.Op,
|
||||
X: exprFromMap(m, node.X),
|
||||
}
|
||||
case *ast.ValueSpec:
|
||||
m[node] = &ast.ValueSpec{
|
||||
Doc: commentGroupFromMap(m, node.Doc),
|
||||
Names: copyIdentList(m, node.Names),
|
||||
Type: exprFromMap(m, node.Type),
|
||||
Values: copyExprList(m, node.Values),
|
||||
Comment: commentGroupFromMap(m, node.Comment),
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unhandled AST node: %T", node))
|
||||
}
|
||||
return true
|
||||
})
|
||||
return m[original]
|
||||
}
|
||||
|
||||
func commentGroupFromMap(m map[ast.Node]ast.Node, key *ast.CommentGroup) *ast.CommentGroup {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
return m[key].(*ast.CommentGroup)
|
||||
}
|
||||
|
||||
func exprFromMap(m map[ast.Node]ast.Node, key ast.Expr) ast.Expr {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
return m[key].(ast.Expr)
|
||||
}
|
||||
|
||||
func stmtFromMap(m map[ast.Node]ast.Node, key ast.Stmt) ast.Stmt {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
return m[key].(ast.Stmt)
|
||||
}
|
||||
|
||||
func identFromMap(m map[ast.Node]ast.Node, key *ast.Ident) *ast.Ident {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
return m[key].(*ast.Ident)
|
||||
}
|
||||
|
||||
func blockStmtFromMap(m map[ast.Node]ast.Node, key *ast.BlockStmt) *ast.BlockStmt {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
return m[key].(*ast.BlockStmt)
|
||||
}
|
||||
|
||||
func fieldListFromMap(m map[ast.Node]ast.Node, key *ast.FieldList) *ast.FieldList {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
return m[key].(*ast.FieldList)
|
||||
}
|
||||
|
||||
func callExprFromMap(m map[ast.Node]ast.Node, key *ast.CallExpr) *ast.CallExpr {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
return m[key].(*ast.CallExpr)
|
||||
}
|
||||
|
||||
func basicLitFromMap(m map[ast.Node]ast.Node, key *ast.BasicLit) *ast.BasicLit {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
return m[key].(*ast.BasicLit)
|
||||
}
|
||||
|
||||
func funcTypeFromMap(m map[ast.Node]ast.Node, key *ast.FuncType) *ast.FuncType {
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
return m[key].(*ast.FuncType)
|
||||
}
|
||||
|
||||
func copyExprList(m map[ast.Node]ast.Node, exprs []ast.Expr) []ast.Expr {
|
||||
if exprs == nil {
|
||||
return nil
|
||||
}
|
||||
newExprs := make([]ast.Expr, len(exprs))
|
||||
for i := range exprs {
|
||||
newExprs[i] = m[exprs[i]].(ast.Expr)
|
||||
}
|
||||
return newExprs
|
||||
}
|
||||
|
||||
func copyStmtList(m map[ast.Node]ast.Node, stmts []ast.Stmt) []ast.Stmt {
|
||||
if stmts == nil {
|
||||
return nil
|
||||
}
|
||||
newStmts := make([]ast.Stmt, len(stmts))
|
||||
for i := range stmts {
|
||||
newStmts[i] = m[stmts[i]].(ast.Stmt)
|
||||
}
|
||||
return newStmts
|
||||
}
|
||||
|
||||
func copyIdentList(m map[ast.Node]ast.Node, idents []*ast.Ident) []*ast.Ident {
|
||||
if idents == nil {
|
||||
return nil
|
||||
}
|
||||
newIdents := make([]*ast.Ident, len(idents))
|
||||
for i := range idents {
|
||||
newIdents[i] = m[idents[i]].(*ast.Ident)
|
||||
}
|
||||
return newIdents
|
||||
}
|
||||
Reference in New Issue
Block a user