goose: dependency injection proof of concept
See documentation and demo for usage and known limitations. Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
203
README.md
Normal file
203
README.md
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
# goose: Compile-Time Dependency Injection for Go
|
||||||
|
|
||||||
|
goose is a compile-time [dependency injection][] framework for Go, inspired by
|
||||||
|
[Dagger][]. It works by using Go code to specify dependencies, then
|
||||||
|
generating code to create those structures, mimicking the code that a user
|
||||||
|
might have hand-written.
|
||||||
|
|
||||||
|
[dependency injection]: https://en.wikipedia.org/wiki/Dependency_injection
|
||||||
|
[Dagger]: https://google.github.io/dagger/
|
||||||
|
|
||||||
|
## Usage Guide
|
||||||
|
|
||||||
|
### Defining Providers
|
||||||
|
|
||||||
|
The primary mechanism in goose is the **provider**: a function that can
|
||||||
|
produce a value, annotated with the special `goose:provide` directive. These
|
||||||
|
functions are ordinary Go code and live in packages.
|
||||||
|
|
||||||
|
```go
|
||||||
|
package module
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
|
||||||
|
// goose:provide
|
||||||
|
|
||||||
|
// ProvideFoo returns a Foo.
|
||||||
|
func ProvideFoo() Foo {
|
||||||
|
return 42
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Providers are always part of a **module**: if there is no module name specified
|
||||||
|
on the `//goose:provide` line, then `Module` is used.
|
||||||
|
|
||||||
|
Providers can specify dependencies with parameters:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package module
|
||||||
|
|
||||||
|
// goose:provide SuperModule
|
||||||
|
|
||||||
|
type Bar int
|
||||||
|
|
||||||
|
// ProvideBar returns a Bar: a negative Foo.
|
||||||
|
func ProvideBar(foo Foo) Bar {
|
||||||
|
return Bar(-foo)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Providers can also return errors:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package module
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Baz int
|
||||||
|
|
||||||
|
// goose:provide SuperModule
|
||||||
|
|
||||||
|
// ProvideBaz returns a value if Bar is not zero.
|
||||||
|
func ProvideBaz(ctx context.Context, bar Bar) (Baz, error) {
|
||||||
|
if bar == 0 {
|
||||||
|
return 0, errors.New("cannot provide baz when bar is zero")
|
||||||
|
}
|
||||||
|
return Baz(bar), nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Modules can import other modules. To import `Module` in `SuperModule`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// goose:import SuperModule Module
|
||||||
|
```
|
||||||
|
|
||||||
|
### Injectors
|
||||||
|
|
||||||
|
An application can use these providers by declaring an **injector**: a
|
||||||
|
generated function that calls providers in dependency order.
|
||||||
|
|
||||||
|
An injector is declared by writing a function declaration without a body in a
|
||||||
|
file guarded by a `gooseinject` build tag. Let's say that the above providers
|
||||||
|
were defined in a package called `example.com/module`. The following would
|
||||||
|
declare an injector to obtain a `Baz`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"example.com/module"
|
||||||
|
)
|
||||||
|
|
||||||
|
// goose:use module.SuperModule
|
||||||
|
|
||||||
|
func initializeApp(ctx context.Context) (module.Baz, error)
|
||||||
|
```
|
||||||
|
|
||||||
|
Like providers, injectors can be parameterized on inputs (which then get sent to
|
||||||
|
providers) and can return errors. The `goose:use` directive specifies the
|
||||||
|
modules to use in the injection. Both `goose:use` and `goose:import` use the
|
||||||
|
same syntax for referencing modules: an optional import qualifier (either a
|
||||||
|
package name or a quoted import path) with a dot, followed by the module name.
|
||||||
|
For example: `SamePackageModule`, `foo.Bar`, or `"example.com/foo".Bar`.
|
||||||
|
|
||||||
|
You can generate the injector using goose:
|
||||||
|
|
||||||
|
```
|
||||||
|
goose
|
||||||
|
```
|
||||||
|
|
||||||
|
Or you can add the line `//go:generate goose` to another file in your package to
|
||||||
|
use [`go generate`]:
|
||||||
|
|
||||||
|
```
|
||||||
|
go generate
|
||||||
|
```
|
||||||
|
|
||||||
|
(Adding the line to the injection declaration file will be silently ignored by
|
||||||
|
`go generate`.)
|
||||||
|
|
||||||
|
goose will produce an implementation of the injector that looks something like
|
||||||
|
this:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Code generated by goose. DO NOT EDIT.
|
||||||
|
|
||||||
|
//+build !gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"example.com/module"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initializeApp(ctx context.Context) (module.Baz, error) {
|
||||||
|
foo := module.ProvideFoo()
|
||||||
|
bar := module.ProvideBar(foo)
|
||||||
|
baz, err := module.ProvideBaz(ctx, bar)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return baz, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
As you can see, the output is very close to what a developer would write
|
||||||
|
themselves. Further, there is no dependency on goose at runtime: all of the
|
||||||
|
written code is just normal Go code, and can be used without goose.
|
||||||
|
|
||||||
|
[`go generate`]: https://blog.golang.org/generate
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
goose is still not mature yet, but guidance that applies to Dagger generally
|
||||||
|
applies to goose as well. In particular, when thinking about how to group a
|
||||||
|
package of providers, follow the same [guidance](https://google.github.io/dagger/testing.html#organize-modules-for-testability) as Dagger:
|
||||||
|
|
||||||
|
> Some [...] bindings will have reasonable alternatives, especially for
|
||||||
|
> testing, and others will not. For example, there are likely to be
|
||||||
|
> alternative bindings for a type like `AuthManager`: one for testing, others
|
||||||
|
> for different authentication/authorization protocols.
|
||||||
|
>
|
||||||
|
> But on the other hand, if the `AuthManager` interface has a method that
|
||||||
|
> returns the currently logged-in user, you might want to [export a provider of
|
||||||
|
> `User` that simply calls `CurrentUser()`] on the `AuthManager`. That
|
||||||
|
> published binding is unlikely to ever need an alternative.
|
||||||
|
>
|
||||||
|
> Once you’ve classified your bindings into [...] bindings with reasonable
|
||||||
|
> alternatives [and] bindings without reasonable alternatives, consider
|
||||||
|
> arranging them into packages like this:
|
||||||
|
>
|
||||||
|
> - One [package] for each [...] binding with a reasonable alternative. (If
|
||||||
|
> you are also writing the alternatives, each one gets its own [package].) That
|
||||||
|
> [package] contains exactly one provider.
|
||||||
|
> - All [...] bindings with no reasonable alternatives go into [packages]
|
||||||
|
> organized along functional lines.
|
||||||
|
> - The [packages] should each include the no-reasonable-alternative [packages] that
|
||||||
|
> require the [...] bindings each provides.
|
||||||
|
|
||||||
|
One goose-specific practice though: create one-off types where in Java you
|
||||||
|
would use a binding annotation.
|
||||||
|
|
||||||
|
## Future Work
|
||||||
|
|
||||||
|
- The names of imports and provider results in the generated code are not
|
||||||
|
actually as nice as shown above. I'd like to make them nicer in the
|
||||||
|
common cases while ensuring uniqueness.
|
||||||
|
- I'd like to support optional and multiple bindings.
|
||||||
|
- At the moment, the entire transitive closure of all dependencies are read
|
||||||
|
for providers. It might be better to have provider imports be opt-in, but
|
||||||
|
that seems like too many levels of magic.
|
||||||
|
- Currently, all dependency satisfaction is done using identity. I'd like to
|
||||||
|
use a limited form of assignability for interface types, but I'm unsure
|
||||||
|
how well this implicit satisfaction will work in practice.
|
||||||
|
- Errors emitted by goose are not very good, but it has all the information
|
||||||
|
it needs to emit better ones.
|
||||||
715
internal/goose/goose.go
Normal file
715
internal/goose/goose.go
Normal file
@@ -0,0 +1,715 @@
|
|||||||
|
// Package goose provides compile-time dependency injection logic as a
|
||||||
|
// Go library.
|
||||||
|
package goose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"go/ast"
|
||||||
|
"go/build"
|
||||||
|
"go/format"
|
||||||
|
"go/parser"
|
||||||
|
"go/token"
|
||||||
|
"go/types"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/tools/go/loader"
|
||||||
|
"golang.org/x/tools/go/types/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generate performs dependency injection for a single package,
|
||||||
|
// returning the gofmt'd Go source code.
|
||||||
|
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||||
|
// TODO(light): allow errors
|
||||||
|
// TODO(light): stop errors from printing to stderr
|
||||||
|
conf := &loader.Config{
|
||||||
|
Build: new(build.Context),
|
||||||
|
ParserMode: parser.ParseComments,
|
||||||
|
Cwd: wd,
|
||||||
|
}
|
||||||
|
*conf.Build = *bctx
|
||||||
|
n := len(conf.Build.BuildTags)
|
||||||
|
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
|
||||||
|
conf.Import(pkg)
|
||||||
|
prog, err := conf.Load()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load: %v", err)
|
||||||
|
}
|
||||||
|
if len(prog.InitialPackages()) != 1 {
|
||||||
|
// This is more of a violated precondition than anything else.
|
||||||
|
return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))
|
||||||
|
}
|
||||||
|
pkgInfo := prog.InitialPackages()[0]
|
||||||
|
g := newGen(pkgInfo.Pkg.Path())
|
||||||
|
mc := newModuleCache(prog)
|
||||||
|
var directives []directive
|
||||||
|
for _, f := range pkgInfo.Files {
|
||||||
|
if !isInjectFile(f) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fileScope := pkgInfo.Scopes[f]
|
||||||
|
cmap := ast.NewCommentMap(prog.Fset, f, f.Comments)
|
||||||
|
for _, decl := range f.Decls {
|
||||||
|
fn, ok := decl.(*ast.FuncDecl)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
directives = directives[:0]
|
||||||
|
for _, c := range cmap[fn] {
|
||||||
|
directives = extractDirectives(directives, c)
|
||||||
|
}
|
||||||
|
modules := make([]moduleRef, 0, len(directives))
|
||||||
|
for _, d := range directives {
|
||||||
|
if d.kind != "use" {
|
||||||
|
return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind)
|
||||||
|
}
|
||||||
|
ref, err := parseModuleRef(d.line, fileScope, g.currPackage, d.pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err)
|
||||||
|
}
|
||||||
|
modules = append(modules, ref)
|
||||||
|
}
|
||||||
|
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||||
|
if err := g.inject(mc, fn.Name.Name, sig, modules); err != nil {
|
||||||
|
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
goSrc := g.frame(pkgInfo.Pkg.Name())
|
||||||
|
fmtSrc, err := format.Source(goSrc)
|
||||||
|
if err != nil {
|
||||||
|
// This is likely a bug from a poorly generated source file.
|
||||||
|
// Return an error and the unformatted source.
|
||||||
|
return goSrc, err
|
||||||
|
}
|
||||||
|
return fmtSrc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// gen is the generator state.
|
||||||
|
type gen struct {
|
||||||
|
currPackage string
|
||||||
|
buf bytes.Buffer
|
||||||
|
imports map[string]string
|
||||||
|
n int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGen(pkg string) *gen {
|
||||||
|
return &gen{
|
||||||
|
currPackage: pkg,
|
||||||
|
imports: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// frame bakes the built up source body into an unformatted Go source file.
|
||||||
|
func (g *gen) frame(pkgName string) []byte {
|
||||||
|
if g.buf.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString("// Code generated by goose. DO NOT EDIT.\n\n//+build !gooseinject\n\npackage ")
|
||||||
|
buf.WriteString(pkgName)
|
||||||
|
buf.WriteString("\n\n")
|
||||||
|
if len(g.imports) > 0 {
|
||||||
|
buf.WriteString("import (\n")
|
||||||
|
imps := make([]string, 0, len(g.imports))
|
||||||
|
for path := range g.imports {
|
||||||
|
imps = append(imps, path)
|
||||||
|
}
|
||||||
|
sort.Strings(imps)
|
||||||
|
for _, path := range imps {
|
||||||
|
fmt.Fprintf(&buf, "\t%s %q\n", g.imports[path], path)
|
||||||
|
}
|
||||||
|
buf.WriteString(")\n\n")
|
||||||
|
}
|
||||||
|
buf.Write(g.buf.Bytes())
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// inject emits the code for an injector.
|
||||||
|
func (g *gen) inject(mc *moduleCache, name string, sig *types.Signature, modules []moduleRef) error {
|
||||||
|
results := sig.Results()
|
||||||
|
returnsErr := false
|
||||||
|
switch results.Len() {
|
||||||
|
case 0:
|
||||||
|
return fmt.Errorf("inject %s: no return values", name)
|
||||||
|
case 1:
|
||||||
|
// nothing special
|
||||||
|
case 2:
|
||||||
|
if t := results.At(1).Type(); !types.Identical(t, errorType) {
|
||||||
|
return fmt.Errorf("inject %s: second return type is %s; must be error", name, types.TypeString(t, nil))
|
||||||
|
}
|
||||||
|
returnsErr = true
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("inject %s: too many return values", name)
|
||||||
|
}
|
||||||
|
outType := results.At(0).Type()
|
||||||
|
params := sig.Params()
|
||||||
|
given := make([]types.Type, params.Len())
|
||||||
|
for i := 0; i < params.Len(); i++ {
|
||||||
|
given[i] = params.At(i).Type()
|
||||||
|
}
|
||||||
|
calls, err := solve(mc, outType, given, modules)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for i := range calls {
|
||||||
|
if calls[i].hasErr && !returnsErr {
|
||||||
|
return fmt.Errorf("inject %s: provider for %s returns error but injection not allowed to fail", name, types.TypeString(calls[i].out, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.p("func %s(", name)
|
||||||
|
for i := 0; i < params.Len(); i++ {
|
||||||
|
if i > 0 {
|
||||||
|
g.p(", ")
|
||||||
|
}
|
||||||
|
pi := params.At(i)
|
||||||
|
g.p("%s %s", pi.Name(), types.TypeString(pi.Type(), g.qualifyPkg))
|
||||||
|
}
|
||||||
|
if returnsErr {
|
||||||
|
g.p(") (%s, error) {\n", types.TypeString(outType, g.qualifyPkg))
|
||||||
|
} else {
|
||||||
|
g.p(") %s {\n", types.TypeString(outType, g.qualifyPkg))
|
||||||
|
}
|
||||||
|
zv := zeroValue(outType, g.qualifyPkg)
|
||||||
|
for i := range calls {
|
||||||
|
c := &calls[i]
|
||||||
|
g.p("\tv%d", i)
|
||||||
|
if c.hasErr {
|
||||||
|
g.p(", err")
|
||||||
|
}
|
||||||
|
g.p(" := %s(", g.qualifiedID(c.importPath, c.funcName))
|
||||||
|
for j, a := range c.args {
|
||||||
|
if j > 0 {
|
||||||
|
g.p(", ")
|
||||||
|
}
|
||||||
|
if a < params.Len() {
|
||||||
|
g.p("%s", params.At(a).Name())
|
||||||
|
} else {
|
||||||
|
g.p("v%d", a-params.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.p(")\n")
|
||||||
|
if c.hasErr {
|
||||||
|
g.p("\tif err != nil {\n")
|
||||||
|
// TODO(light): give information about failing provider
|
||||||
|
g.p("\t\treturn %s, err\n", zv)
|
||||||
|
g.p("\t}\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(calls) == 0 {
|
||||||
|
for i := range given {
|
||||||
|
if types.Identical(outType, given[i]) {
|
||||||
|
g.p("\treturn %s", params.At(i).Name())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
g.p("\treturn v%d", len(calls)-1)
|
||||||
|
}
|
||||||
|
if returnsErr {
|
||||||
|
g.p(", nil")
|
||||||
|
}
|
||||||
|
g.p("\n}\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gen) qualifiedID(path, sym string) string {
|
||||||
|
name := g.qualifyImport(path)
|
||||||
|
if name == "" {
|
||||||
|
return sym
|
||||||
|
}
|
||||||
|
return name + "." + sym
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gen) qualifyImport(path string) string {
|
||||||
|
if path == g.currPackage {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if name := g.imports[path]; name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
name := fmt.Sprintf("pkg%d", g.n)
|
||||||
|
g.n++
|
||||||
|
g.imports[path] = name
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gen) qualifyPkg(pkg *types.Package) string {
|
||||||
|
return g.qualifyImport(pkg.Path())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gen) p(format string, args ...interface{}) {
|
||||||
|
fmt.Fprintf(&g.buf, format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A module describes a set of providers. The zero value is an empty
|
||||||
|
// module.
|
||||||
|
type module struct {
|
||||||
|
providers []*providerInfo
|
||||||
|
imports []moduleImport
|
||||||
|
}
|
||||||
|
|
||||||
|
type moduleImport struct {
|
||||||
|
moduleRef
|
||||||
|
pos token.Pos
|
||||||
|
}
|
||||||
|
|
||||||
|
const implicitModuleName = "Module"
|
||||||
|
|
||||||
|
// findModules processes a package and extracts the modules declared in it.
|
||||||
|
func findModules(fset *token.FileSet, pkg *types.Package, typeInfo *types.Info, files []*ast.File) (map[string]*module, error) {
|
||||||
|
modules := make(map[string]*module)
|
||||||
|
var directives []directive
|
||||||
|
for _, f := range files {
|
||||||
|
fileScope := typeInfo.Scopes[f]
|
||||||
|
for _, c := range f.Comments {
|
||||||
|
directives = extractDirectives(directives[:0], c)
|
||||||
|
for _, d := range directives {
|
||||||
|
switch d.kind {
|
||||||
|
case "provide", "use":
|
||||||
|
// handled later
|
||||||
|
case "import":
|
||||||
|
if fileScope == nil {
|
||||||
|
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fset.File(f.Pos()).Name())
|
||||||
|
}
|
||||||
|
var name, spec string
|
||||||
|
if strings.HasPrefix(d.line, `"`) {
|
||||||
|
name, spec = implicitModuleName, d.line
|
||||||
|
} else if i := strings.IndexByte(d.line, ' '); i != -1 {
|
||||||
|
name, spec = d.line[:i], d.line[i+1:]
|
||||||
|
} else {
|
||||||
|
name, spec = implicitModuleName, d.line
|
||||||
|
}
|
||||||
|
ref, err := parseModuleRef(spec, fileScope, pkg.Path(), d.pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%v: %v", fset.Position(d.pos), err)
|
||||||
|
}
|
||||||
|
if ref.importPath != pkg.Path() {
|
||||||
|
imported := false
|
||||||
|
for _, imp := range pkg.Imports() {
|
||||||
|
if ref.importPath == imp.Path() {
|
||||||
|
imported = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !imported {
|
||||||
|
return nil, fmt.Errorf("%v: module %s imports %q which is not in the package's imports", fset.Position(d.pos), name, ref.importPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mod := modules[name]; mod != nil {
|
||||||
|
found := false
|
||||||
|
for _, other := range mod.imports {
|
||||||
|
if ref == other.moduleRef {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mod.imports = append(mod.imports, moduleImport{moduleRef: ref, pos: d.pos})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
modules[name] = &module{
|
||||||
|
imports: []moduleImport{{moduleRef: ref, pos: d.pos}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%v: unknown directive %s", fset.Position(d.pos), d.kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmap := ast.NewCommentMap(fset, f, f.Comments)
|
||||||
|
for _, decl := range f.Decls {
|
||||||
|
directives = directives[:0]
|
||||||
|
for _, cg := range cmap[decl] {
|
||||||
|
directives = extractDirectives(directives, cg)
|
||||||
|
}
|
||||||
|
fn, isFunction := decl.(*ast.FuncDecl)
|
||||||
|
var providerModule string
|
||||||
|
for _, d := range directives {
|
||||||
|
if d.kind != "provide" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if providerModule != "" {
|
||||||
|
return nil, fmt.Errorf("%v: multiple provide directives for %s", fset.Position(d.pos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
if !isFunction {
|
||||||
|
return nil, fmt.Errorf("%v: only functions can be marked as providers", fset.Position(d.pos))
|
||||||
|
}
|
||||||
|
if d.line == "" {
|
||||||
|
providerModule = implicitModuleName
|
||||||
|
} else {
|
||||||
|
// TODO(light): validate identifier
|
||||||
|
providerModule = d.line
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if providerModule == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fpos := fn.Pos()
|
||||||
|
sig := typeInfo.ObjectOf(fn.Name).Type().(*types.Signature)
|
||||||
|
r := sig.Results()
|
||||||
|
var hasErr bool
|
||||||
|
switch r.Len() {
|
||||||
|
case 1:
|
||||||
|
hasErr = false
|
||||||
|
case 2:
|
||||||
|
if t := r.At(1).Type(); !types.Identical(t, errorType) {
|
||||||
|
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error", fset.Position(fpos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
hasErr = true
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name.Name)
|
||||||
|
}
|
||||||
|
out := r.At(0).Type()
|
||||||
|
p := sig.Params()
|
||||||
|
provider := &providerInfo{
|
||||||
|
importPath: pkg.Path(),
|
||||||
|
funcName: fn.Name.Name,
|
||||||
|
pos: fn.Pos(),
|
||||||
|
args: make([]types.Type, p.Len()),
|
||||||
|
out: out,
|
||||||
|
hasErr: hasErr,
|
||||||
|
}
|
||||||
|
for i := 0; i < p.Len(); i++ {
|
||||||
|
provider.args[i] = p.At(i).Type()
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
if types.Identical(provider.args[i], provider.args[j]) {
|
||||||
|
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.args[j], nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mod := modules[providerModule]; mod != nil {
|
||||||
|
for _, other := range mod.providers {
|
||||||
|
if types.Identical(other.out, provider.out) {
|
||||||
|
return nil, fmt.Errorf("%v: module %s has multiple providers for %s (previous declaration at %v)", fset.Position(fpos), providerModule, types.TypeString(provider.out, nil), fset.Position(other.pos))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod.providers = append(mod.providers, provider)
|
||||||
|
} else {
|
||||||
|
modules[providerModule] = &module{
|
||||||
|
providers: []*providerInfo{provider},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// moduleCache is a lazily evaluated index of modules.
|
||||||
|
type moduleCache struct {
|
||||||
|
modules map[string]map[string]*module
|
||||||
|
fset *token.FileSet
|
||||||
|
prog *loader.Program
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModuleCache(prog *loader.Program) *moduleCache {
|
||||||
|
return &moduleCache{
|
||||||
|
fset: prog.Fset,
|
||||||
|
prog: prog,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *moduleCache) get(ref moduleRef) (*module, error) {
|
||||||
|
if mods, cached := mc.modules[ref.importPath]; cached {
|
||||||
|
mod := mods[ref.moduleName]
|
||||||
|
if mod == nil {
|
||||||
|
return nil, fmt.Errorf("no such module %s in package %q", ref.moduleName, ref.importPath)
|
||||||
|
}
|
||||||
|
return mod, nil
|
||||||
|
}
|
||||||
|
if mc.modules == nil {
|
||||||
|
mc.modules = make(map[string]map[string]*module)
|
||||||
|
}
|
||||||
|
pkg, info, files, err := mc.getpkg(ref.importPath)
|
||||||
|
if err != nil {
|
||||||
|
mc.modules[ref.importPath] = nil
|
||||||
|
return nil, fmt.Errorf("analyze package: %v", err)
|
||||||
|
}
|
||||||
|
mods, err := findModules(mc.fset, pkg, info, files)
|
||||||
|
if err != nil {
|
||||||
|
mc.modules[ref.importPath] = nil
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mc.modules[ref.importPath] = mods
|
||||||
|
mod := mods[ref.moduleName]
|
||||||
|
if mod == nil {
|
||||||
|
return nil, fmt.Errorf("no such module %s in package %q", ref.moduleName, ref.importPath)
|
||||||
|
}
|
||||||
|
return mod, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *moduleCache) getpkg(path string) (*types.Package, *types.Info, []*ast.File, error) {
|
||||||
|
// TODO(light): allow other implementations for testing
|
||||||
|
|
||||||
|
pkg := mc.prog.Package(path)
|
||||||
|
if pkg == nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("package %q not found", path)
|
||||||
|
}
|
||||||
|
return pkg.Pkg, &pkg.Info, pkg.Files, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// solve finds the sequence of calls required to produce an output type
|
||||||
|
// with an optional set of provided inputs.
|
||||||
|
func solve(mc *moduleCache, out types.Type, given []types.Type, modules []moduleRef) ([]call, error) {
|
||||||
|
for i, g := range given {
|
||||||
|
for _, h := range given[:i] {
|
||||||
|
if types.Identical(g, h) {
|
||||||
|
return nil, fmt.Errorf("multiple inputs of the same type %s", types.TypeString(g, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
providers, err := buildProviderMap(mc, modules)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start building the mapping of type to local variable of the given type.
|
||||||
|
// The first len(given) local variables are the given types.
|
||||||
|
index := new(typeutil.Map)
|
||||||
|
for i, g := range given {
|
||||||
|
if p := providers.At(g); p != nil {
|
||||||
|
pp := p.(*providerInfo)
|
||||||
|
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.funcName, mc.fset.Position(pp.pos))
|
||||||
|
}
|
||||||
|
index.Set(g, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topological sort of the directed graph defined by the providers
|
||||||
|
// using a depth-first search. The graph may contain cycles, which
|
||||||
|
// should trigger an error.
|
||||||
|
var calls []call
|
||||||
|
var visit func(trail []types.Type) error
|
||||||
|
visit = func(trail []types.Type) error {
|
||||||
|
typ := trail[len(trail)-1]
|
||||||
|
if index.At(typ) != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, t := range trail[:len(trail)-1] {
|
||||||
|
if types.Identical(typ, t) {
|
||||||
|
// TODO(light): describe cycle
|
||||||
|
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p, _ := providers.At(typ).(*providerInfo)
|
||||||
|
if p == nil {
|
||||||
|
if len(trail) == 1 {
|
||||||
|
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
||||||
|
}
|
||||||
|
// TODO(light): give name of provider
|
||||||
|
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2], nil))
|
||||||
|
}
|
||||||
|
for _, a := range p.args {
|
||||||
|
// TODO(light): this will discard grown trail arrays.
|
||||||
|
if err := visit(append(trail, a)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args := make([]int, len(p.args))
|
||||||
|
for i := range p.args {
|
||||||
|
args[i] = index.At(p.args[i]).(int)
|
||||||
|
}
|
||||||
|
index.Set(typ, len(given)+len(calls))
|
||||||
|
calls = append(calls, call{
|
||||||
|
importPath: p.importPath,
|
||||||
|
funcName: p.funcName,
|
||||||
|
args: args,
|
||||||
|
out: typ,
|
||||||
|
hasErr: p.hasErr,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := visit([]types.Type{out}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return calls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildProviderMap(mc *moduleCache, modules []moduleRef) (*typeutil.Map, error) {
|
||||||
|
type nextEnt struct {
|
||||||
|
to moduleRef
|
||||||
|
|
||||||
|
from moduleRef
|
||||||
|
pos token.Pos
|
||||||
|
}
|
||||||
|
|
||||||
|
pm := new(typeutil.Map) // to *providerInfo
|
||||||
|
visited := make(map[moduleRef]struct{})
|
||||||
|
var next []nextEnt
|
||||||
|
for _, ref := range modules {
|
||||||
|
next = append(next, nextEnt{to: ref})
|
||||||
|
}
|
||||||
|
for len(next) > 0 {
|
||||||
|
curr := next[0]
|
||||||
|
copy(next, next[1:])
|
||||||
|
next = next[:len(next)-1]
|
||||||
|
if _, skip := visited[curr.to]; skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
visited[curr.to] = struct{}{}
|
||||||
|
mod, err := mc.get(curr.to)
|
||||||
|
if err != nil {
|
||||||
|
if !curr.pos.IsValid() {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
|
||||||
|
}
|
||||||
|
for _, p := range mod.providers {
|
||||||
|
if prev := pm.At(p.out); prev != nil {
|
||||||
|
pos := mc.fset.Position(p.pos)
|
||||||
|
typ := types.TypeString(p.out, nil)
|
||||||
|
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
|
||||||
|
if curr.from.importPath != "" {
|
||||||
|
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
|
||||||
|
}
|
||||||
|
pm.Set(p.out, p)
|
||||||
|
}
|
||||||
|
for _, imp := range mod.imports {
|
||||||
|
next = append(next, nextEnt{to: imp.moduleRef, from: curr.to, pos: imp.pos})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// A call represents a step of an injector function.
|
||||||
|
type call struct {
|
||||||
|
// importPath and funcName identify the provider function to call.
|
||||||
|
importPath string
|
||||||
|
funcName string
|
||||||
|
|
||||||
|
// args is a list of arguments to call the provider with. Each element is either:
|
||||||
|
// a) one of the givens (args[i] < len(given)) or
|
||||||
|
// b) the result of a previous provider call (args[i] >= len(given)).
|
||||||
|
args []int
|
||||||
|
|
||||||
|
// out is the type produced by this provider call.
|
||||||
|
out types.Type
|
||||||
|
|
||||||
|
// hasErr is true if the provider call returns an error.
|
||||||
|
hasErr bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// providerInfo records the signature of a provider function.
|
||||||
|
type providerInfo struct {
|
||||||
|
importPath string
|
||||||
|
funcName string
|
||||||
|
pos token.Pos
|
||||||
|
args []types.Type
|
||||||
|
out types.Type
|
||||||
|
hasErr bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// A moduleRef is a parsed reference to a collection of providers.
|
||||||
|
type moduleRef struct {
|
||||||
|
importPath string
|
||||||
|
moduleName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseModuleRef(ref string, s *types.Scope, pkg string, pos token.Pos) (moduleRef, error) {
|
||||||
|
// TODO(light): verify that module name is an identifier before returning
|
||||||
|
|
||||||
|
i := strings.LastIndexByte(ref, '.')
|
||||||
|
if i == -1 {
|
||||||
|
return moduleRef{importPath: pkg, moduleName: ref}, nil
|
||||||
|
}
|
||||||
|
imp, name := ref[:i], ref[i+1:]
|
||||||
|
if strings.HasPrefix(imp, `"`) {
|
||||||
|
path, err := strconv.Unquote(imp)
|
||||||
|
if err != nil {
|
||||||
|
return moduleRef{}, fmt.Errorf("parse module reference %q: bad import path", ref)
|
||||||
|
}
|
||||||
|
return moduleRef{importPath: path, moduleName: name}, nil
|
||||||
|
}
|
||||||
|
_, obj := s.LookupParent(imp, pos)
|
||||||
|
if obj == nil {
|
||||||
|
return moduleRef{}, fmt.Errorf("parse module reference %q: unknown identifier %s", ref, imp)
|
||||||
|
}
|
||||||
|
pn, ok := obj.(*types.PkgName)
|
||||||
|
if !ok {
|
||||||
|
return moduleRef{}, fmt.Errorf("parse module reference %q: %s does not name a package", ref, imp)
|
||||||
|
}
|
||||||
|
return moduleRef{importPath: pn.Imported().Path(), moduleName: name}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ref moduleRef) String() string {
|
||||||
|
return strconv.Quote(ref.importPath) + "." + ref.moduleName
|
||||||
|
}
|
||||||
|
|
||||||
|
type directive struct {
|
||||||
|
pos token.Pos
|
||||||
|
kind string
|
||||||
|
line string
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
|
||||||
|
const prefix = "goose:"
|
||||||
|
text := cg.Text()
|
||||||
|
for len(text) > 0 {
|
||||||
|
text = strings.TrimLeft(text, " \t\r\n")
|
||||||
|
if !strings.HasPrefix(text, prefix) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
line := text[len(prefix):]
|
||||||
|
if i := strings.IndexByte(line, '\n'); i != -1 {
|
||||||
|
line, text = line[:i], line[i+1:]
|
||||||
|
} else {
|
||||||
|
text = ""
|
||||||
|
}
|
||||||
|
if i := strings.IndexByte(line, ' '); i != -1 {
|
||||||
|
d = append(d, directive{
|
||||||
|
kind: line[:i],
|
||||||
|
line: strings.TrimSpace(line[i+1:]),
|
||||||
|
pos: cg.Pos(), // TODO(light): more precise position
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
d = append(d, directive{
|
||||||
|
kind: line,
|
||||||
|
pos: cg.Pos(), // TODO(light): more precise position
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// isInjectFile reports whether a given file is an injection template.
|
||||||
|
func isInjectFile(f *ast.File) bool {
|
||||||
|
// TODO(light): better determination
|
||||||
|
for _, cg := range f.Comments {
|
||||||
|
text := cg.Text()
|
||||||
|
if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// zeroValue returns the shortest expression that evaluates to the zero
|
||||||
|
// value for the given type.
|
||||||
|
func zeroValue(t types.Type, qf types.Qualifier) string {
|
||||||
|
switch u := t.Underlying().(type) {
|
||||||
|
case *types.Array, *types.Struct:
|
||||||
|
return types.TypeString(t, qf) + "{}"
|
||||||
|
case *types.Basic:
|
||||||
|
info := u.Info()
|
||||||
|
switch {
|
||||||
|
case info&types.IsBoolean != 0:
|
||||||
|
return "false"
|
||||||
|
case info&(types.IsInteger|types.IsFloat|types.IsComplex) != 0:
|
||||||
|
return "0"
|
||||||
|
case info&types.IsString != 0:
|
||||||
|
return `""`
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice:
|
||||||
|
return "nil"
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorType = types.Universe.Lookup("error").Type()
|
||||||
406
internal/goose/goose_test.go
Normal file
406
internal/goose/goose_test.go
Normal file
@@ -0,0 +1,406 @@
|
|||||||
|
package goose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"go/build"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO(light): pull this out into a testdata directory
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
files map[string]string
|
||||||
|
pkg string
|
||||||
|
wantOutput string
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No-op build",
|
||||||
|
files: map[string]string{
|
||||||
|
"foo/foo.go": `package main; import "fmt"; func main() { fmt.Println("Hello, World!"); }`,
|
||||||
|
},
|
||||||
|
pkg: "foo",
|
||||||
|
wantOutput: "Hello, World!\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Niladic identity provider",
|
||||||
|
files: map[string]string{
|
||||||
|
"foo/foo.go": `package main
|
||||||
|
import "fmt"
|
||||||
|
func main() { fmt.Println(injectedMessage()); }
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
|
||||||
|
// provideMessage provides a friendly user greeting.
|
||||||
|
func provideMessage() string { return "Hello, World!"; }
|
||||||
|
`,
|
||||||
|
"foo/foo_goose.go": `//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use Module
|
||||||
|
|
||||||
|
func injectedMessage() string
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
pkg: "foo",
|
||||||
|
wantOutput: "Hello, World!\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing use",
|
||||||
|
files: map[string]string{
|
||||||
|
"foo/foo.go": `package main
|
||||||
|
import "fmt"
|
||||||
|
func main() { fmt.Println(injectedMessage()); }
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
|
||||||
|
// provideMessage provides a friendly user greeting.
|
||||||
|
func provideMessage() string { return "Hello, World!"; }
|
||||||
|
`,
|
||||||
|
"foo/foo_goose.go": `//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
func injectedMessage() string
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
pkg: "foo",
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Chain",
|
||||||
|
files: map[string]string{
|
||||||
|
"foo/foo.go": `package main
|
||||||
|
import "fmt"
|
||||||
|
func main() {
|
||||||
|
fmt.Println(injectFooBar())
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
type FooBar int
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideFoo() Foo { return 41 }
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideFooBar(foo Foo) FooBar { return FooBar(foo) + 1 }
|
||||||
|
`,
|
||||||
|
"foo/foo_goose.go": `//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use Module
|
||||||
|
|
||||||
|
func injectFooBar() FooBar
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
pkg: "foo",
|
||||||
|
wantOutput: "42\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Two deps",
|
||||||
|
files: map[string]string{
|
||||||
|
"foo/foo.go": `package main
|
||||||
|
import "fmt"
|
||||||
|
func main() {
|
||||||
|
fmt.Println(injectFooBar())
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
type Bar int
|
||||||
|
type FooBar int
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideFoo() Foo { return 40 }
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideBar() Bar { return 2 }
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideFooBar(foo Foo, bar Bar) FooBar { return FooBar(foo) + FooBar(bar) }
|
||||||
|
`,
|
||||||
|
"foo/foo_goose.go": `//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use Module
|
||||||
|
|
||||||
|
func injectFooBar() FooBar
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
pkg: "foo",
|
||||||
|
wantOutput: "42\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Inject input",
|
||||||
|
files: map[string]string{
|
||||||
|
"foo/foo.go": `package main
|
||||||
|
import "fmt"
|
||||||
|
func main() {
|
||||||
|
fmt.Println(injectFooBar(40))
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
type Bar int
|
||||||
|
type FooBar int
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideBar() Bar { return 2 }
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideFooBar(foo Foo, bar Bar) FooBar { return FooBar(foo) + FooBar(bar) }
|
||||||
|
`,
|
||||||
|
"foo/foo_goose.go": `//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use Module
|
||||||
|
|
||||||
|
func injectFooBar(foo Foo) FooBar
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
pkg: "foo",
|
||||||
|
wantOutput: "42\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Inject input conflict",
|
||||||
|
files: map[string]string{
|
||||||
|
"foo/foo.go": `package main
|
||||||
|
import "fmt"
|
||||||
|
func main() {
|
||||||
|
fmt.Println(injectBar(40))
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
type Bar int
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideFoo() Foo { return -888 }
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideBar(foo Foo) Bar { return 2 }
|
||||||
|
`,
|
||||||
|
"foo/foo_goose.go": `//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use Module
|
||||||
|
|
||||||
|
func injectBar(foo Foo) Bar
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
pkg: "foo",
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Return error",
|
||||||
|
files: map[string]string{
|
||||||
|
"foo/foo.go": `package main
|
||||||
|
import "errors"
|
||||||
|
import "fmt"
|
||||||
|
import "strings"
|
||||||
|
func main() {
|
||||||
|
foo, err := injectFoo()
|
||||||
|
fmt.Println(foo)
|
||||||
|
if err == nil {
|
||||||
|
fmt.Println("<nil>")
|
||||||
|
} else {
|
||||||
|
fmt.Println(strings.Contains(err.Error(), "there is no Foo"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Foo int
|
||||||
|
|
||||||
|
//goose:provide
|
||||||
|
func provideFoo() (Foo, error) { return 42, errors.New("there is no Foo") }
|
||||||
|
`,
|
||||||
|
"foo/foo_goose.go": `//+build gooseinject
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
//goose:use Module
|
||||||
|
|
||||||
|
func injectFoo() (Foo, error)
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
pkg: "foo",
|
||||||
|
wantOutput: "0\ntrue\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeneratedCode(t *testing.T) {
|
||||||
|
if _, err := os.Stat(filepath.Join(build.Default.GOROOT, "bin", "go")); err != nil {
|
||||||
|
t.Fatalf("go toolchain not available: %v", err)
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
gopath, err := ioutil.TempDir("", "goose_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(gopath)
|
||||||
|
bctx := &build.Context{
|
||||||
|
GOARCH: build.Default.GOARCH,
|
||||||
|
GOOS: build.Default.GOOS,
|
||||||
|
GOROOT: build.Default.GOROOT,
|
||||||
|
GOPATH: gopath,
|
||||||
|
CgoEnabled: build.Default.CgoEnabled,
|
||||||
|
Compiler: build.Default.Compiler,
|
||||||
|
ReleaseTags: build.Default.ReleaseTags,
|
||||||
|
}
|
||||||
|
for name, content := range test.files {
|
||||||
|
p := filepath.Join(gopath, "src", filepath.FromSlash(name))
|
||||||
|
if err := os.MkdirAll(filepath.Dir(p), 0777); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := ioutil.WriteFile(p, []byte(content), 0666); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gen, err := Generate(bctx, gopath, test.pkg)
|
||||||
|
if len(gen) > 0 {
|
||||||
|
defer t.Logf("goose_gen.go:\n%s", gen)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if !test.wantError {
|
||||||
|
t.Fatalf("goose: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil && test.wantError {
|
||||||
|
t.Fatal("goose succeeded; want error")
|
||||||
|
}
|
||||||
|
if len(gen) > 0 {
|
||||||
|
genPath := filepath.Join(gopath, "src", filepath.FromSlash(test.pkg), "goose_gen.go")
|
||||||
|
if err := ioutil.WriteFile(genPath, gen, 0666); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
testExePath := filepath.Join(gopath, "bin", "testprog")
|
||||||
|
if err := runGo(bctx, "build", "-o", testExePath, test.pkg); err != nil {
|
||||||
|
t.Fatal("build:", err)
|
||||||
|
}
|
||||||
|
out, err := exec.Command(testExePath).Output()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("run compiled program:", err)
|
||||||
|
}
|
||||||
|
if string(out) != test.wantOutput {
|
||||||
|
t.Errorf("compiled program output = %q; want %q", out, test.wantOutput)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeterminism(t *testing.T) {
|
||||||
|
runs := 10
|
||||||
|
if testing.Short() {
|
||||||
|
runs = 3
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
if test.wantError {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
gopath, err := ioutil.TempDir("", "goose_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(gopath)
|
||||||
|
bctx := &build.Context{
|
||||||
|
GOARCH: build.Default.GOARCH,
|
||||||
|
GOOS: build.Default.GOOS,
|
||||||
|
GOROOT: build.Default.GOROOT,
|
||||||
|
GOPATH: gopath,
|
||||||
|
CgoEnabled: build.Default.CgoEnabled,
|
||||||
|
Compiler: build.Default.Compiler,
|
||||||
|
ReleaseTags: build.Default.ReleaseTags,
|
||||||
|
}
|
||||||
|
for name, content := range test.files {
|
||||||
|
p := filepath.Join(gopath, "src", filepath.FromSlash(name))
|
||||||
|
if err := os.MkdirAll(filepath.Dir(p), 0777); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := ioutil.WriteFile(p, []byte(content), 0666); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gold, err := Generate(bctx, gopath, test.pkg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("goose:", err)
|
||||||
|
}
|
||||||
|
goldstr := string(gold)
|
||||||
|
for i := 0; i < runs-1; i++ {
|
||||||
|
out, err := Generate(bctx, gopath, test.pkg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("goose (on subsequent run):", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gold, out) {
|
||||||
|
t.Fatalf("goose output differs when run repeatedly on same input:\n%s", diff(goldstr, string(out)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runGo(bctx *build.Context, args ...string) error {
|
||||||
|
exe := filepath.Join(bctx.GOROOT, "bin", "go")
|
||||||
|
c := exec.Command(exe, args...)
|
||||||
|
c.Env = append(os.Environ(), "GOROOT="+bctx.GOROOT, "GOARCH="+bctx.GOARCH, "GOOS="+bctx.GOOS, "GOPATH="+bctx.GOPATH)
|
||||||
|
if bctx.CgoEnabled {
|
||||||
|
c.Env = append(c.Env, "CGO_ENABLED=1")
|
||||||
|
} else {
|
||||||
|
c.Env = append(c.Env, "CGO_ENABLED=0")
|
||||||
|
}
|
||||||
|
// TODO(someday): set -compiler flag if needed.
|
||||||
|
out, err := c.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
if len(out) > 0 {
|
||||||
|
return fmt.Errorf("%v; output:\n%s", err, out)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func diff(want, got string) string {
|
||||||
|
d, err := runDiff([]byte(want), []byte(got))
|
||||||
|
if err == nil {
|
||||||
|
return string(d)
|
||||||
|
}
|
||||||
|
return "*** got:\n" + got + "\n\n*** want:\n" + want
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDiff(a, b []byte) ([]byte, error) {
|
||||||
|
fa, err := ioutil.TempFile("", "goose_test_diff")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
os.Remove(fa.Name())
|
||||||
|
fa.Close()
|
||||||
|
}()
|
||||||
|
fb, err := ioutil.TempFile("", "goose_test_diff")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
os.Remove(fb.Name())
|
||||||
|
fb.Close()
|
||||||
|
}()
|
||||||
|
if _, err := fa.Write(a); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := fb.Write(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c := exec.Command("diff", "-u", fa.Name(), fb.Name())
|
||||||
|
out, err := c.Output()
|
||||||
|
return out, err
|
||||||
|
}
|
||||||
52
main.go
Normal file
52
main.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
// goose is a compile-time dependency injection tool.
|
||||||
|
//
|
||||||
|
// See README.md for an overview.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"go/build"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"codename/goose/internal/goose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var pkg string
|
||||||
|
switch len(os.Args) {
|
||||||
|
case 1:
|
||||||
|
pkg = "."
|
||||||
|
case 2:
|
||||||
|
pkg = os.Args[1]
|
||||||
|
default:
|
||||||
|
fmt.Fprintln(os.Stderr, "goose: usage: goose [PKG]")
|
||||||
|
os.Exit(64)
|
||||||
|
}
|
||||||
|
wd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "goose:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "goose:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
out, err := goose.Generate(&build.Default, wd, pkg)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "goose:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
// No Goose directives, don't write anything.
|
||||||
|
fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p := filepath.Join(pkgInfo.Dir, "goose_gen.go")
|
||||||
|
if err := ioutil.WriteFile(p, out, 0666); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "goose:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user