wire: report an error if a func with wire.Build in it is an invalid injector (google/go-cloud#487)

This commit is contained in:
Robert van Gent
2018-09-27 15:30:13 -07:00
committed by Ross Light
parent 3bc7933406
commit ec7cb36215
6 changed files with 124 additions and 34 deletions

View File

@@ -193,7 +193,11 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) {
if !ok {
continue
}
buildCall := isInjector(&pkgInfo.Info, fn)
buildCall, err := findInjectorBuild(&pkgInfo.Info, fn)
if err != nil {
ec.add(notePosition(prog.Fset.Position(fn.Pos()), fmt.Errorf("inject %s: %v", fn.Name.Name, err)))
continue
}
if buildCall == nil {
continue
}
@@ -770,53 +774,59 @@ func processInterfaceValue(fset *token.FileSet, info *types.Info, call *ast.Call
}, nil
}
// isInjector checks whether a given function declaration is an
// injector template, returning the wire.Build call. It returns nil if
// the function is not an injector template.
func isInjector(info *types.Info, fn *ast.FuncDecl) *ast.CallExpr {
// findInjectorBuild returns the wire.Build call if fn is an injector template.
// It returns nil if the function is not an injector template.
func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error) {
if fn.Body == nil {
return nil
return nil, nil
}
var only *ast.ExprStmt
numStatements := 0
invalid := false
var wireBuildCall *ast.CallExpr
for _, stmt := range fn.Body.List {
switch stmt := stmt.(type) {
case *ast.ExprStmt:
if only != nil {
return nil
numStatements++
if numStatements > 1 {
invalid = true
}
only = stmt
call, ok := stmt.X.(*ast.CallExpr)
if !ok {
continue
}
if qualifiedIdentObject(info, call.Fun) == types.Universe.Lookup("panic") {
if len(call.Args) != 1 {
continue
}
call, ok = call.Args[0].(*ast.CallExpr)
if !ok {
continue
}
}
buildObj := qualifiedIdentObject(info, call.Fun)
if buildObj == nil || buildObj.Pkg() == nil || !isWireImport(buildObj.Pkg().Path()) || buildObj.Name() != "Build" {
continue
}
wireBuildCall = call
case *ast.EmptyStmt:
// Do nothing.
case *ast.ReturnStmt:
// Allow the function to end in a return.
if only == nil {
return nil
if numStatements == 0 {
return nil, nil
}
default:
return nil
invalid = true
}
}
if only == nil {
return nil
if wireBuildCall == nil {
return nil, nil
}
call, ok := only.X.(*ast.CallExpr)
if !ok {
return nil
if invalid {
return nil, errors.New("a call to wire.Build indicates that this function is an injector, but injectors must consist of only the wire.Build call and an optional return")
}
if qualifiedIdentObject(info, call.Fun) == types.Universe.Lookup("panic") {
if len(call.Args) != 1 {
return nil
}
call, ok = call.Args[0].(*ast.CallExpr)
if !ok {
return nil
}
}
buildObj := qualifiedIdentObject(info, call.Fun)
if buildObj == nil || buildObj.Pkg() == nil || !isWireImport(buildObj.Pkg().Path()) || buildObj.Name() != "Build" {
return nil
}
return call
return wireBuildCall, nil
}
func isWireImport(path string) bool {

View File

@@ -0,0 +1,36 @@
// Copyright 2018 The Go Cloud Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"fmt"
)
func main() {
foo := injectFoo()
bar := injectBar()
fmt.Println(foo)
fmt.Println(bar)
}
type Foo int
type Bar int
func provideFoo() Foo {
return Foo(42)
}
func provideBar() Bar {
return Bar(99)
}

View File

@@ -0,0 +1,33 @@
// Copyright 2018 The Go Cloud Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//+build wireinject
package main
import (
"github.com/google/go-cloud/wire"
)
func injectFoo() Foo {
// This non-call statement makes this an invalid injector.
_ = 42
panic(wire.Build(provideFoo))
}
func injectBar() Bar {
// Two call statements are also invalid.
panic(wire.Build(provideBar))
panic(wire.Build(provideBar))
}

View File

@@ -0,0 +1 @@
example.com/foo

View File

@@ -0,0 +1,4 @@
a call to wire.Build indicates that this function is an injector, but injectors
must consist of only the wire.Build call and an optional return
a call to wire.Build indicates that this function is an injector, but injectors
must consist of only the wire.Build call and an optional return

View File

@@ -75,7 +75,11 @@ func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*as
if !ok {
continue
}
buildCall := isInjector(&pkgInfo.Info, fn)
buildCall, err := findInjectorBuild(&pkgInfo.Info, fn)
if err != nil {
ec.add(err)
continue
}
if buildCall == nil {
continue
}
@@ -113,7 +117,9 @@ func copyNonInjectorDecls(g *gen, files []*ast.File, info *types.Info) {
for _, decl := range f.Decls {
switch decl := decl.(type) {
case *ast.FuncDecl:
if isInjector(info, decl) != nil {
// OK to ignore error, as any error cases should already have
// been filtered out.
if buildCall, _ := findInjectorBuild(info, decl); buildCall != nil {
continue
}
case *ast.GenDecl: