goose: dependency injection proof of concept

See documentation and demo for usage and known limitations.

Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
Ross Light
2018-03-26 07:39:00 -07:00
commit 26c8618466
4 changed files with 1376 additions and 0 deletions

203
README.md Normal file
View File

@@ -0,0 +1,203 @@
# goose: Compile-Time Dependency Injection for Go
goose is a compile-time [dependency injection][] framework for Go, inspired by
[Dagger][]. It works by using Go code to specify dependencies, then
generating code to create those structures, mimicking the code that a user
might have hand-written.
[dependency injection]: https://en.wikipedia.org/wiki/Dependency_injection
[Dagger]: https://google.github.io/dagger/
## Usage Guide
### Defining Providers
The primary mechanism in goose is the **provider**: a function that can
produce a value, annotated with the special `goose:provide` directive. These
functions are ordinary Go code and live in packages.
```go
package module
type Foo int
// goose:provide
// ProvideFoo returns a Foo.
func ProvideFoo() Foo {
return 42
}
```
Providers are always part of a **module**: if there is no module name specified
on the `//goose:provide` line, then `Module` is used.
Providers can specify dependencies with parameters:
```go
package module
// goose:provide SuperModule
type Bar int
// ProvideBar returns a Bar: a negative Foo.
func ProvideBar(foo Foo) Bar {
return Bar(-foo)
}
```
Providers can also return errors:
```go
package module
import (
"context"
"errors"
)
type Baz int
// goose:provide SuperModule
// ProvideBaz returns a value if Bar is not zero.
func ProvideBaz(ctx context.Context, bar Bar) (Baz, error) {
if bar == 0 {
return 0, errors.New("cannot provide baz when bar is zero")
}
return Baz(bar), nil
}
```
Modules can import other modules. To import `Module` in `SuperModule`:
```go
// goose:import SuperModule Module
```
### Injectors
An application can use these providers by declaring an **injector**: a
generated function that calls providers in dependency order.
An injector is declared by writing a function declaration without a body in a
file guarded by a `gooseinject` build tag. Let's say that the above providers
were defined in a package called `example.com/module`. The following would
declare an injector to obtain a `Baz`:
```go
//+build gooseinject
package main
import (
"context"
"example.com/module"
)
// goose:use module.SuperModule
func initializeApp(ctx context.Context) (module.Baz, error)
```
Like providers, injectors can be parameterized on inputs (which then get sent to
providers) and can return errors. The `goose:use` directive specifies the
modules to use in the injection. Both `goose:use` and `goose:import` use the
same syntax for referencing modules: an optional import qualifier (either a
package name or a quoted import path) with a dot, followed by the module name.
For example: `SamePackageModule`, `foo.Bar`, or `"example.com/foo".Bar`.
You can generate the injector using goose:
```
goose
```
Or you can add the line `//go:generate goose` to another file in your package to
use [`go generate`]:
```
go generate
```
(Adding the line to the injection declaration file will be silently ignored by
`go generate`.)
goose will produce an implementation of the injector that looks something like
this:
```go
// Code generated by goose. DO NOT EDIT.
//+build !gooseinject
package main
import (
"example.com/module"
)
func initializeApp(ctx context.Context) (module.Baz, error) {
foo := module.ProvideFoo()
bar := module.ProvideBar(foo)
baz, err := module.ProvideBaz(ctx, bar)
if err != nil {
return 0, err
}
return baz, nil
}
```
As you can see, the output is very close to what a developer would write
themselves. Further, there is no dependency on goose at runtime: all of the
written code is just normal Go code, and can be used without goose.
[`go generate`]: https://blog.golang.org/generate
## Best Practices
goose is still not mature yet, but guidance that applies to Dagger generally
applies to goose as well. In particular, when thinking about how to group a
package of providers, follow the same [guidance](https://google.github.io/dagger/testing.html#organize-modules-for-testability) as Dagger:
> Some [...] bindings will have reasonable alternatives, especially for
> testing, and others will not. For example, there are likely to be
> alternative bindings for a type like `AuthManager`: one for testing, others
> for different authentication/authorization protocols.
>
> But on the other hand, if the `AuthManager` interface has a method that
> returns the currently logged-in user, you might want to [export a provider of
> `User` that simply calls `CurrentUser()`] on the `AuthManager`. That
> published binding is unlikely to ever need an alternative.
>
> Once youve classified your bindings into [...] bindings with reasonable
> alternatives [and] bindings without reasonable alternatives, consider
> arranging them into packages like this:
>
> - One [package] for each [...] binding with a reasonable alternative. (If
> you are also writing the alternatives, each one gets its own [package].) That
> [package] contains exactly one provider.
> - All [...] bindings with no reasonable alternatives go into [packages]
> organized along functional lines.
> - The [packages] should each include the no-reasonable-alternative [packages] that
> require the [...] bindings each provides.
One goose-specific practice though: create one-off types where in Java you
would use a binding annotation.
## Future Work
- The names of imports and provider results in the generated code are not
actually as nice as shown above. I'd like to make them nicer in the
common cases while ensuring uniqueness.
- I'd like to support optional and multiple bindings.
- At the moment, the entire transitive closure of all dependencies are read
for providers. It might be better to have provider imports be opt-in, but
that seems like too many levels of magic.
- Currently, all dependency satisfaction is done using identity. I'd like to
use a limited form of assignability for interface types, but I'm unsure
how well this implicit satisfaction will work in practice.
- Errors emitted by goose are not very good, but it has all the information
it needs to emit better ones.

715
internal/goose/goose.go Normal file
View File

@@ -0,0 +1,715 @@
// Package goose provides compile-time dependency injection logic as a
// Go library.
package goose
import (
"bytes"
"fmt"
"go/ast"
"go/build"
"go/format"
"go/parser"
"go/token"
"go/types"
"sort"
"strconv"
"strings"
"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/types/typeutil"
)
// Generate performs dependency injection for a single package,
// returning the gofmt'd Go source code.
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
// TODO(light): allow errors
// TODO(light): stop errors from printing to stderr
conf := &loader.Config{
Build: new(build.Context),
ParserMode: parser.ParseComments,
Cwd: wd,
}
*conf.Build = *bctx
n := len(conf.Build.BuildTags)
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
conf.Import(pkg)
prog, err := conf.Load()
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
if len(prog.InitialPackages()) != 1 {
// This is more of a violated precondition than anything else.
return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))
}
pkgInfo := prog.InitialPackages()[0]
g := newGen(pkgInfo.Pkg.Path())
mc := newModuleCache(prog)
var directives []directive
for _, f := range pkgInfo.Files {
if !isInjectFile(f) {
continue
}
fileScope := pkgInfo.Scopes[f]
cmap := ast.NewCommentMap(prog.Fset, f, f.Comments)
for _, decl := range f.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
directives = directives[:0]
for _, c := range cmap[fn] {
directives = extractDirectives(directives, c)
}
modules := make([]moduleRef, 0, len(directives))
for _, d := range directives {
if d.kind != "use" {
return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind)
}
ref, err := parseModuleRef(d.line, fileScope, g.currPackage, d.pos)
if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err)
}
modules = append(modules, ref)
}
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
if err := g.inject(mc, fn.Name.Name, sig, modules); err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
}
}
}
goSrc := g.frame(pkgInfo.Pkg.Name())
fmtSrc, err := format.Source(goSrc)
if err != nil {
// This is likely a bug from a poorly generated source file.
// Return an error and the unformatted source.
return goSrc, err
}
return fmtSrc, nil
}
// gen is the generator state.
type gen struct {
currPackage string
buf bytes.Buffer
imports map[string]string
n int
}
func newGen(pkg string) *gen {
return &gen{
currPackage: pkg,
imports: make(map[string]string),
}
}
// frame bakes the built up source body into an unformatted Go source file.
func (g *gen) frame(pkgName string) []byte {
if g.buf.Len() == 0 {
return nil
}
var buf bytes.Buffer
buf.WriteString("// Code generated by goose. DO NOT EDIT.\n\n//+build !gooseinject\n\npackage ")
buf.WriteString(pkgName)
buf.WriteString("\n\n")
if len(g.imports) > 0 {
buf.WriteString("import (\n")
imps := make([]string, 0, len(g.imports))
for path := range g.imports {
imps = append(imps, path)
}
sort.Strings(imps)
for _, path := range imps {
fmt.Fprintf(&buf, "\t%s %q\n", g.imports[path], path)
}
buf.WriteString(")\n\n")
}
buf.Write(g.buf.Bytes())
return buf.Bytes()
}
// inject emits the code for an injector.
func (g *gen) inject(mc *moduleCache, name string, sig *types.Signature, modules []moduleRef) error {
results := sig.Results()
returnsErr := false
switch results.Len() {
case 0:
return fmt.Errorf("inject %s: no return values", name)
case 1:
// nothing special
case 2:
if t := results.At(1).Type(); !types.Identical(t, errorType) {
return fmt.Errorf("inject %s: second return type is %s; must be error", name, types.TypeString(t, nil))
}
returnsErr = true
default:
return fmt.Errorf("inject %s: too many return values", name)
}
outType := results.At(0).Type()
params := sig.Params()
given := make([]types.Type, params.Len())
for i := 0; i < params.Len(); i++ {
given[i] = params.At(i).Type()
}
calls, err := solve(mc, outType, given, modules)
if err != nil {
return err
}
for i := range calls {
if calls[i].hasErr && !returnsErr {
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
}
}
g.p("func %s(", name)
for i := 0; i < params.Len(); i++ {
if i > 0 {
g.p(", ")
}
pi := params.At(i)
g.p("%s %s", pi.Name(), types.TypeString(pi.Type(), g.qualifyPkg))
}
if returnsErr {
g.p(") (%s, error) {\n", types.TypeString(outType, g.qualifyPkg))
} else {
g.p(") %s {\n", types.TypeString(outType, g.qualifyPkg))
}
zv := zeroValue(outType, g.qualifyPkg)
for i := range calls {
c := &calls[i]
g.p("\tv%d", i)
if c.hasErr {
g.p(", err")
}
g.p(" := %s(", g.qualifiedID(c.importPath, c.funcName))
for j, a := range c.args {
if j > 0 {
g.p(", ")
}
if a < params.Len() {
g.p("%s", params.At(a).Name())
} else {
g.p("v%d", a-params.Len())
}
}
g.p(")\n")
if c.hasErr {
g.p("\tif err != nil {\n")
// TODO(light): give information about failing provider
g.p("\t\treturn %s, err\n", zv)
g.p("\t}\n")
}
}
if len(calls) == 0 {
for i := range given {
if types.Identical(outType, given[i]) {
g.p("\treturn %s", params.At(i).Name())
break
}
}
} else {
g.p("\treturn v%d", len(calls)-1)
}
if returnsErr {
g.p(", nil")
}
g.p("\n}\n")
return nil
}
func (g *gen) qualifiedID(path, sym string) string {
name := g.qualifyImport(path)
if name == "" {
return sym
}
return name + "." + sym
}
func (g *gen) qualifyImport(path string) string {
if path == g.currPackage {
return ""
}
if name := g.imports[path]; name != "" {
return name
}
name := fmt.Sprintf("pkg%d", g.n)
g.n++
g.imports[path] = name
return name
}
func (g *gen) qualifyPkg(pkg *types.Package) string {
return g.qualifyImport(pkg.Path())
}
func (g *gen) p(format string, args ...interface{}) {
fmt.Fprintf(&g.buf, format, args...)
}
// A module describes a set of providers. The zero value is an empty
// module.
type module struct {
providers []*providerInfo
imports []moduleImport
}
type moduleImport struct {
moduleRef
pos token.Pos
}
const implicitModuleName = "Module"
// findModules processes a package and extracts the modules declared in it.
func findModules(fset *token.FileSet, pkg *types.Package, typeInfo *types.Info, files []*ast.File) (map[string]*module, error) {
modules := make(map[string]*module)
var directives []directive
for _, f := range files {
fileScope := typeInfo.Scopes[f]
for _, c := range f.Comments {
directives = extractDirectives(directives[:0], c)
for _, d := range directives {
switch d.kind {
case "provide", "use":
// handled later
case "import":
if fileScope == nil {
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fset.File(f.Pos()).Name())
}
var name, spec string
if strings.HasPrefix(d.line, `"`) {
name, spec = implicitModuleName, d.line
} else if i := strings.IndexByte(d.line, ' '); i != -1 {
name, spec = d.line[:i], d.line[i+1:]
} else {
name, spec = implicitModuleName, d.line
}
ref, err := parseModuleRef(spec, fileScope, pkg.Path(), d.pos)
if err != nil {
return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err)
}
if ref.importPath != pkg.Path() {
imported := false
for _, imp := range pkg.Imports() {
if ref.importPath == imp.Path() {
imported = true
break
}
}
if !imported {
return nil, fmt.Errorf("%v: module %s imports %q which is not in the package's imports", fset.Position(d.pos), name, ref.importPath)
}
}
if mod := modules[name]; mod != nil {
found := false
for _, other := range mod.imports {
if ref == other.moduleRef {
found = true
break
}
}
if !found {
mod.imports = append(mod.imports, moduleImport{moduleRef: ref, pos: d.pos})
}
} else {
modules[name] = &module{
imports: []moduleImport{{moduleRef: ref, pos: d.pos}},
}
}
default:
return nil, fmt.Errorf("%v: unknown directive %s", fset.Position(d.pos), d.kind)
}
}
}
cmap := ast.NewCommentMap(fset, f, f.Comments)
for _, decl := range f.Decls {
directives = directives[:0]
for _, cg := range cmap[decl] {
directives = extractDirectives(directives, cg)
}
fn, isFunction := decl.(*ast.FuncDecl)
var providerModule string
for _, d := range directives {
if d.kind != "provide" {
continue
}
if providerModule != "" {
return nil, fmt.Errorf("%v: multiple provide directives for %s", fset.Position(d.pos), fn.Name.Name)
}
if !isFunction {
return nil, fmt.Errorf("%v: only functions can be marked as providers", fset.Position(d.pos))
}
if d.line == "" {
providerModule = implicitModuleName
} else {
// TODO(light): validate identifier
providerModule = d.line
}
}
if providerModule == "" {
continue
}
fpos := fn.Pos()
sig := typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
r := sig.Results()
var hasErr bool
switch r.Len() {
case 1:
hasErr = false
case 2:
if t := r.At(1).Type(); !types.Identical(t, errorType) {
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fset.Position(fpos), fn.Name.Name)
}
hasErr = true
default:
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name.Name)
}
out := r.At(0).Type()
p := sig.Params()
provider := &providerInfo{
importPath: pkg.Path(),
funcName: fn.Name.Name,
pos: fn.Pos(),
args: make([]types.Type, p.Len()),
out: out,
hasErr: hasErr,
}
for i := 0; i < p.Len(); i++ {
provider.args[i] = p.At(i).Type()
for j := 0; j < i; j++ {
if types.Identical(provider.args[i], provider.args[j]) {
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.args[j], nil))
}
}
}
if mod := modules[providerModule]; mod != nil {
for _, other := range mod.providers {
if types.Identical(other.out, provider.out) {
return nil, fmt.Errorf("%v: module %s has multiple providers for %s (previous declaration at %v)", fset.Position(fpos), providerModule, types.TypeString(provider.out, nil), fset.Position(other.pos))
}
}
mod.providers = append(mod.providers, provider)
} else {
modules[providerModule] = &module{
providers: []*providerInfo{provider},
}
}
}
}
return modules, nil
}
// moduleCache is a lazily evaluated index of modules.
type moduleCache struct {
modules map[string]map[string]*module
fset *token.FileSet
prog *loader.Program
}
func newModuleCache(prog *loader.Program) *moduleCache {
return &moduleCache{
fset: prog.Fset,
prog: prog,
}
}
func (mc *moduleCache) get(ref moduleRef) (*module, error) {
if mods, cached := mc.modules[ref.importPath]; cached {
mod := mods[ref.moduleName]
if mod == nil {
return nil, fmt.Errorf("no such module %s in package %q", ref.moduleName, ref.importPath)
}
return mod, nil
}
if mc.modules == nil {
mc.modules = make(map[string]map[string]*module)
}
pkg, info, files, err := mc.getpkg(ref.importPath)
if err != nil {
mc.modules[ref.importPath] = nil
return nil, fmt.Errorf("analyze package: %v", err)
}
mods, err := findModules(mc.fset, pkg, info, files)
if err != nil {
mc.modules[ref.importPath] = nil
return nil, err
}
mc.modules[ref.importPath] = mods
mod := mods[ref.moduleName]
if mod == nil {
return nil, fmt.Errorf("no such module %s in package %q", ref.moduleName, ref.importPath)
}
return mod, nil
}
func (mc *moduleCache) getpkg(path string) (*types.Package, *types.Info, []*ast.File, error) {
// TODO(light): allow other implementations for testing
pkg := mc.prog.Package(path)
if pkg == nil {
return nil, nil, nil, fmt.Errorf("package %q not found", path)
}
return pkg.Pkg, &pkg.Info, pkg.Files, nil
}
// solve finds the sequence of calls required to produce an output type
// with an optional set of provided inputs.
func solve(mc *moduleCache, out types.Type, given []types.Type, modules []moduleRef) ([]call, error) {
for i, g := range given {
for _, h := range given[:i] {
if types.Identical(g, h) {
return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))
}
}
}
providers, err := buildProviderMap(mc, modules)
if err != nil {
return nil, err
}
// Start building the mapping of type to local variable of the given type.
// The first len(given) local variables are the given types.
index := new(typeutil.Map)
for i, g := range given {
if p := providers.At(g); p != nil {
pp := p.(*providerInfo)
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos))
}
index.Set(g, i)
}
// Topological sort of the directed graph defined by the providers
// using a depth-first search. The graph may contain cycles, which
// should trigger an error.
var calls []call
var visit func(trail []types.Type) error
visit = func(trail []types.Type) error {
typ := trail[len(trail)-1]
if index.At(typ) != nil {
return nil
}
for _, t := range trail[:len(trail)-1] {
if types.Identical(typ, t) {
// TODO(light): describe cycle
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
}
}
p, _ := providers.At(typ).(*providerInfo)
if p == nil {
if len(trail) == 1 {
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
}
// TODO(light): give name of provider
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2], nil))
}
for _, a := range p.args {
// TODO(light): this will discard grown trail arrays.
if err := visit(append(trail, a)); err != nil {
return err
}
}
args := make([]int, len(p.args))
for i := range p.args {
args[i] = index.At(p.args[i]).(int)
}
index.Set(typ, len(given)+len(calls))
calls = append(calls, call{
importPath: p.importPath,
funcName: p.funcName,
args: args,
out: typ,
hasErr: p.hasErr,
})
return nil
}
if err := visit([]types.Type{out}); err != nil {
return nil, err
}
return calls, nil
}
func buildProviderMap(mc *moduleCache, modules []moduleRef) (*typeutil.Map, error) {
type nextEnt struct {
to moduleRef
from moduleRef
pos token.Pos
}
pm := new(typeutil.Map) // to *providerInfo
visited := make(map[moduleRef]struct{})
var next []nextEnt
for _, ref := range modules {
next = append(next, nextEnt{to: ref})
}
for len(next) > 0 {
curr := next[0]
copy(next, next[1:])
next = next[:len(next)-1]
if _, skip := visited[curr.to]; skip {
continue
}
visited[curr.to] = struct{}{}
mod, err := mc.get(curr.to)
if err != nil {
if !curr.pos.IsValid() {
return nil, err
}
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
}
for _, p := range mod.providers {
if prev := pm.At(p.out); prev != nil {
pos := mc.fset.Position(p.pos)
typ := types.TypeString(p.out, nil)
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
if curr.from.importPath != "" {
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
}
pm.Set(p.out, p)
}
for _, imp := range mod.imports {
next = append(next, nextEnt{to: imp.moduleRef, from: curr.to, pos: imp.pos})
}
}
return pm, nil
}
// A call represents a step of an injector function.
type call struct {
// importPath and funcName identify the provider function to call.
importPath string
funcName string
// args is a list of arguments to call the provider with. Each element is either:
// a) one of the givens (args[i] < len(given)) or
// b) the result of a previous provider call (args[i] >= len(given)).
args []int
// out is the type produced by this provider call.
out types.Type
// hasErr is true if the provider call returns an error.
hasErr bool
}
// providerInfo records the signature of a provider function.
type providerInfo struct {
importPath string
funcName string
pos token.Pos
args []types.Type
out types.Type
hasErr bool
}
// A moduleRef is a parsed reference to a collection of providers.
type moduleRef struct {
importPath string
moduleName string
}
func parseModuleRef(ref string, s *types.Scope, pkg string, pos token.Pos) (moduleRef, error) {
// TODO(light): verify that module name is an identifier before returning
i := strings.LastIndexByte(ref, '.')
if i == -1 {
return moduleRef{importPath: pkg, moduleName: ref}, nil
}
imp, name := ref[:i], ref[i+1:]
if strings.HasPrefix(imp, `"`) {
path, err := strconv.Unquote(imp)
if err != nil {
return moduleRef{}, fmt.Errorf("parse module reference %q: bad import path", ref)
}
return moduleRef{importPath: path, moduleName: name}, nil
}
_, obj := s.LookupParent(imp, pos)
if obj == nil {
return moduleRef{}, fmt.Errorf("parse module reference %q: unknown identifier %s", ref, imp)
}
pn, ok := obj.(*types.PkgName)
if !ok {
return moduleRef{}, fmt.Errorf("parse module reference %q: %s does not name a package", ref, imp)
}
return moduleRef{importPath: pn.Imported().Path(), moduleName: name}, nil
}
func (ref moduleRef) String() string {
return strconv.Quote(ref.importPath) + "." + ref.moduleName
}
type directive struct {
pos token.Pos
kind string
line string
}
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
const prefix = "goose:"
text := cg.Text()
for len(text) > 0 {
text = strings.TrimLeft(text, " \t\r\n")
if !strings.HasPrefix(text, prefix) {
break
}
line := text[len(prefix):]
if i := strings.IndexByte(line, '\n'); i != -1 {
line, text = line[:i], line[i+1:]
} else {
text = ""
}
if i := strings.IndexByte(line, ' '); i != -1 {
d = append(d, directive{
kind: line[:i],
line: strings.TrimSpace(line[i+1:]),
pos: cg.Pos(), // TODO(light): more precise position
})
} else {
d = append(d, directive{
kind: line,
pos: cg.Pos(), // TODO(light): more precise position
})
}
}
return d
}
// isInjectFile reports whether a given file is an injection template.
func isInjectFile(f *ast.File) bool {
// TODO(light): better determination
for _, cg := range f.Comments {
text := cg.Text()
if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") {
return true
}
}
return false
}
// zeroValue returns the shortest expression that evaluates to the zero
// value for the given type.
func zeroValue(t types.Type, qf types.Qualifier) string {
switch u := t.Underlying().(type) {
case *types.Array, *types.Struct:
return types.TypeString(t, qf) + "{}"
case *types.Basic:
info := u.Info()
switch {
case info&types.IsBoolean != 0:
return "false"
case info&(types.IsInteger|types.IsFloat|types.IsComplex) != 0:
return "0"
case info&types.IsString != 0:
return `""`
default:
panic("unreachable")
}
case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice:
return "nil"
default:
panic("unreachable")
}
}
var errorType = types.Universe.Lookup("error").Type()

View File

@@ -0,0 +1,406 @@
package goose
import (
"bytes"
"fmt"
"go/build"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"testing"
)
// TODO(light): pull this out into a testdata directory
var tests = []struct {
name string
files map[string]string
pkg string
wantOutput string
wantError bool
}{
{
name: "No-op build",
files: map[string]string{
"foo/foo.go": `package main; import "fmt"; func main() { fmt.Println("Hello, World!"); }`,
},
pkg: "foo",
wantOutput: "Hello, World!\n",
},
{
name: "Niladic identity provider",
files: map[string]string{
"foo/foo.go": `package main
import "fmt"
func main() { fmt.Println(injectedMessage()); }
//goose:provide
// provideMessage provides a friendly user greeting.
func provideMessage() string { return "Hello, World!"; }
`,
"foo/foo_goose.go": `//+build gooseinject
package main
//goose:use Module
func injectedMessage() string
`,
},
pkg: "foo",
wantOutput: "Hello, World!\n",
},
{
name: "Missing use",
files: map[string]string{
"foo/foo.go": `package main
import "fmt"
func main() { fmt.Println(injectedMessage()); }
//goose:provide
// provideMessage provides a friendly user greeting.
func provideMessage() string { return "Hello, World!"; }
`,
"foo/foo_goose.go": `//+build gooseinject
package main
func injectedMessage() string
`,
},
pkg: "foo",
wantError: true,
},
{
name: "Chain",
files: map[string]string{
"foo/foo.go": `package main
import "fmt"
func main() {
fmt.Println(injectFooBar())
}
type Foo int
type FooBar int
//goose:provide
func provideFoo() Foo { return 41 }
//goose:provide
func provideFooBar(foo Foo) FooBar { return FooBar(foo) + 1 }
`,
"foo/foo_goose.go": `//+build gooseinject
package main
//goose:use Module
func injectFooBar() FooBar
`,
},
pkg: "foo",
wantOutput: "42\n",
},
{
name: "Two deps",
files: map[string]string{
"foo/foo.go": `package main
import "fmt"
func main() {
fmt.Println(injectFooBar())
}
type Foo int
type Bar int
type FooBar int
//goose:provide
func provideFoo() Foo { return 40 }
//goose:provide
func provideBar() Bar { return 2 }
//goose:provide
func provideFooBar(foo Foo, bar Bar) FooBar { return FooBar(foo) + FooBar(bar) }
`,
"foo/foo_goose.go": `//+build gooseinject
package main
//goose:use Module
func injectFooBar() FooBar
`,
},
pkg: "foo",
wantOutput: "42\n",
},
{
name: "Inject input",
files: map[string]string{
"foo/foo.go": `package main
import "fmt"
func main() {
fmt.Println(injectFooBar(40))
}
type Foo int
type Bar int
type FooBar int
//goose:provide
func provideBar() Bar { return 2 }
//goose:provide
func provideFooBar(foo Foo, bar Bar) FooBar { return FooBar(foo) + FooBar(bar) }
`,
"foo/foo_goose.go": `//+build gooseinject
package main
//goose:use Module
func injectFooBar(foo Foo) FooBar
`,
},
pkg: "foo",
wantOutput: "42\n",
},
{
name: "Inject input conflict",
files: map[string]string{
"foo/foo.go": `package main
import "fmt"
func main() {
fmt.Println(injectBar(40))
}
type Foo int
type Bar int
//goose:provide
func provideFoo() Foo { return -888 }
//goose:provide
func provideBar(foo Foo) Bar { return 2 }
`,
"foo/foo_goose.go": `//+build gooseinject
package main
//goose:use Module
func injectBar(foo Foo) Bar
`,
},
pkg: "foo",
wantError: true,
},
{
name: "Return error",
files: map[string]string{
"foo/foo.go": `package main
import "errors"
import "fmt"
import "strings"
func main() {
foo, err := injectFoo()
fmt.Println(foo)
if err == nil {
fmt.Println("<nil>")
} else {
fmt.Println(strings.Contains(err.Error(), "there is no Foo"))
}
}
type Foo int
//goose:provide
func provideFoo() (Foo, error) { return 42, errors.New("there is no Foo") }
`,
"foo/foo_goose.go": `//+build gooseinject
package main
//goose:use Module
func injectFoo() (Foo, error)
`,
},
pkg: "foo",
wantOutput: "0\ntrue\n",
},
}
func TestGeneratedCode(t *testing.T) {
if _, err := os.Stat(filepath.Join(build.Default.GOROOT, "bin", "go")); err != nil {
t.Fatalf("go toolchain not available: %v", err)
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
gopath, err := ioutil.TempDir("", "goose_test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(gopath)
bctx := &build.Context{
GOARCH: build.Default.GOARCH,
GOOS: build.Default.GOOS,
GOROOT: build.Default.GOROOT,
GOPATH: gopath,
CgoEnabled: build.Default.CgoEnabled,
Compiler: build.Default.Compiler,
ReleaseTags: build.Default.ReleaseTags,
}
for name, content := range test.files {
p := filepath.Join(gopath, "src", filepath.FromSlash(name))
if err := os.MkdirAll(filepath.Dir(p), 0777); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(p, []byte(content), 0666); err != nil {
t.Fatal(err)
}
}
gen, err := Generate(bctx, gopath, test.pkg)
if len(gen) > 0 {
defer t.Logf("goose_gen.go:\n%s", gen)
}
if err != nil {
if !test.wantError {
t.Fatalf("goose: %v", err)
}
return
}
if err == nil && test.wantError {
t.Fatal("goose succeeded; want error")
}
if len(gen) > 0 {
genPath := filepath.Join(gopath, "src", filepath.FromSlash(test.pkg), "goose_gen.go")
if err := ioutil.WriteFile(genPath, gen, 0666); err != nil {
t.Fatal(err)
}
}
testExePath := filepath.Join(gopath, "bin", "testprog")
if err := runGo(bctx, "build", "-o", testExePath, test.pkg); err != nil {
t.Fatal("build:", err)
}
out, err := exec.Command(testExePath).Output()
if err != nil {
t.Fatal("run compiled program:", err)
}
if string(out) != test.wantOutput {
t.Errorf("compiled program output = %q; want %q", out, test.wantOutput)
}
})
}
}
func TestDeterminism(t *testing.T) {
runs := 10
if testing.Short() {
runs = 3
}
for _, test := range tests {
if test.wantError {
continue
}
t.Run(test.name, func(t *testing.T) {
gopath, err := ioutil.TempDir("", "goose_test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(gopath)
bctx := &build.Context{
GOARCH: build.Default.GOARCH,
GOOS: build.Default.GOOS,
GOROOT: build.Default.GOROOT,
GOPATH: gopath,
CgoEnabled: build.Default.CgoEnabled,
Compiler: build.Default.Compiler,
ReleaseTags: build.Default.ReleaseTags,
}
for name, content := range test.files {
p := filepath.Join(gopath, "src", filepath.FromSlash(name))
if err := os.MkdirAll(filepath.Dir(p), 0777); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(p, []byte(content), 0666); err != nil {
t.Fatal(err)
}
}
gold, err := Generate(bctx, gopath, test.pkg)
if err != nil {
t.Fatal("goose:", err)
}
goldstr := string(gold)
for i := 0; i < runs-1; i++ {
out, err := Generate(bctx, gopath, test.pkg)
if err != nil {
t.Fatal("goose (on subsequent run):", err)
}
if !bytes.Equal(gold, out) {
t.Fatalf("goose output differs when run repeatedly on same input:\n%s", diff(goldstr, string(out)))
}
}
})
}
}
func runGo(bctx *build.Context, args ...string) error {
exe := filepath.Join(bctx.GOROOT, "bin", "go")
c := exec.Command(exe, args...)
c.Env = append(os.Environ(), "GOROOT="+bctx.GOROOT, "GOARCH="+bctx.GOARCH, "GOOS="+bctx.GOOS, "GOPATH="+bctx.GOPATH)
if bctx.CgoEnabled {
c.Env = append(c.Env, "CGO_ENABLED=1")
} else {
c.Env = append(c.Env, "CGO_ENABLED=0")
}
// TODO(someday): set -compiler flag if needed.
out, err := c.CombinedOutput()
if err != nil {
if len(out) > 0 {
return fmt.Errorf("%v; output:\n%s", err, out)
}
return err
}
return nil
}
func diff(want, got string) string {
d, err := runDiff([]byte(want), []byte(got))
if err == nil {
return string(d)
}
return "*** got:\n" + got + "\n\n*** want:\n" + want
}
func runDiff(a, b []byte) ([]byte, error) {
fa, err := ioutil.TempFile("", "goose_test_diff")
if err != nil {
return nil, err
}
defer func() {
os.Remove(fa.Name())
fa.Close()
}()
fb, err := ioutil.TempFile("", "goose_test_diff")
if err != nil {
return nil, err
}
defer func() {
os.Remove(fb.Name())
fb.Close()
}()
if _, err := fa.Write(a); err != nil {
return nil, err
}
if _, err := fb.Write(b); err != nil {
return nil, err
}
c := exec.Command("diff", "-u", fa.Name(), fb.Name())
out, err := c.Output()
return out, err
}

52
main.go Normal file
View File

@@ -0,0 +1,52 @@
// goose is a compile-time dependency injection tool.
//
// See README.md for an overview.
package main
import (
"fmt"
"go/build"
"io/ioutil"
"os"
"path/filepath"
"codename/goose/internal/goose"
)
func main() {
var pkg string
switch len(os.Args) {
case 1:
pkg = "."
case 2:
pkg = os.Args[1]
default:
fmt.Fprintln(os.Stderr, "goose: usage: goose [PKG]")
os.Exit(64)
}
wd, err := os.Getwd()
if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1)
}
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1)
}
out, err := goose.Generate(&build.Default, wd, pkg)
if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1)
}
if len(out) == 0 {
// No Goose directives, don't write anything.
fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg)
return
}
p := filepath.Join(pkgInfo.Dir, "goose_gen.go")
if err := ioutil.WriteFile(p, out, 0666); err != nil {
fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1)
}
}