diff --git a/internal/wire/parse.go b/internal/wire/parse.go index a1d939d..1b65312 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -193,7 +193,11 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) { if !ok { continue } - buildCall := isInjector(&pkgInfo.Info, fn) + buildCall, err := findInjectorBuild(&pkgInfo.Info, fn) + if err != nil { + ec.add(notePosition(prog.Fset.Position(fn.Pos()), fmt.Errorf("inject %s: %v", fn.Name.Name, err))) + continue + } if buildCall == nil { continue } @@ -770,53 +774,59 @@ func processInterfaceValue(fset *token.FileSet, info *types.Info, call *ast.Call }, nil } -// isInjector checks whether a given function declaration is an -// injector template, returning the wire.Build call. It returns nil if -// the function is not an injector template. -func isInjector(info *types.Info, fn *ast.FuncDecl) *ast.CallExpr { +// findInjectorBuild returns the wire.Build call if fn is an injector template. +// It returns nil if the function is not an injector template. +func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error) { if fn.Body == nil { - return nil + return nil, nil } - var only *ast.ExprStmt + numStatements := 0 + invalid := false + var wireBuildCall *ast.CallExpr for _, stmt := range fn.Body.List { switch stmt := stmt.(type) { case *ast.ExprStmt: - if only != nil { - return nil + numStatements++ + if numStatements > 1 { + invalid = true } - only = stmt + call, ok := stmt.X.(*ast.CallExpr) + if !ok { + continue + } + if qualifiedIdentObject(info, call.Fun) == types.Universe.Lookup("panic") { + if len(call.Args) != 1 { + continue + } + call, ok = call.Args[0].(*ast.CallExpr) + if !ok { + continue + } + } + buildObj := qualifiedIdentObject(info, call.Fun) + if buildObj == nil || buildObj.Pkg() == nil || !isWireImport(buildObj.Pkg().Path()) || buildObj.Name() != "Build" { + continue + } + wireBuildCall = call case *ast.EmptyStmt: // Do nothing. case *ast.ReturnStmt: // Allow the function to end in a return. - if only == nil { - return nil + if numStatements == 0 { + return nil, nil } default: - return nil + invalid = true } + } - if only == nil { - return nil + if wireBuildCall == nil { + return nil, nil } - call, ok := only.X.(*ast.CallExpr) - if !ok { - return nil + if invalid { + return nil, errors.New("a call to wire.Build indicates that this function is an injector, but injectors must consist of only the wire.Build call and an optional return") } - if qualifiedIdentObject(info, call.Fun) == types.Universe.Lookup("panic") { - if len(call.Args) != 1 { - return nil - } - call, ok = call.Args[0].(*ast.CallExpr) - if !ok { - return nil - } - } - buildObj := qualifiedIdentObject(info, call.Fun) - if buildObj == nil || buildObj.Pkg() == nil || !isWireImport(buildObj.Pkg().Path()) || buildObj.Name() != "Build" { - return nil - } - return call + return wireBuildCall, nil } func isWireImport(path string) bool { diff --git a/internal/wire/testdata/InvalidInjector/foo/foo.go b/internal/wire/testdata/InvalidInjector/foo/foo.go new file mode 100644 index 0000000..922b980 --- /dev/null +++ b/internal/wire/testdata/InvalidInjector/foo/foo.go @@ -0,0 +1,36 @@ +// Copyright 2018 The Go Cloud Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" +) + +func main() { + foo := injectFoo() + bar := injectBar() + fmt.Println(foo) + fmt.Println(bar) +} + +type Foo int +type Bar int + +func provideFoo() Foo { + return Foo(42) +} +func provideBar() Bar { + return Bar(99) +} diff --git a/internal/wire/testdata/InvalidInjector/foo/wire.go b/internal/wire/testdata/InvalidInjector/foo/wire.go new file mode 100644 index 0000000..128e88c --- /dev/null +++ b/internal/wire/testdata/InvalidInjector/foo/wire.go @@ -0,0 +1,33 @@ +// Copyright 2018 The Go Cloud Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build wireinject + +package main + +import ( + "github.com/google/go-cloud/wire" +) + +func injectFoo() Foo { + // This non-call statement makes this an invalid injector. + _ = 42 + panic(wire.Build(provideFoo)) +} + +func injectBar() Bar { + // Two call statements are also invalid. + panic(wire.Build(provideBar)) + panic(wire.Build(provideBar)) +} diff --git a/internal/wire/testdata/InvalidInjector/pkg b/internal/wire/testdata/InvalidInjector/pkg new file mode 100644 index 0000000..f7a5c8c --- /dev/null +++ b/internal/wire/testdata/InvalidInjector/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/wire/testdata/InvalidInjector/want/wire_errs.txt b/internal/wire/testdata/InvalidInjector/want/wire_errs.txt new file mode 100644 index 0000000..fbd6607 --- /dev/null +++ b/internal/wire/testdata/InvalidInjector/want/wire_errs.txt @@ -0,0 +1,4 @@ +a call to wire.Build indicates that this function is an injector, but injectors +must consist of only the wire.Build call and an optional return +a call to wire.Build indicates that this function is an injector, but injectors +must consist of only the wire.Build call and an optional return diff --git a/internal/wire/wire.go b/internal/wire/wire.go index f064aac..ef748cb 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -75,7 +75,11 @@ func generateInjectors(g *gen, pkgInfo *loader.PackageInfo) (injectorFiles []*as if !ok { continue } - buildCall := isInjector(&pkgInfo.Info, fn) + buildCall, err := findInjectorBuild(&pkgInfo.Info, fn) + if err != nil { + ec.add(err) + continue + } if buildCall == nil { continue } @@ -113,7 +117,9 @@ func copyNonInjectorDecls(g *gen, files []*ast.File, info *types.Info) { for _, decl := range f.Decls { switch decl := decl.(type) { case *ast.FuncDecl: - if isInjector(info, decl) != nil { + // OK to ignore error, as any error cases should already have + // been filtered out. + if buildCall, _ := findInjectorBuild(info, decl); buildCall != nil { continue } case *ast.GenDecl: