From 26c8618466275de08d81be686f7b18729bf45cd4 Mon Sep 17 00:00:00 2001 From: Ross Light Date: Mon, 26 Mar 2018 07:39:00 -0700 Subject: [PATCH] goose: dependency injection proof of concept See documentation and demo for usage and known limitations. Reviewed-by: Herbie Ong --- README.md | 203 ++++++++++ internal/goose/goose.go | 715 +++++++++++++++++++++++++++++++++++ internal/goose/goose_test.go | 406 ++++++++++++++++++++ main.go | 52 +++ 4 files changed, 1376 insertions(+) create mode 100644 README.md create mode 100644 internal/goose/goose.go create mode 100644 internal/goose/goose_test.go create mode 100644 main.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..8400a00 --- /dev/null +++ b/README.md @@ -0,0 +1,203 @@ +# goose: Compile-Time Dependency Injection for Go + +goose is a compile-time [dependency injection][] framework for Go, inspired by +[Dagger][]. It works by using Go code to specify dependencies, then +generating code to create those structures, mimicking the code that a user +might have hand-written. + +[dependency injection]: https://en.wikipedia.org/wiki/Dependency_injection +[Dagger]: https://google.github.io/dagger/ + +## Usage Guide + +### Defining Providers + +The primary mechanism in goose is the **provider**: a function that can +produce a value, annotated with the special `goose:provide` directive. These +functions are ordinary Go code and live in packages. + +```go +package module + +type Foo int + +// goose:provide + +// ProvideFoo returns a Foo. +func ProvideFoo() Foo { + return 42 +} +``` + +Providers are always part of a **module**: if there is no module name specified +on the `//goose:provide` line, then `Module` is used. + +Providers can specify dependencies with parameters: + +```go +package module + +// goose:provide SuperModule + +type Bar int + +// ProvideBar returns a Bar: a negative Foo. +func ProvideBar(foo Foo) Bar { + return Bar(-foo) +} +``` + +Providers can also return errors: + +```go +package module + +import ( + "context" + "errors" +) + +type Baz int + +// goose:provide SuperModule + +// ProvideBaz returns a value if Bar is not zero. +func ProvideBaz(ctx context.Context, bar Bar) (Baz, error) { + if bar == 0 { + return 0, errors.New("cannot provide baz when bar is zero") + } + return Baz(bar), nil +} +``` + +Modules can import other modules. To import `Module` in `SuperModule`: + +```go +// goose:import SuperModule Module +``` + +### Injectors + +An application can use these providers by declaring an **injector**: a +generated function that calls providers in dependency order. + +An injector is declared by writing a function declaration without a body in a +file guarded by a `gooseinject` build tag. Let's say that the above providers +were defined in a package called `example.com/module`. The following would +declare an injector to obtain a `Baz`: + +```go +//+build gooseinject + +package main + +import ( + "context" + + "example.com/module" +) + +// goose:use module.SuperModule + +func initializeApp(ctx context.Context) (module.Baz, error) +``` + +Like providers, injectors can be parameterized on inputs (which then get sent to +providers) and can return errors. The `goose:use` directive specifies the +modules to use in the injection. Both `goose:use` and `goose:import` use the +same syntax for referencing modules: an optional import qualifier (either a +package name or a quoted import path) with a dot, followed by the module name. +For example: `SamePackageModule`, `foo.Bar`, or `"example.com/foo".Bar`. + +You can generate the injector using goose: + +``` +goose +``` + +Or you can add the line `//go:generate goose` to another file in your package to +use [`go generate`]: + +``` +go generate +``` + +(Adding the line to the injection declaration file will be silently ignored by +`go generate`.) + +goose will produce an implementation of the injector that looks something like +this: + +```go +// Code generated by goose. DO NOT EDIT. + +//+build !gooseinject + +package main + +import ( + "example.com/module" +) + +func initializeApp(ctx context.Context) (module.Baz, error) { + foo := module.ProvideFoo() + bar := module.ProvideBar(foo) + baz, err := module.ProvideBaz(ctx, bar) + if err != nil { + return 0, err + } + return baz, nil +} +``` + +As you can see, the output is very close to what a developer would write +themselves. Further, there is no dependency on goose at runtime: all of the +written code is just normal Go code, and can be used without goose. + +[`go generate`]: https://blog.golang.org/generate + +## Best Practices + +goose is still not mature yet, but guidance that applies to Dagger generally +applies to goose as well. In particular, when thinking about how to group a +package of providers, follow the same [guidance](https://google.github.io/dagger/testing.html#organize-modules-for-testability) as Dagger: + +> Some [...] bindings will have reasonable alternatives, especially for +> testing, and others will not. For example, there are likely to be +> alternative bindings for a type like `AuthManager`: one for testing, others +> for different authentication/authorization protocols. +> +> But on the other hand, if the `AuthManager` interface has a method that +> returns the currently logged-in user, you might want to [export a provider of +> `User` that simply calls `CurrentUser()`] on the `AuthManager`. That +> published binding is unlikely to ever need an alternative. +> +> Once you’ve classified your bindings into [...] bindings with reasonable +> alternatives [and] bindings without reasonable alternatives, consider +> arranging them into packages like this: +> +> - One [package] for each [...] binding with a reasonable alternative. (If +> you are also writing the alternatives, each one gets its own [package].) That +> [package] contains exactly one provider. +> - All [...] bindings with no reasonable alternatives go into [packages] +> organized along functional lines. +> - The [packages] should each include the no-reasonable-alternative [packages] that +> require the [...] bindings each provides. + +One goose-specific practice though: create one-off types where in Java you +would use a binding annotation. + +## Future Work + +- The names of imports and provider results in the generated code are not + actually as nice as shown above. I'd like to make them nicer in the + common cases while ensuring uniqueness. +- I'd like to support optional and multiple bindings. +- At the moment, the entire transitive closure of all dependencies are read + for providers. It might be better to have provider imports be opt-in, but + that seems like too many levels of magic. +- Currently, all dependency satisfaction is done using identity. I'd like to + use a limited form of assignability for interface types, but I'm unsure + how well this implicit satisfaction will work in practice. +- Errors emitted by goose are not very good, but it has all the information + it needs to emit better ones. diff --git a/internal/goose/goose.go b/internal/goose/goose.go new file mode 100644 index 0000000..d50a457 --- /dev/null +++ b/internal/goose/goose.go @@ -0,0 +1,715 @@ +// Package goose provides compile-time dependency injection logic as a +// Go library. +package goose + +import ( + "bytes" + "fmt" + "go/ast" + "go/build" + "go/format" + "go/parser" + "go/token" + "go/types" + "sort" + "strconv" + "strings" + + "golang.org/x/tools/go/loader" + "golang.org/x/tools/go/types/typeutil" +) + +// Generate performs dependency injection for a single package, +// returning the gofmt'd Go source code. +func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { + // TODO(light): allow errors + // TODO(light): stop errors from printing to stderr + conf := &loader.Config{ + Build: new(build.Context), + ParserMode: parser.ParseComments, + Cwd: wd, + } + *conf.Build = *bctx + n := len(conf.Build.BuildTags) + conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject") + conf.Import(pkg) + prog, err := conf.Load() + if err != nil { + return nil, fmt.Errorf("load: %v", err) + } + if len(prog.InitialPackages()) != 1 { + // This is more of a violated precondition than anything else. + return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages())) + } + pkgInfo := prog.InitialPackages()[0] + g := newGen(pkgInfo.Pkg.Path()) + mc := newModuleCache(prog) + var directives []directive + for _, f := range pkgInfo.Files { + if !isInjectFile(f) { + continue + } + fileScope := pkgInfo.Scopes[f] + cmap := ast.NewCommentMap(prog.Fset, f, f.Comments) + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + directives = directives[:0] + for _, c := range cmap[fn] { + directives = extractDirectives(directives, c) + } + modules := make([]moduleRef, 0, len(directives)) + for _, d := range directives { + if d.kind != "use" { + return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind) + } + ref, err := parseModuleRef(d.line, fileScope, g.currPackage, d.pos) + if err != nil { + return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err) + } + modules = append(modules, ref) + } + sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature) + if err := g.inject(mc, fn.Name.Name, sig, modules); err != nil { + return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err) + } + } + } + goSrc := g.frame(pkgInfo.Pkg.Name()) + fmtSrc, err := format.Source(goSrc) + if err != nil { + // This is likely a bug from a poorly generated source file. + // Return an error and the unformatted source. + return goSrc, err + } + return fmtSrc, nil +} + +// gen is the generator state. +type gen struct { + currPackage string + buf bytes.Buffer + imports map[string]string + n int +} + +func newGen(pkg string) *gen { + return &gen{ + currPackage: pkg, + imports: make(map[string]string), + } +} + +// frame bakes the built up source body into an unformatted Go source file. +func (g *gen) frame(pkgName string) []byte { + if g.buf.Len() == 0 { + return nil + } + var buf bytes.Buffer + buf.WriteString("// Code generated by goose. DO NOT EDIT.\n\n//+build !gooseinject\n\npackage ") + buf.WriteString(pkgName) + buf.WriteString("\n\n") + if len(g.imports) > 0 { + buf.WriteString("import (\n") + imps := make([]string, 0, len(g.imports)) + for path := range g.imports { + imps = append(imps, path) + } + sort.Strings(imps) + for _, path := range imps { + fmt.Fprintf(&buf, "\t%s %q\n", g.imports[path], path) + } + buf.WriteString(")\n\n") + } + buf.Write(g.buf.Bytes()) + return buf.Bytes() +} + +// inject emits the code for an injector. +func (g *gen) inject(mc *moduleCache, name string, sig *types.Signature, modules []moduleRef) error { + results := sig.Results() + returnsErr := false + switch results.Len() { + case 0: + return fmt.Errorf("inject %s: no return values", name) + case 1: + // nothing special + case 2: + if t := results.At(1).Type(); !types.Identical(t, errorType) { + return fmt.Errorf("inject %s: second return type is %s; must be error", name, types.TypeString(t, nil)) + } + returnsErr = true + default: + return fmt.Errorf("inject %s: too many return values", name) + } + outType := results.At(0).Type() + params := sig.Params() + given := make([]types.Type, params.Len()) + for i := 0; i < params.Len(); i++ { + given[i] = params.At(i).Type() + } + calls, err := solve(mc, outType, given, modules) + if err != nil { + return err + } + for i := range calls { + if calls[i].hasErr && !returnsErr { + return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil)) + } + } + g.p("func %s(", name) + for i := 0; i < params.Len(); i++ { + if i > 0 { + g.p(", ") + } + pi := params.At(i) + g.p("%s %s", pi.Name(), types.TypeString(pi.Type(), g.qualifyPkg)) + } + if returnsErr { + g.p(") (%s, error) {\n", types.TypeString(outType, g.qualifyPkg)) + } else { + g.p(") %s {\n", types.TypeString(outType, g.qualifyPkg)) + } + zv := zeroValue(outType, g.qualifyPkg) + for i := range calls { + c := &calls[i] + g.p("\tv%d", i) + if c.hasErr { + g.p(", err") + } + g.p(" := %s(", g.qualifiedID(c.importPath, c.funcName)) + for j, a := range c.args { + if j > 0 { + g.p(", ") + } + if a < params.Len() { + g.p("%s", params.At(a).Name()) + } else { + g.p("v%d", a-params.Len()) + } + } + g.p(")\n") + if c.hasErr { + g.p("\tif err != nil {\n") + // TODO(light): give information about failing provider + g.p("\t\treturn %s, err\n", zv) + g.p("\t}\n") + } + } + if len(calls) == 0 { + for i := range given { + if types.Identical(outType, given[i]) { + g.p("\treturn %s", params.At(i).Name()) + break + } + } + } else { + g.p("\treturn v%d", len(calls)-1) + } + if returnsErr { + g.p(", nil") + } + g.p("\n}\n") + return nil +} + +func (g *gen) qualifiedID(path, sym string) string { + name := g.qualifyImport(path) + if name == "" { + return sym + } + return name + "." + sym +} + +func (g *gen) qualifyImport(path string) string { + if path == g.currPackage { + return "" + } + if name := g.imports[path]; name != "" { + return name + } + name := fmt.Sprintf("pkg%d", g.n) + g.n++ + g.imports[path] = name + return name +} + +func (g *gen) qualifyPkg(pkg *types.Package) string { + return g.qualifyImport(pkg.Path()) +} + +func (g *gen) p(format string, args ...interface{}) { + fmt.Fprintf(&g.buf, format, args...) +} + +// A module describes a set of providers. The zero value is an empty +// module. +type module struct { + providers []*providerInfo + imports []moduleImport +} + +type moduleImport struct { + moduleRef + pos token.Pos +} + +const implicitModuleName = "Module" + +// findModules processes a package and extracts the modules declared in it. +func findModules(fset *token.FileSet, pkg *types.Package, typeInfo *types.Info, files []*ast.File) (map[string]*module, error) { + modules := make(map[string]*module) + var directives []directive + for _, f := range files { + fileScope := typeInfo.Scopes[f] + for _, c := range f.Comments { + directives = extractDirectives(directives[:0], c) + for _, d := range directives { + switch d.kind { + case "provide", "use": + // handled later + case "import": + if fileScope == nil { + return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fset.File(f.Pos()).Name()) + } + var name, spec string + if strings.HasPrefix(d.line, `"`) { + name, spec = implicitModuleName, d.line + } else if i := strings.IndexByte(d.line, ' '); i != -1 { + name, spec = d.line[:i], d.line[i+1:] + } else { + name, spec = implicitModuleName, d.line + } + ref, err := parseModuleRef(spec, fileScope, pkg.Path(), d.pos) + if err != nil { + return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err) + } + if ref.importPath != pkg.Path() { + imported := false + for _, imp := range pkg.Imports() { + if ref.importPath == imp.Path() { + imported = true + break + } + } + if !imported { + return nil, fmt.Errorf("%v: module %s imports %q which is not in the package's imports", fset.Position(d.pos), name, ref.importPath) + } + } + if mod := modules[name]; mod != nil { + found := false + for _, other := range mod.imports { + if ref == other.moduleRef { + found = true + break + } + } + if !found { + mod.imports = append(mod.imports, moduleImport{moduleRef: ref, pos: d.pos}) + } + } else { + modules[name] = &module{ + imports: []moduleImport{{moduleRef: ref, pos: d.pos}}, + } + } + default: + return nil, fmt.Errorf("%v: unknown directive %s", fset.Position(d.pos), d.kind) + } + } + } + cmap := ast.NewCommentMap(fset, f, f.Comments) + for _, decl := range f.Decls { + directives = directives[:0] + for _, cg := range cmap[decl] { + directives = extractDirectives(directives, cg) + } + fn, isFunction := decl.(*ast.FuncDecl) + var providerModule string + for _, d := range directives { + if d.kind != "provide" { + continue + } + if providerModule != "" { + return nil, fmt.Errorf("%v: multiple provide directives for %s", fset.Position(d.pos), fn.Name.Name) + } + if !isFunction { + return nil, fmt.Errorf("%v: only functions can be marked as providers", fset.Position(d.pos)) + } + if d.line == "" { + providerModule = implicitModuleName + } else { + // TODO(light): validate identifier + providerModule = d.line + } + } + if providerModule == "" { + continue + } + fpos := fn.Pos() + sig := typeInfo.ObjectOf(fn.Name).Type().(*types.Signature) + r := sig.Results() + var hasErr bool + switch r.Len() { + case 1: + hasErr = false + case 2: + if t := r.At(1).Type(); !types.Identical(t, errorType) { + return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fset.Position(fpos), fn.Name.Name) + } + hasErr = true + default: + return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name.Name) + } + out := r.At(0).Type() + p := sig.Params() + provider := &providerInfo{ + importPath: pkg.Path(), + funcName: fn.Name.Name, + pos: fn.Pos(), + args: make([]types.Type, p.Len()), + out: out, + hasErr: hasErr, + } + for i := 0; i < p.Len(); i++ { + provider.args[i] = p.At(i).Type() + for j := 0; j < i; j++ { + if types.Identical(provider.args[i], provider.args[j]) { + return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.args[j], nil)) + } + } + } + if mod := modules[providerModule]; mod != nil { + for _, other := range mod.providers { + if types.Identical(other.out, provider.out) { + return nil, fmt.Errorf("%v: module %s has multiple providers for %s (previous declaration at %v)", fset.Position(fpos), providerModule, types.TypeString(provider.out, nil), fset.Position(other.pos)) + } + } + mod.providers = append(mod.providers, provider) + } else { + modules[providerModule] = &module{ + providers: []*providerInfo{provider}, + } + } + } + } + return modules, nil +} + +// moduleCache is a lazily evaluated index of modules. +type moduleCache struct { + modules map[string]map[string]*module + fset *token.FileSet + prog *loader.Program +} + +func newModuleCache(prog *loader.Program) *moduleCache { + return &moduleCache{ + fset: prog.Fset, + prog: prog, + } +} + +func (mc *moduleCache) get(ref moduleRef) (*module, error) { + if mods, cached := mc.modules[ref.importPath]; cached { + mod := mods[ref.moduleName] + if mod == nil { + return nil, fmt.Errorf("no such module %s in package %q", ref.moduleName, ref.importPath) + } + return mod, nil + } + if mc.modules == nil { + mc.modules = make(map[string]map[string]*module) + } + pkg, info, files, err := mc.getpkg(ref.importPath) + if err != nil { + mc.modules[ref.importPath] = nil + return nil, fmt.Errorf("analyze package: %v", err) + } + mods, err := findModules(mc.fset, pkg, info, files) + if err != nil { + mc.modules[ref.importPath] = nil + return nil, err + } + mc.modules[ref.importPath] = mods + mod := mods[ref.moduleName] + if mod == nil { + return nil, fmt.Errorf("no such module %s in package %q", ref.moduleName, ref.importPath) + } + return mod, nil +} + +func (mc *moduleCache) getpkg(path string) (*types.Package, *types.Info, []*ast.File, error) { + // TODO(light): allow other implementations for testing + + pkg := mc.prog.Package(path) + if pkg == nil { + return nil, nil, nil, fmt.Errorf("package %q not found", path) + } + return pkg.Pkg, &pkg.Info, pkg.Files, nil +} + +// solve finds the sequence of calls required to produce an output type +// with an optional set of provided inputs. +func solve(mc *moduleCache, out types.Type, given []types.Type, modules []moduleRef) ([]call, error) { + for i, g := range given { + for _, h := range given[:i] { + if types.Identical(g, h) { + return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil)) + } + } + } + providers, err := buildProviderMap(mc, modules) + if err != nil { + return nil, err + } + + // Start building the mapping of type to local variable of the given type. + // The first len(given) local variables are the given types. + index := new(typeutil.Map) + for i, g := range given { + if p := providers.At(g); p != nil { + pp := p.(*providerInfo) + return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos)) + } + index.Set(g, i) + } + + // Topological sort of the directed graph defined by the providers + // using a depth-first search. The graph may contain cycles, which + // should trigger an error. + var calls []call + var visit func(trail []types.Type) error + visit = func(trail []types.Type) error { + typ := trail[len(trail)-1] + if index.At(typ) != nil { + return nil + } + for _, t := range trail[:len(trail)-1] { + if types.Identical(typ, t) { + // TODO(light): describe cycle + return fmt.Errorf("cycle for %s", types.TypeString(typ, nil)) + } + } + + p, _ := providers.At(typ).(*providerInfo) + if p == nil { + if len(trail) == 1 { + return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil)) + } + // TODO(light): give name of provider + return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2], nil)) + } + for _, a := range p.args { + // TODO(light): this will discard grown trail arrays. + if err := visit(append(trail, a)); err != nil { + return err + } + } + args := make([]int, len(p.args)) + for i := range p.args { + args[i] = index.At(p.args[i]).(int) + } + index.Set(typ, len(given)+len(calls)) + calls = append(calls, call{ + importPath: p.importPath, + funcName: p.funcName, + args: args, + out: typ, + hasErr: p.hasErr, + }) + return nil + } + if err := visit([]types.Type{out}); err != nil { + return nil, err + } + return calls, nil +} + +func buildProviderMap(mc *moduleCache, modules []moduleRef) (*typeutil.Map, error) { + type nextEnt struct { + to moduleRef + + from moduleRef + pos token.Pos + } + + pm := new(typeutil.Map) // to *providerInfo + visited := make(map[moduleRef]struct{}) + var next []nextEnt + for _, ref := range modules { + next = append(next, nextEnt{to: ref}) + } + for len(next) > 0 { + curr := next[0] + copy(next, next[1:]) + next = next[:len(next)-1] + if _, skip := visited[curr.to]; skip { + continue + } + visited[curr.to] = struct{}{} + mod, err := mc.get(curr.to) + if err != nil { + if !curr.pos.IsValid() { + return nil, err + } + return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err) + } + for _, p := range mod.providers { + if prev := pm.At(p.out); prev != nil { + pos := mc.fset.Position(p.pos) + typ := types.TypeString(p.out, nil) + prevPos := mc.fset.Position(prev.(*providerInfo).pos) + if curr.from.importPath != "" { + return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) + } + return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos) + } + pm.Set(p.out, p) + } + for _, imp := range mod.imports { + next = append(next, nextEnt{to: imp.moduleRef, from: curr.to, pos: imp.pos}) + } + } + return pm, nil +} + +// A call represents a step of an injector function. +type call struct { + // importPath and funcName identify the provider function to call. + importPath string + funcName string + + // args is a list of arguments to call the provider with. Each element is either: + // a) one of the givens (args[i] < len(given)) or + // b) the result of a previous provider call (args[i] >= len(given)). + args []int + + // out is the type produced by this provider call. + out types.Type + + // hasErr is true if the provider call returns an error. + hasErr bool +} + +// providerInfo records the signature of a provider function. +type providerInfo struct { + importPath string + funcName string + pos token.Pos + args []types.Type + out types.Type + hasErr bool +} + +// A moduleRef is a parsed reference to a collection of providers. +type moduleRef struct { + importPath string + moduleName string +} + +func parseModuleRef(ref string, s *types.Scope, pkg string, pos token.Pos) (moduleRef, error) { + // TODO(light): verify that module name is an identifier before returning + + i := strings.LastIndexByte(ref, '.') + if i == -1 { + return moduleRef{importPath: pkg, moduleName: ref}, nil + } + imp, name := ref[:i], ref[i+1:] + if strings.HasPrefix(imp, `"`) { + path, err := strconv.Unquote(imp) + if err != nil { + return moduleRef{}, fmt.Errorf("parse module reference %q: bad import path", ref) + } + return moduleRef{importPath: path, moduleName: name}, nil + } + _, obj := s.LookupParent(imp, pos) + if obj == nil { + return moduleRef{}, fmt.Errorf("parse module reference %q: unknown identifier %s", ref, imp) + } + pn, ok := obj.(*types.PkgName) + if !ok { + return moduleRef{}, fmt.Errorf("parse module reference %q: %s does not name a package", ref, imp) + } + return moduleRef{importPath: pn.Imported().Path(), moduleName: name}, nil +} + +func (ref moduleRef) String() string { + return strconv.Quote(ref.importPath) + "." + ref.moduleName +} + +type directive struct { + pos token.Pos + kind string + line string +} + +func extractDirectives(d []directive, cg *ast.CommentGroup) []directive { + const prefix = "goose:" + text := cg.Text() + for len(text) > 0 { + text = strings.TrimLeft(text, " \t\r\n") + if !strings.HasPrefix(text, prefix) { + break + } + line := text[len(prefix):] + if i := strings.IndexByte(line, '\n'); i != -1 { + line, text = line[:i], line[i+1:] + } else { + text = "" + } + if i := strings.IndexByte(line, ' '); i != -1 { + d = append(d, directive{ + kind: line[:i], + line: strings.TrimSpace(line[i+1:]), + pos: cg.Pos(), // TODO(light): more precise position + }) + } else { + d = append(d, directive{ + kind: line, + pos: cg.Pos(), // TODO(light): more precise position + }) + } + } + return d +} + +// isInjectFile reports whether a given file is an injection template. +func isInjectFile(f *ast.File) bool { + // TODO(light): better determination + for _, cg := range f.Comments { + text := cg.Text() + if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") { + return true + } + } + return false +} + +// zeroValue returns the shortest expression that evaluates to the zero +// value for the given type. +func zeroValue(t types.Type, qf types.Qualifier) string { + switch u := t.Underlying().(type) { + case *types.Array, *types.Struct: + return types.TypeString(t, qf) + "{}" + case *types.Basic: + info := u.Info() + switch { + case info&types.IsBoolean != 0: + return "false" + case info&(types.IsInteger|types.IsFloat|types.IsComplex) != 0: + return "0" + case info&types.IsString != 0: + return `""` + default: + panic("unreachable") + } + case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice: + return "nil" + default: + panic("unreachable") + } +} + +var errorType = types.Universe.Lookup("error").Type() diff --git a/internal/goose/goose_test.go b/internal/goose/goose_test.go new file mode 100644 index 0000000..9441547 --- /dev/null +++ b/internal/goose/goose_test.go @@ -0,0 +1,406 @@ +package goose + +import ( + "bytes" + "fmt" + "go/build" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "testing" +) + +// TODO(light): pull this out into a testdata directory + +var tests = []struct { + name string + files map[string]string + pkg string + wantOutput string + wantError bool +}{ + { + name: "No-op build", + files: map[string]string{ + "foo/foo.go": `package main; import "fmt"; func main() { fmt.Println("Hello, World!"); }`, + }, + pkg: "foo", + wantOutput: "Hello, World!\n", + }, + { + name: "Niladic identity provider", + files: map[string]string{ + "foo/foo.go": `package main +import "fmt" +func main() { fmt.Println(injectedMessage()); } + +//goose:provide + +// provideMessage provides a friendly user greeting. +func provideMessage() string { return "Hello, World!"; } +`, + "foo/foo_goose.go": `//+build gooseinject + +package main + +//goose:use Module + +func injectedMessage() string +`, + }, + pkg: "foo", + wantOutput: "Hello, World!\n", + }, + { + name: "Missing use", + files: map[string]string{ + "foo/foo.go": `package main +import "fmt" +func main() { fmt.Println(injectedMessage()); } + +//goose:provide + +// provideMessage provides a friendly user greeting. +func provideMessage() string { return "Hello, World!"; } +`, + "foo/foo_goose.go": `//+build gooseinject + +package main + +func injectedMessage() string +`, + }, + pkg: "foo", + wantError: true, + }, + { + name: "Chain", + files: map[string]string{ + "foo/foo.go": `package main +import "fmt" +func main() { + fmt.Println(injectFooBar()) +} + +type Foo int +type FooBar int + +//goose:provide +func provideFoo() Foo { return 41 } + +//goose:provide +func provideFooBar(foo Foo) FooBar { return FooBar(foo) + 1 } +`, + "foo/foo_goose.go": `//+build gooseinject + +package main + +//goose:use Module + +func injectFooBar() FooBar +`, + }, + pkg: "foo", + wantOutput: "42\n", + }, + { + name: "Two deps", + files: map[string]string{ + "foo/foo.go": `package main +import "fmt" +func main() { + fmt.Println(injectFooBar()) +} + +type Foo int +type Bar int +type FooBar int + +//goose:provide +func provideFoo() Foo { return 40 } + +//goose:provide +func provideBar() Bar { return 2 } + +//goose:provide +func provideFooBar(foo Foo, bar Bar) FooBar { return FooBar(foo) + FooBar(bar) } +`, + "foo/foo_goose.go": `//+build gooseinject + +package main + +//goose:use Module + +func injectFooBar() FooBar +`, + }, + pkg: "foo", + wantOutput: "42\n", + }, + { + name: "Inject input", + files: map[string]string{ + "foo/foo.go": `package main +import "fmt" +func main() { + fmt.Println(injectFooBar(40)) +} + +type Foo int +type Bar int +type FooBar int + +//goose:provide +func provideBar() Bar { return 2 } + +//goose:provide +func provideFooBar(foo Foo, bar Bar) FooBar { return FooBar(foo) + FooBar(bar) } +`, + "foo/foo_goose.go": `//+build gooseinject + +package main + +//goose:use Module + +func injectFooBar(foo Foo) FooBar +`, + }, + pkg: "foo", + wantOutput: "42\n", + }, + { + name: "Inject input conflict", + files: map[string]string{ + "foo/foo.go": `package main +import "fmt" +func main() { + fmt.Println(injectBar(40)) +} + +type Foo int +type Bar int + +//goose:provide +func provideFoo() Foo { return -888 } + +//goose:provide +func provideBar(foo Foo) Bar { return 2 } +`, + "foo/foo_goose.go": `//+build gooseinject + +package main + +//goose:use Module + +func injectBar(foo Foo) Bar +`, + }, + pkg: "foo", + wantError: true, + }, + { + name: "Return error", + files: map[string]string{ + "foo/foo.go": `package main +import "errors" +import "fmt" +import "strings" +func main() { + foo, err := injectFoo() + fmt.Println(foo) + if err == nil { + fmt.Println("") + } else { + fmt.Println(strings.Contains(err.Error(), "there is no Foo")) + } +} + +type Foo int + +//goose:provide +func provideFoo() (Foo, error) { return 42, errors.New("there is no Foo") } +`, + "foo/foo_goose.go": `//+build gooseinject + +package main + +//goose:use Module + +func injectFoo() (Foo, error) +`, + }, + pkg: "foo", + wantOutput: "0\ntrue\n", + }, +} + +func TestGeneratedCode(t *testing.T) { + if _, err := os.Stat(filepath.Join(build.Default.GOROOT, "bin", "go")); err != nil { + t.Fatalf("go toolchain not available: %v", err) + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gopath, err := ioutil.TempDir("", "goose_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(gopath) + bctx := &build.Context{ + GOARCH: build.Default.GOARCH, + GOOS: build.Default.GOOS, + GOROOT: build.Default.GOROOT, + GOPATH: gopath, + CgoEnabled: build.Default.CgoEnabled, + Compiler: build.Default.Compiler, + ReleaseTags: build.Default.ReleaseTags, + } + for name, content := range test.files { + p := filepath.Join(gopath, "src", filepath.FromSlash(name)) + if err := os.MkdirAll(filepath.Dir(p), 0777); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(p, []byte(content), 0666); err != nil { + t.Fatal(err) + } + } + gen, err := Generate(bctx, gopath, test.pkg) + if len(gen) > 0 { + defer t.Logf("goose_gen.go:\n%s", gen) + } + if err != nil { + if !test.wantError { + t.Fatalf("goose: %v", err) + } + return + } + if err == nil && test.wantError { + t.Fatal("goose succeeded; want error") + } + if len(gen) > 0 { + genPath := filepath.Join(gopath, "src", filepath.FromSlash(test.pkg), "goose_gen.go") + if err := ioutil.WriteFile(genPath, gen, 0666); err != nil { + t.Fatal(err) + } + } + testExePath := filepath.Join(gopath, "bin", "testprog") + if err := runGo(bctx, "build", "-o", testExePath, test.pkg); err != nil { + t.Fatal("build:", err) + } + out, err := exec.Command(testExePath).Output() + if err != nil { + t.Fatal("run compiled program:", err) + } + if string(out) != test.wantOutput { + t.Errorf("compiled program output = %q; want %q", out, test.wantOutput) + } + }) + } +} + +func TestDeterminism(t *testing.T) { + runs := 10 + if testing.Short() { + runs = 3 + } + for _, test := range tests { + if test.wantError { + continue + } + t.Run(test.name, func(t *testing.T) { + gopath, err := ioutil.TempDir("", "goose_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(gopath) + bctx := &build.Context{ + GOARCH: build.Default.GOARCH, + GOOS: build.Default.GOOS, + GOROOT: build.Default.GOROOT, + GOPATH: gopath, + CgoEnabled: build.Default.CgoEnabled, + Compiler: build.Default.Compiler, + ReleaseTags: build.Default.ReleaseTags, + } + for name, content := range test.files { + p := filepath.Join(gopath, "src", filepath.FromSlash(name)) + if err := os.MkdirAll(filepath.Dir(p), 0777); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(p, []byte(content), 0666); err != nil { + t.Fatal(err) + } + } + gold, err := Generate(bctx, gopath, test.pkg) + if err != nil { + t.Fatal("goose:", err) + } + goldstr := string(gold) + for i := 0; i < runs-1; i++ { + out, err := Generate(bctx, gopath, test.pkg) + if err != nil { + t.Fatal("goose (on subsequent run):", err) + } + if !bytes.Equal(gold, out) { + t.Fatalf("goose output differs when run repeatedly on same input:\n%s", diff(goldstr, string(out))) + } + } + }) + } +} + +func runGo(bctx *build.Context, args ...string) error { + exe := filepath.Join(bctx.GOROOT, "bin", "go") + c := exec.Command(exe, args...) + c.Env = append(os.Environ(), "GOROOT="+bctx.GOROOT, "GOARCH="+bctx.GOARCH, "GOOS="+bctx.GOOS, "GOPATH="+bctx.GOPATH) + if bctx.CgoEnabled { + c.Env = append(c.Env, "CGO_ENABLED=1") + } else { + c.Env = append(c.Env, "CGO_ENABLED=0") + } + // TODO(someday): set -compiler flag if needed. + out, err := c.CombinedOutput() + if err != nil { + if len(out) > 0 { + return fmt.Errorf("%v; output:\n%s", err, out) + } + return err + } + return nil +} + +func diff(want, got string) string { + d, err := runDiff([]byte(want), []byte(got)) + if err == nil { + return string(d) + } + return "*** got:\n" + got + "\n\n*** want:\n" + want +} + +func runDiff(a, b []byte) ([]byte, error) { + fa, err := ioutil.TempFile("", "goose_test_diff") + if err != nil { + return nil, err + } + defer func() { + os.Remove(fa.Name()) + fa.Close() + }() + fb, err := ioutil.TempFile("", "goose_test_diff") + if err != nil { + return nil, err + } + defer func() { + os.Remove(fb.Name()) + fb.Close() + }() + if _, err := fa.Write(a); err != nil { + return nil, err + } + if _, err := fb.Write(b); err != nil { + return nil, err + } + c := exec.Command("diff", "-u", fa.Name(), fb.Name()) + out, err := c.Output() + return out, err +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..30db70b --- /dev/null +++ b/main.go @@ -0,0 +1,52 @@ +// goose is a compile-time dependency injection tool. +// +// See README.md for an overview. +package main + +import ( + "fmt" + "go/build" + "io/ioutil" + "os" + "path/filepath" + + "codename/goose/internal/goose" +) + +func main() { + var pkg string + switch len(os.Args) { + case 1: + pkg = "." + case 2: + pkg = os.Args[1] + default: + fmt.Fprintln(os.Stderr, "goose: usage: goose [PKG]") + os.Exit(64) + } + wd, err := os.Getwd() + if err != nil { + fmt.Fprintln(os.Stderr, "goose:", err) + os.Exit(1) + } + pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly) + if err != nil { + fmt.Fprintln(os.Stderr, "goose:", err) + os.Exit(1) + } + out, err := goose.Generate(&build.Default, wd, pkg) + if err != nil { + fmt.Fprintln(os.Stderr, "goose:", err) + os.Exit(1) + } + if len(out) == 0 { + // No Goose directives, don't write anything. + fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg) + return + } + p := filepath.Join(pkgInfo.Dir, "goose_gen.go") + if err := ioutil.WriteFile(p, out, 0666); err != nil { + fmt.Fprintln(os.Stderr, "goose:", err) + os.Exit(1) + } +}