wire: add check command (google/go-cloud#207)

In the internal package, this expands the wire.Load function to run
the same solver as wire.Generate would on any injector function. For
completeness, I also print the injector functions in the gowire show
command.

A subsequent PR will add this as a step to Go Cloud CI.

Updates google/go-cloud#30
This commit is contained in:
Ross Light
2018-07-19 16:04:26 -07:00
parent 2943de1153
commit 5f0dd9ee8f
3 changed files with 188 additions and 66 deletions

View File

@@ -42,14 +42,18 @@ func main() {
err = generate(".") err = generate(".")
case len(os.Args) == 2 && os.Args[1] == "show": case len(os.Args) == 2 && os.Args[1] == "show":
err = show(".") err = show(".")
case len(os.Args) == 2:
err = generate(os.Args[1])
case len(os.Args) > 2 && os.Args[1] == "show": case len(os.Args) > 2 && os.Args[1] == "show":
err = show(os.Args[2:]...) err = show(os.Args[2:]...)
case len(os.Args) == 2 && os.Args[1] == "check":
err = check(".")
case len(os.Args) > 2 && os.Args[1] == "check":
err = check(os.Args[2:]...)
case len(os.Args) == 2:
err = generate(os.Args[1])
case len(os.Args) == 3 && os.Args[1] == "gen": case len(os.Args) == 3 && os.Args[1] == "gen":
err = generate(os.Args[2]) err = generate(os.Args[2])
default: default:
fmt.Fprintln(os.Stderr, "gowire: usage: gowire [gen] [PKG] | gowire show [...]") fmt.Fprintln(os.Stderr, "gowire: usage: gowire [gen] [PKG] | gowire show [...] | gowire check [...]")
os.Exit(64) os.Exit(64)
} }
if err != nil { if err != nil {
@@ -91,6 +95,7 @@ func generate(pkg string) error {
// Given one or more packages, show will find all the provider sets // Given one or more packages, show will find all the provider sets
// declared as top-level variables and print what other provider sets it // declared as top-level variables and print what other provider sets it
// imports and what outputs it can produce, given possible inputs. // imports and what outputs it can produce, given possible inputs.
// It also lists any injector functions defined in the package.
func show(pkgs ...string) error { func show(pkgs ...string) error {
wd, err := os.Getwd() wd, err := os.Getwd()
if err != nil { if err != nil {
@@ -144,7 +149,38 @@ func show(pkgs ...string) error {
} }
} }
} }
if len(info.Injectors) > 0 {
injectors := append([]*wire.Injector(nil), info.Injectors...)
sort.Slice(injectors, func(i, j int) bool {
if injectors[i].ImportPath == injectors[j].ImportPath {
return injectors[i].FuncName < injectors[j].FuncName
} }
return injectors[i].ImportPath < injectors[j].ImportPath
})
fmt.Printf("%sInjectors:%s\n", redBold, reset)
for _, in := range injectors {
fmt.Printf("\t%v\n", in)
}
}
}
if len(errs) > 0 {
logErrors(errs)
return errors.New("error loading packages")
}
return nil
}
// check runs the check subcommand.
//
// Given one or more packages, check will print any type-checking or
// Wire errors found with top-level variable provider sets or injector
// functions.
func check(pkgs ...string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
_, errs := wire.Load(&build.Default, wd, pkgs)
if len(errs) > 0 { if len(errs) > 0 {
logErrors(errs) logErrors(errs)
return errors.New("error loading packages") return errors.New("error loading packages")

View File

@@ -147,33 +147,21 @@ type Value struct {
// the provider sets' transitive dependencies. It may return both errors // the provider sets' transitive dependencies. It may return both errors
// and Info. // and Info.
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) { func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) {
ec := new(errorCollector) prog, errs := load(bctx, wd, pkgs)
conf := &loader.Config{ if len(errs) > 0 {
Build: bctx, return nil, errs
Cwd: wd,
TypeChecker: types.Config{
Error: func(err error) {
ec.add(err)
},
},
TypeCheckFuncBodies: func(string) bool { return false },
}
for _, p := range pkgs {
conf.Import(p)
}
prog, err := conf.Load()
if len(ec.errors) > 0 {
return nil, ec.errors
}
if err != nil {
return nil, []error{err}
} }
info := &Info{ info := &Info{
Fset: prog.Fset, Fset: prog.Fset,
Sets: make(map[ProviderSetID]*ProviderSet), Sets: make(map[ProviderSetID]*ProviderSet),
} }
oc := newObjectCache(prog) oc := newObjectCache(prog)
ec := new(errorCollector)
for _, pkgInfo := range prog.InitialPackages() { for _, pkgInfo := range prog.InitialPackages() {
if isWireImport(pkgInfo.Pkg.Path()) {
// The marker function package confuses analysis.
continue
}
scope := pkgInfo.Pkg.Scope() scope := pkgInfo.Pkg.Scope()
for _, name := range scope.Names() { for _, name := range scope.Names() {
obj := scope.Lookup(name) obj := scope.Lookup(name)
@@ -191,16 +179,129 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) {
id := ProviderSetID{ImportPath: pset.PkgPath, VarName: name} id := ProviderSetID{ImportPath: pset.PkgPath, VarName: name}
info.Sets[id] = pset info.Sets[id] = pset
} }
for _, f := range pkgInfo.Files {
for _, decl := range f.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
buildCall := isInjector(&pkgInfo.Info, fn)
if buildCall == nil {
continue
}
set, errs := oc.processNewSet(pkgInfo, buildCall)
if len(errs) > 0 {
ec.add(notePositionAll(prog.Fset.Position(fn.Pos()), errs)...)
continue
}
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
ins, out, err := injectorFuncSignature(sig)
if err != nil {
if w, ok := err.(*wireErr); ok {
ec.add(notePosition(w.position, fmt.Errorf("inject %s: %v", fn.Name.Name, w.error)))
} else {
ec.add(notePosition(prog.Fset.Position(fn.Pos()), fmt.Errorf("inject %s: %v", fn.Name.Name, err)))
}
continue
}
_, errs = solve(prog.Fset, out.out, ins, set)
if len(errs) > 0 {
ec.add(mapErrors(errs, func(e error) error {
if w, ok := e.(*wireErr); ok {
return notePosition(w.position, fmt.Errorf("inject %s: %v", fn.Name.Name, w.error))
}
return notePosition(prog.Fset.Position(fn.Pos()), fmt.Errorf("inject %s: %v", fn.Name.Name, e))
})...)
continue
}
info.Injectors = append(info.Injectors, &Injector{
ImportPath: pkgInfo.Pkg.Path(),
FuncName: fn.Name.Name,
})
}
}
} }
return info, ec.errors return info, ec.errors
} }
// load typechecks the packages, including function body type checking
// for the packages directly named.
func load(bctx *build.Context, wd string, pkgs []string) (*loader.Program, []error) {
var foundPkgs []*build.Package
ec := new(errorCollector)
for _, name := range pkgs {
p, err := bctx.Import(name, wd, build.FindOnly)
if err != nil {
ec.add(err)
continue
}
foundPkgs = append(foundPkgs, p)
}
if len(ec.errors) > 0 {
return nil, ec.errors
}
conf := &loader.Config{
Build: bctx,
Cwd: wd,
TypeChecker: types.Config{
Error: func(err error) {
ec.add(err)
},
},
TypeCheckFuncBodies: func(path string) bool {
return importPathInPkgList(foundPkgs, path)
},
FindPackage: func(bctx *build.Context, importPath, fromDir string, mode build.ImportMode) (*build.Package, error) {
// Optimistically try to load in the package with normal build tags.
pkg, err := bctx.Import(importPath, fromDir, mode)
// If this is the generated package, then load it in with the
// wireinject build tag to pick up the injector template. Since
// the *build.Context is shared between calls to FindPackage, this
// uses a copy.
if pkg != nil && importPathInPkgList(foundPkgs, pkg.ImportPath) {
bctx2 := new(build.Context)
*bctx2 = *bctx
n := len(bctx2.BuildTags)
bctx2.BuildTags = append(bctx2.BuildTags[:n:n], "wireinject")
pkg, err = bctx2.Import(importPath, fromDir, mode)
}
return pkg, err
},
}
for _, name := range pkgs {
conf.Import(name)
}
prog, err := conf.Load()
if len(ec.errors) > 0 {
return nil, ec.errors
}
if err != nil {
return nil, []error{err}
}
return prog, nil
}
func importPathInPkgList(pkgs []*build.Package, path string) bool {
for _, p := range pkgs {
if path == p.ImportPath {
return true
}
}
return false
}
// Info holds the result of Load. // Info holds the result of Load.
type Info struct { type Info struct {
Fset *token.FileSet Fset *token.FileSet
// Sets contains all the provider sets in the initial packages. // Sets contains all the provider sets in the initial packages.
Sets map[ProviderSetID]*ProviderSet Sets map[ProviderSetID]*ProviderSet
// Injectors contains all the injector functions in the initial packages.
// The order is undefined.
Injectors []*Injector
} }
// A ProviderSetID identifies a named provider set. // A ProviderSetID identifies a named provider set.
@@ -214,6 +315,17 @@ func (id ProviderSetID) String() string {
return strconv.Quote(id.ImportPath) + "." + id.VarName return strconv.Quote(id.ImportPath) + "." + id.VarName
} }
// An Injector describes an injector function.
type Injector struct {
ImportPath string
FuncName string
}
// String returns the injector name as ""path/to/pkg".Foo".
func (in *Injector) String() string {
return strconv.Quote(in.ImportPath) + "." + in.FuncName
}
// objectCache is a lazily evaluated mapping of objects to Wire structures. // objectCache is a lazily evaluated mapping of objects to Wire structures.
type objectCache struct { type objectCache struct {
prog *loader.Program prog *loader.Program
@@ -462,6 +574,19 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro
return provider, nil return provider, nil
} }
func injectorFuncSignature(sig *types.Signature) ([]types.Type, outputSignature, error) {
out, err := funcOutput(sig)
if err != nil {
return nil, outputSignature{}, err
}
params := sig.Params()
given := make([]types.Type, params.Len())
for i := 0; i < params.Len(); i++ {
given[i] = params.At(i).Type()
}
return given, out, nil
}
type outputSignature struct { type outputSignature struct {
out types.Type out types.Type
cleanup bool cleanup bool
@@ -653,7 +778,7 @@ func isInjector(info *types.Info, fn *ast.FuncDecl) *ast.CallExpr {
} }
} }
buildObj := qualifiedIdentObject(info, call.Fun) buildObj := qualifiedIdentObject(info, call.Fun)
if !isWireImport(buildObj.Pkg().Path()) || buildObj.Name() != "Build" { if buildObj == nil || buildObj.Pkg() == nil || !isWireImport(buildObj.Pkg().Path()) || buildObj.Name() != "Build" {
return nil return nil
} }
return call return call

View File

@@ -39,48 +39,9 @@ import (
// Generate performs dependency injection for a single package, // Generate performs dependency injection for a single package,
// returning the gofmt'd Go source code. // returning the gofmt'd Go source code.
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) {
mainPkg, err := bctx.Import(pkg, wd, build.FindOnly) prog, errs := load(bctx, wd, []string{pkg})
if err != nil { if len(errs) > 0 {
return nil, []error{fmt.Errorf("load: %v", err)} return nil, errs
}
ec := new(errorCollector)
conf := &loader.Config{
Build: bctx,
Cwd: wd,
TypeChecker: types.Config{
Error: func(err error) {
ec.add(err)
},
},
TypeCheckFuncBodies: func(path string) bool {
return path == mainPkg.ImportPath
},
FindPackage: func(bctx *build.Context, importPath, fromDir string, mode build.ImportMode) (*build.Package, error) {
// Optimistically try to load in the package with normal build tags.
pkg, err := bctx.Import(importPath, fromDir, mode)
// If this is the generated package, then load it in with the
// wireinject build tag to pick up the injector template. Since
// the *build.Context is shared between calls to FindPackage, this
// uses a copy.
if pkg != nil && pkg.ImportPath == mainPkg.ImportPath {
bctx2 := new(build.Context)
*bctx2 = *bctx
n := len(bctx2.BuildTags)
bctx2.BuildTags = append(bctx2.BuildTags[:n:n], "wireinject")
pkg, err = bctx2.Import(importPath, fromDir, mode)
}
return pkg, err
},
}
conf.Import(pkg)
prog, err := conf.Load()
if len(ec.errors) > 0 {
return nil, ec.errors
}
if err != nil {
return nil, []error{err}
} }
if len(prog.InitialPackages()) != 1 { if len(prog.InitialPackages()) != 1 {
// This is more of a violated precondition than anything else. // This is more of a violated precondition than anything else.