wire: report an error if a func with wire.Build in it is an invalid injector (google/go-cloud#487)

This commit is contained in:
Robert van Gent
2018-09-27 15:30:13 -07:00
committed by Ross Light
parent 3bc7933406
commit ec7cb36215
6 changed files with 124 additions and 34 deletions

View File

@@ -193,7 +193,11 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) {
if !ok {
continue
}
buildCall := isInjector(&pkgInfo.Info, fn)
buildCall, err := findInjectorBuild(&pkgInfo.Info, fn)
if err != nil {
ec.add(notePosition(prog.Fset.Position(fn.Pos()), fmt.Errorf("inject %s: %v", fn.Name.Name, err)))
continue
}
if buildCall == nil {
continue
}
@@ -770,53 +774,59 @@ func processInterfaceValue(fset *token.FileSet, info *types.Info, call *ast.Call
}, nil
}
// isInjector checks whether a given function declaration is an
// injector template, returning the wire.Build call. It returns nil if
// the function is not an injector template.
func isInjector(info *types.Info, fn *ast.FuncDecl) *ast.CallExpr {
// findInjectorBuild returns the wire.Build call if fn is an injector template.
// It returns nil if the function is not an injector template.
func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error) {
if fn.Body == nil {
return nil
return nil, nil
}
var only *ast.ExprStmt
numStatements := 0
invalid := false
var wireBuildCall *ast.CallExpr
for _, stmt := range fn.Body.List {
switch stmt := stmt.(type) {
case *ast.ExprStmt:
if only != nil {
return nil
numStatements++
if numStatements > 1 {
invalid = true
}
only = stmt
call, ok := stmt.X.(*ast.CallExpr)
if !ok {
continue
}
if qualifiedIdentObject(info, call.Fun) == types.Universe.Lookup("panic") {
if len(call.Args) != 1 {
continue
}
call, ok = call.Args[0].(*ast.CallExpr)
if !ok {
continue
}
}
buildObj := qualifiedIdentObject(info, call.Fun)
if buildObj == nil || buildObj.Pkg() == nil || !isWireImport(buildObj.Pkg().Path()) || buildObj.Name() != "Build" {
continue
}
wireBuildCall = call
case *ast.EmptyStmt:
// Do nothing.
case *ast.ReturnStmt:
// Allow the function to end in a return.
if only == nil {
return nil
if numStatements == 0 {
return nil, nil
}
default:
return nil
invalid = true
}
}
if only == nil {
return nil
if wireBuildCall == nil {
return nil, nil
}
call, ok := only.X.(*ast.CallExpr)
if !ok {
return nil
if invalid {
return nil, errors.New("a call to wire.Build indicates that this function is an injector, but injectors must consist of only the wire.Build call and an optional return")
}
if qualifiedIdentObject(info, call.Fun) == types.Universe.Lookup("panic") {
if len(call.Args) != 1 {
return nil
}
call, ok = call.Args[0].(*ast.CallExpr)
if !ok {
return nil
}
}
buildObj := qualifiedIdentObject(info, call.Fun)
if buildObj == nil || buildObj.Pkg() == nil || !isWireImport(buildObj.Pkg().Path()) || buildObj.Name() != "Build" {
return nil
}
return call
return wireBuildCall, nil
}
func isWireImport(path string) bool {