goose: use marker functions instead of comments

To avoid making this CL too large, I did not migrate the existing goose
comments through the repository.  This will be addressed in a subsequent
CL.

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-27 13:44:54 -04:00
parent 13698e656a
commit f8e446fa17
59 changed files with 713 additions and 983 deletions

113
README.md
View File

@@ -13,34 +13,27 @@ might have hand-written.
### Defining Providers ### Defining Providers
The primary mechanism in goose is the **provider**: a function that can The primary mechanism in goose is the **provider**: a function that can
produce a value, annotated with the special `goose:provide` directive. These produce a value. These functions are ordinary Go code.
functions are otherwise ordinary Go code.
```go ```go
package foobarbaz package foobarbaz
type Foo int type Foo int
// goose:provide
// ProvideFoo returns a Foo. // ProvideFoo returns a Foo.
func ProvideFoo() Foo { func ProvideFoo() Foo {
return 42 return 42
} }
``` ```
Providers are always part of a **provider set**: if there is no provider set
named on the `//goose:provide` line, then the provider is added to the provider
set with the same name as the function (`ProvideFoo`, in this case).
Providers can specify dependencies with parameters: Providers can specify dependencies with parameters:
```go ```go
package foobarbaz package foobarbaz
type Bar int // ...
// goose:provide SuperSet type Bar int
// ProvideBar returns a Bar: a negative Foo. // ProvideBar returns a Bar: a negative Foo.
func ProvideBar(foo Foo) Bar { func ProvideBar(foo Foo) Bar {
@@ -58,9 +51,9 @@ import (
"errors" "errors"
) )
type Baz int // ...
// goose:provide SuperSet type Baz int
// ProvideBaz returns a value if Bar is not zero. // ProvideBaz returns a value if Bar is not zero.
func ProvideBaz(ctx context.Context, bar Bar) (Baz, error) { func ProvideBaz(ctx context.Context, bar Bar) (Baz, error) {
@@ -71,23 +64,36 @@ func ProvideBaz(ctx context.Context, bar Bar) (Baz, error) {
} }
``` ```
Provider sets can import other provider sets. To add the `ProvideFoo` set to Providers can be grouped in **provider sets**. To add these providers to a new
`SuperSet`: set called `SuperSet`, use the `goose.NewSet` function:
```go ```go
// goose:import SuperSet ProvideFoo package foobarbaz
import (
// ...
"codename/goose"
)
// ...
var SuperSet = goose.NewSet(ProvideFoo, ProvideBar, ProvideBaz)
``` ```
You can also import provider sets in another package, provided that you have a You can also add other provider sets into a provider set.
Go import for the package:
```go ```go
// goose:import SuperSet "example.com/some/other/pkg".OtherSet package foobarbaz
```
A provider set reference is an optional import qualifier (either a package name import (
or a quoted import path, as seen above) ending with a dot, followed by the // ...
provider set name. "example.com/some/other/pkg"
)
// ...
var MegaSet = goose.NewSet(SuperSet, pkg.OtherSet)
```
### Injectors ### Injectors
@@ -95,32 +101,34 @@ An application wires up these providers with an **injector**: a function that
calls providers in dependency order. With goose, you write the injector's calls providers in dependency order. With goose, you write the injector's
signature, then goose generates the function's body. signature, then goose generates the function's body.
An injector is declared by writing a function declaration without a body in a An injector is declared by writing a function declaration whose body is a call
file guarded by a `gooseinject` build tag. Let's say that the above providers to `panic()` with a call to `goose.Use` as its argument. Let's say that the
were defined in a package called `example.com/foobarbaz`. The following would above providers were defined in a package called `example.com/foobarbaz`. The
declare an injector to obtain a `Baz`: following would declare an injector to obtain a `Baz`:
```go ```go
//+build gooseinject // +build gooseinject
// ^ build tag makes sure the stub is not built in the final build
package main package main
import ( import (
"context" "context"
"codename/goose"
"example.com/foobarbaz" "example.com/foobarbaz"
) )
// goose:use foobarbaz.SuperSet func initializeApp(ctx context.Context) (foobarbaz.Baz, error) {
panic(goose.Use(foobarbaz.MegaSet))
func initializeApp(ctx context.Context) (foobarbaz.Baz, error) }
``` ```
Like providers, injectors can be parameterized on inputs (which then get sent to Like providers, injectors can be parameterized on inputs (which then get sent to
providers) and can return errors. Each `goose:use` directive specifies a providers) and can return errors. Arguments to `goose.Use` are the same as
provider set to use in the injection. An injector can have one or more `goose.NewSet`: they form a provider set. This is the provider set that gets
`goose:use` directives. `goose:use` directives use the same syntax as used during code generation for that injector.
`goose:import` to reference provider sets.
You can generate the injector by invoking goose in the package directory: You can generate the injector by invoking goose in the package directory:
@@ -164,7 +172,7 @@ func initializeApp(ctx context.Context) (foobarbaz.Baz, error) {
``` ```
As you can see, the output is very close to what a developer would write As you can see, the output is very close to what a developer would write
themselves. Further, there is no dependency on goose at runtime: all of the themselves. Further, there is little dependency on goose at runtime: all of the
written code is just normal Go code, and can be used without goose. written code is just normal Go code, and can be used without goose.
[`go generate`]: https://blog.golang.org/generate [`go generate`]: https://blog.golang.org/generate
@@ -228,19 +236,21 @@ func (b *Bar) Foo() string {
return string(*b) return string(*b)
} }
//goose:provide BarFooer func ProvideBar() *Bar {
func provideBar() *Bar {
b := new(Bar) b := new(Bar)
*b = "Hello, World!" *b = "Hello, World!"
return b return b
} }
//goose:bind BarFooer Fooer *Bar var BarFooer = goose.NewSet(
ProvideBar,
goose.Bind(Fooer(nil), (*Bar)(nil)))
``` ```
The syntax is provider set name, interface type, and finally the concrete type. The first argument to `goose.Bind` is a nil value for the interface type and the
An interface binding does not necessarily need to have a provider in the same second argument is a zero value of the concrete type. An interface binding does
set that provides the concrete type. not necessarily need to have a provider in the same set that provides the
concrete type.
[type identity]: https://golang.org/ref/spec#Type_identity [type identity]: https://golang.org/ref/spec#Type_identity
[return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces [return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces
@@ -256,32 +266,31 @@ following providers:
type Foo int type Foo int
type Bar int type Bar int
//goose:provide Foo func ProvideFoo() Foo {
func provideFoo() Foo {
// ... // ...
} }
//goose:provide Bar func ProvideBar() Bar {
func provideBar() Bar {
// ... // ...
} }
//goose:provide
type FooBar struct { type FooBar struct {
Foo Foo Foo Foo
Bar Bar Bar Bar
} }
var Set = goose.NewSet(
ProvideFoo,
ProvideBar,
FooBar{})
``` ```
A generated injector for `FooBar` would look like this: A generated injector for `FooBar` would look like this:
```go ```go
func injectFooBar() FooBar { func injectFooBar() FooBar {
foo := provideFoo() foo := ProvideFoo()
bar := provideBar() bar := ProvideBar()
fooBar := FooBar{ fooBar := FooBar{
Foo: foo, Foo: foo,
Bar: bar, Bar: bar,
@@ -300,8 +309,6 @@ this to either return an aggregated cleanup function to the caller or to clean
up the resource if a later provider returns an error. up the resource if a later provider returns an error.
```go ```go
//goose:provide
func provideFile(log Logger, path Path) (*os.File, func(), error) { func provideFile(log Logger, path Path) (*os.File, func(), error) {
f, err := os.Open(string(path)) f, err := os.Open(string(path))
if err != nil { if err != nil {

View File

@@ -13,6 +13,7 @@ import (
"path/filepath" "path/filepath"
"reflect" "reflect"
"sort" "sort"
"strconv"
"strings" "strings"
"codename/goose/internal/goose" "codename/goose/internal/goose"
@@ -71,9 +72,9 @@ func generate(pkg string) error {
// show runs the show subcommand. // show runs the show subcommand.
// //
// Given one or more packages, show will find all the declared provider // Given one or more packages, show will find all the provider sets
// sets and print what other provider sets it imports and what outputs // declared as top-level variables and print what other provider sets it
// it can produce, given possible inputs. // imports and what outputs it can produce, given possible inputs.
func show(pkgs ...string) error { func show(pkgs ...string) error {
wd, err := os.Getwd() wd, err := os.Getwd()
if err != nil { if err != nil {
@@ -89,11 +90,12 @@ func show(pkgs ...string) error {
} }
sort.Slice(keys, func(i, j int) bool { sort.Slice(keys, func(i, j int) bool {
if keys[i].ImportPath == keys[j].ImportPath { if keys[i].ImportPath == keys[j].ImportPath {
return keys[i].Name < keys[j].Name return keys[i].VarName < keys[j].VarName
} }
return keys[i].ImportPath < keys[j].ImportPath return keys[i].ImportPath < keys[j].ImportPath
}) })
// ANSI color codes. // ANSI color codes.
// TODO(light): Possibly use github.com/fatih/color?
const ( const (
reset = "\x1b[0m" reset = "\x1b[0m"
redBold = "\x1b[0;1;31m" redBold = "\x1b[0;1;31m"
@@ -116,7 +118,7 @@ func show(pkgs ...string) error {
switch v := v.(type) { switch v := v.(type) {
case *goose.Provider: case *goose.Provider:
out[types.TypeString(t, nil)] = v.Pos out[types.TypeString(t, nil)] = v.Pos
case goose.IfaceBinding: case *goose.IfaceBinding:
out[types.TypeString(t, nil)] = v.Pos out[types.TypeString(t, nil)] = v.Pos
default: default:
panic("unreachable") panic("unreachable")
@@ -134,19 +136,19 @@ func show(pkgs ...string) error {
type outGroup struct { type outGroup struct {
name string name string
inputs *typeutil.Map // values are not important inputs *typeutil.Map // values are not important
outputs *typeutil.Map // values are either *goose.Provider or goose.IfaceBinding outputs *typeutil.Map // values are either *goose.Provider or *goose.IfaceBinding
} }
// gather flattens a provider set into outputs grouped by the inputs // gather flattens a provider set into outputs grouped by the inputs
// required to create them. As it flattens the provider set, it records // required to create them. As it flattens the provider set, it records
// the visited provider sets as imports. // the visited named provider sets as imports.
func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports map[string]struct{}) { func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports map[string]struct{}) {
hash := typeutil.MakeHasher() hash := typeutil.MakeHasher()
// Map types to providers and bindings. // Map types to providers and bindings.
pm := new(typeutil.Map) pm := new(typeutil.Map)
pm.SetHasher(hash) pm.SetHasher(hash)
next := []goose.ProviderSetID{key} next := []*goose.ProviderSet{info.Sets[key]}
visited := make(map[goose.ProviderSetID]struct{}) visited := make(map[*goose.ProviderSet]struct{})
imports = make(map[string]struct{}) imports = make(map[string]struct{})
for len(next) > 0 { for len(next) > 0 {
curr := next[len(next)-1] curr := next[len(next)-1]
@@ -155,18 +157,17 @@ func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports ma
continue continue
} }
visited[curr] = struct{}{} visited[curr] = struct{}{}
if curr != key { if curr.Name != "" && !(curr.PkgPath == key.ImportPath && curr.Name == key.VarName) {
imports[curr.String()] = struct{}{} imports[formatProviderSetName(curr.PkgPath, curr.Name)] = struct{}{}
} }
set := info.All[curr] for _, p := range curr.Providers {
for _, p := range set.Providers {
pm.Set(p.Out, p) pm.Set(p.Out, p)
} }
for _, b := range set.Bindings { for _, b := range curr.Bindings {
pm.Set(b.Iface, b) pm.Set(b.Iface, b)
} }
for _, imp := range set.Imports { for _, imp := range curr.Imports {
next = append(next, imp.ProviderSetID) next = append(next, imp)
} }
} }
@@ -238,7 +239,7 @@ func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports ma
inputs: in, inputs: in,
outputs: out, outputs: out,
}) })
case goose.IfaceBinding: case *goose.IfaceBinding:
i, ok := inputVisited.At(p.Provided).(int) i, ok := inputVisited.At(p.Provided).(int)
if !ok { if !ok {
stk = append(stk, curr, p.Provided) stk = append(stk, curr, p.Provided)
@@ -327,3 +328,8 @@ func sortSet(set interface{}) []string {
sort.Strings(a) sort.Strings(a)
return a return a
} }
func formatProviderSetName(importPath, varName string) string {
// Since varName is an identifier, it doesn't make sense to quote.
return strconv.Quote(importPath) + "." + varName
}

38
goose.go Normal file
View File

@@ -0,0 +1,38 @@
// Package goose contains directives for goose code generation.
package goose
// ProviderSet is a marker type that collects a group of providers.
type ProviderSet struct{}
// NewSet creates a new provider set that includes the providers in
// its arguments. Each argument is either an exported function value,
// an exported struct (zero) value, or a call to Bind.
func NewSet(...interface{}) ProviderSet {
return ProviderSet{}
}
// Use is placed in the body of an injector function to declare the
// providers to use. Its arguments are the same as NewSet. Its return
// value is an error message that can be sent to panic.
//
// Example:
//
// func injector(ctx context.Context) (*sql.DB, error) {
// panic(Use(otherpkg.Foo, myProviderFunc, goose.Bind()))
// }
func Use(...interface{}) string {
return "implementation not generated, run goose"
}
// A Binding maps an interface to a concrete type.
type Binding struct{}
// Bind declares that a concrete type should be used to satisfy a
// dependency on iface.
//
// Example:
//
// var MySet = goose.NewSet(goose.Bind(MyInterface(nil), new(MyStruct)))
func Bind(iface, to interface{}) Binding {
return Binding{}
}

View File

@@ -42,7 +42,7 @@ type call struct {
// solve finds the sequence of calls required to produce an output type // solve finds the sequence of calls required to produce an output type
// with an optional set of provided inputs. // with an optional set of provided inputs.
func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symref) ([]call, error) { func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, error) {
for i, g := range given { for i, g := range given {
for _, h := range given[:i] { for _, h := range given[:i] {
if types.Identical(g, h) { if types.Identical(g, h) {
@@ -50,7 +50,7 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
} }
} }
} }
providers, err := buildProviderMap(mc, sets) providers, err := buildProviderMap(fset, set)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -61,7 +61,7 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
for i, g := range given { for i, g := range given {
if p := providers.At(g); p != nil { if p := providers.At(g); p != nil {
pp := p.(*Provider) pp := p.(*Provider)
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, mc.fset.Position(pp.Pos)) return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, fset.Position(pp.Pos))
} }
index.Set(g, i) index.Set(g, i)
} }
@@ -135,88 +135,70 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
return calls, nil return calls, nil
} }
func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error) { func buildProviderMap(fset *token.FileSet, set *ProviderSet) (*typeutil.Map, error) {
type nextEnt struct {
to symref
from symref
pos token.Pos
}
type binding struct { type binding struct {
IfaceBinding *IfaceBinding
pset symref set *ProviderSet
from symref
} }
pm := new(typeutil.Map) // to *providerInfo providerMap := new(typeutil.Map) // to *Provider
setMap := new(typeutil.Map) // to *ProviderSet, for error messages
var bindings []binding var bindings []binding
visited := make(map[symref]struct{}) visited := make(map[*ProviderSet]struct{})
var next []nextEnt next := []*ProviderSet{set}
for _, ref := range sets {
next = append(next, nextEnt{to: ref})
}
for len(next) > 0 { for len(next) > 0 {
curr := next[0] curr := next[0]
copy(next, next[1:]) copy(next, next[1:])
next = next[:len(next)-1] next = next[:len(next)-1]
if _, skip := visited[curr.to]; skip { if _, skip := visited[curr]; skip {
continue continue
} }
visited[curr.to] = struct{}{} visited[curr] = struct{}{}
pset, err := mc.get(curr.to) for _, p := range curr.Providers {
if err != nil { if providerMap.At(p.Out) != nil {
if !curr.pos.IsValid() { return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet))
return nil, err
} }
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err) providerMap.Set(p.Out, p)
setMap.Set(p.Out, curr)
} }
for _, p := range pset.Providers { for _, b := range curr.Bindings {
if prev := pm.At(p.Out); prev != nil {
pos := mc.fset.Position(p.Pos)
typ := types.TypeString(p.Out, nil)
prevPos := mc.fset.Position(prev.(*Provider).Pos)
if curr.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
}
pm.Set(p.Out, p)
}
for _, b := range pset.Bindings {
bindings = append(bindings, binding{ bindings = append(bindings, binding{
IfaceBinding: b, IfaceBinding: b,
pset: curr.to, set: curr,
from: curr.from,
}) })
} }
for _, imp := range pset.Imports { for _, imp := range curr.Imports {
next = append(next, nextEnt{to: imp.symref(), from: curr.to, pos: imp.Pos}) next = append(next, imp)
} }
} }
// Validate that bindings have their concrete type provided in the set.
// TODO(light): Move this validation up into provider set creation.
for _, b := range bindings { for _, b := range bindings {
if prev := pm.At(b.Iface); prev != nil { if providerMap.At(b.Iface) != nil {
pos := mc.fset.Position(b.Pos) return nil, bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet))
typ := types.TypeString(b.Iface, nil)
// TODO(light): Error message for conflicting with another interface binding will point at provider instead of binding.
prevPos := mc.fset.Position(prev.(*Provider).Pos)
if b.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
} }
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos) concrete := providerMap.At(b.Provided)
}
concrete := pm.At(b.Provided)
if concrete == nil { if concrete == nil {
pos := mc.fset.Position(b.Pos) pos := fset.Position(b.Pos)
typ := types.TypeString(b.Provided, nil) typ := types.TypeString(b.Provided, nil)
if b.from.importPath == "" {
// Concrete provider is imported directly by injector.
return nil, fmt.Errorf("%v: no binding for %s", pos, typ) return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
} }
return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from) providerMap.Set(b.Iface, concrete)
setMap.Set(b.Iface, b.set)
} }
pm.Set(b.Iface, concrete) return providerMap, nil
} }
return pm, nil
// bindingConflictError creates a new error describing multiple bindings
// for the same output type.
func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error {
position := fset.Position(pos)
typString := types.TypeString(typ, nil)
if prevSet.Name == "" {
prevPosition := fset.Position(prevSet.Pos)
return fmt.Errorf("%v: multiple bindings for %s (previous binding at %v)",
position, typString, prevPosition)
}
return fmt.Errorf("%v: multiple bindings for %s (previous binding in %q.%s)",
position, typString, prevSet.PkgPath, prevSet.Name)
} }

View File

@@ -8,7 +8,7 @@ import (
"go/ast" "go/ast"
"go/build" "go/build"
"go/format" "go/format"
"go/parser" "go/token"
"go/types" "go/types"
"sort" "sort"
"strconv" "strconv"
@@ -22,8 +22,24 @@ import (
// Generate performs dependency injection for a single package, // Generate performs dependency injection for a single package,
// returning the gofmt'd Go source code. // returning the gofmt'd Go source code.
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
conf := newLoaderConfig(bctx, wd, true) mainPkg, err := bctx.Import(pkg, wd, build.FindOnly)
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
// TODO(light): Stop errors from printing to stderr.
conf := &loader.Config{
Build: new(build.Context),
Cwd: wd,
TypeCheckFuncBodies: func(path string) bool {
return path == mainPkg.ImportPath
},
}
*conf.Build = *bctx
n := len(conf.Build.BuildTags)
// TODO(light): Only apply gooseinject build tag on main package.
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
conf.Import(pkg) conf.Import(pkg)
prog, err := conf.Load() prog, err := conf.Load()
if err != nil { if err != nil {
return nil, fmt.Errorf("load: %v", err) return nil, fmt.Errorf("load: %v", err)
@@ -34,47 +50,23 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
} }
pkgInfo := prog.InitialPackages()[0] pkgInfo := prog.InitialPackages()[0]
g := newGen(prog, pkgInfo.Pkg.Path()) g := newGen(prog, pkgInfo.Pkg.Path())
r := newImportResolver(conf, prog.Fset) oc := newObjectCache(prog)
mc := newProviderSetCache(prog, r)
for _, f := range pkgInfo.Files { for _, f := range pkgInfo.Files {
if !isInjectFile(f) {
continue
}
fileScope := pkgInfo.Scopes[f]
groups := parseFile(prog.Fset, f)
for _, decl := range f.Decls { for _, decl := range f.Decls {
fn, ok := decl.(*ast.FuncDecl) fn, ok := decl.(*ast.FuncDecl)
if !ok { if !ok {
continue continue
} }
var dg directiveGroup useCall := isInjector(&pkgInfo.Info, fn)
for _, dg = range groups { if useCall == nil {
if dg.decl == decl { continue
break
} }
} set, err := oc.processNewSet(pkgInfo, useCall)
if dg.decl != decl {
dg = directiveGroup{}
}
var sets []symref
for _, d := range dg.dirs {
if d.kind != "use" {
return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind)
}
args := d.args()
if len(args) == 0 {
return nil, fmt.Errorf("%v: goose:use must have at least one provider set reference", prog.Fset.Position(d.pos))
}
for _, arg := range args {
ref, err := parseSymbolRef(r, arg, fileScope, g.currPackage, d.pos)
if err != nil { if err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err) return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
}
sets = append(sets, ref)
}
} }
sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature) sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature)
if err := g.inject(mc, fn.Name.Name, sig, sets); err != nil { if err := g.inject(prog.Fset, fn.Name.Name, sig, set); err != nil {
return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err) return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err)
} }
} }
@@ -89,23 +81,6 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
return fmtSrc, nil return fmtSrc, nil
} }
func newLoaderConfig(bctx *build.Context, wd string, inject bool) *loader.Config {
// TODO(light): Stop errors from printing to stderr.
conf := &loader.Config{
Build: bctx,
ParserMode: parser.ParseComments,
Cwd: wd,
TypeCheckFuncBodies: func(string) bool { return false },
}
if inject {
conf.Build = new(build.Context)
*conf.Build = *bctx
n := len(conf.Build.BuildTags)
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
}
return conf
}
// gen is the generator state. // gen is the generator state.
type gen struct { type gen struct {
currPackage string currPackage string
@@ -150,7 +125,7 @@ func (g *gen) frame() []byte {
} }
// inject emits the code for an injector. // inject emits the code for an injector.
func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []symref) error { func (g *gen) inject(fset *token.FileSet, name string, sig *types.Signature, set *ProviderSet) error {
results := sig.Results() results := sig.Results()
var returnsCleanup, returnsErr bool var returnsCleanup, returnsErr bool
switch results.Len() { switch results.Len() {
@@ -184,7 +159,7 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se
for i := 0; i < params.Len(); i++ { for i := 0; i < params.Len(); i++ {
given[i] = params.At(i).Type() given[i] = params.At(i).Type()
} }
calls, err := solve(mc, outType, given, sets) calls, err := solve(fset, outType, given, set)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -29,13 +29,19 @@ func TestGoose(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// The marker function package source is needed to have the test cases
// type check. loadTestCase places this file at the well-known import path.
gooseGo, err := ioutil.ReadFile(filepath.Join("..", "..", "goose.go"))
if err != nil {
t.Fatal(err)
}
tests := make([]*testCase, 0, len(testdataEnts)) tests := make([]*testCase, 0, len(testdataEnts))
for _, ent := range testdataEnts { for _, ent := range testdataEnts {
name := ent.Name() name := ent.Name()
if !ent.IsDir() || strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") { if !ent.IsDir() || strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") {
continue continue
} }
test, err := loadTestCase(filepath.Join(testRoot, name)) test, err := loadTestCase(filepath.Join(testRoot, name), gooseGo)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -227,7 +233,7 @@ type testCase struct {
// out.txt file containing the expected output, or the magic string "ERROR" // out.txt file containing the expected output, or the magic string "ERROR"
// if this test should cause generation to fail // if this test should cause generation to fail
// ... any Go files found recursively placed under GOPATH/src/... // ... any Go files found recursively placed under GOPATH/src/...
func loadTestCase(root string) (*testCase, error) { func loadTestCase(root string, gooseGoSrc []byte) (*testCase, error) {
name := filepath.Base(root) name := filepath.Base(root)
pkg, err := ioutil.ReadFile(filepath.Join(root, "pkg")) pkg, err := ioutil.ReadFile(filepath.Join(root, "pkg"))
if err != nil { if err != nil {
@@ -242,7 +248,9 @@ func loadTestCase(root string) (*testCase, error) {
wantError = true wantError = true
out = nil out = nil
} }
goFiles := make(map[string][]byte) goFiles := map[string][]byte{
"codename/goose/goose.go": gooseGoSrc,
}
err = filepath.Walk(root, func(src string, info os.FileInfo, err error) error { err = filepath.Walk(root, func(src string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
return err return err

View File

@@ -6,20 +6,28 @@ import (
"go/build" "go/build"
"go/token" "go/token"
"go/types" "go/types"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
"unicode"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
) )
// A ProviderSet describes a set of providers. The zero value is an empty // A ProviderSet describes a set of providers. The zero value is an empty
// ProviderSet. // ProviderSet.
type ProviderSet struct { type ProviderSet struct {
// Pos is the position of the call to goose.NewSet or goose.Use that
// created the set.
Pos token.Pos
// PkgPath is the import path of the package that declared this set.
PkgPath string
// Name is the variable name of the set, if it came from a package
// variable.
Name string
Providers []*Provider Providers []*Provider
Bindings []IfaceBinding Bindings []*IfaceBinding
Imports []ProviderSetImport Imports []*ProviderSet
} }
// An IfaceBinding declares that a type should be used to satisfy inputs // An IfaceBinding declares that a type should be used to satisfy inputs
@@ -35,12 +43,6 @@ type IfaceBinding struct {
Pos token.Pos Pos token.Pos
} }
// A ProviderSetImport adds providers from one provider set into another.
type ProviderSetImport struct {
ProviderSetID
Pos token.Pos
}
// Provider records the signature of a provider. A provider is a // Provider records the signature of a provider. A provider is a
// single Go object, either a function or a named struct type. // single Go object, either a function or a named struct type.
type Provider struct { type Provider struct {
@@ -87,7 +89,12 @@ type ProviderInput struct {
// Load finds all the provider sets in the given packages, as well as // Load finds all the provider sets in the given packages, as well as
// the provider sets' transitive dependencies. // the provider sets' transitive dependencies.
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) { func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) {
conf := newLoaderConfig(bctx, wd, false) // TODO(light): Stop errors from printing to stderr.
conf := &loader.Config{
Build: bctx,
Cwd: wd,
TypeCheckFuncBodies: func(string) bool { return false },
}
for _, p := range pkgs { for _, p := range pkgs {
conf.Import(p) conf.Import(p)
} }
@@ -95,48 +102,26 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("load: %v", err) return nil, fmt.Errorf("load: %v", err)
} }
r := newImportResolver(conf, prog.Fset)
var next []string
initial := make(map[string]struct{})
for _, pkgInfo := range prog.InitialPackages() {
path := pkgInfo.Pkg.Path()
next = append(next, path)
initial[path] = struct{}{}
}
visited := make(map[string]struct{})
info := &Info{ info := &Info{
Fset: prog.Fset, Fset: prog.Fset,
Sets: make(map[ProviderSetID]*ProviderSet), Sets: make(map[ProviderSetID]*ProviderSet),
All: make(map[ProviderSetID]*ProviderSet),
} }
for len(next) > 0 { oc := newObjectCache(prog)
curr := next[len(next)-1] for _, pkgInfo := range prog.InitialPackages() {
next = next[:len(next)-1] scope := pkgInfo.Pkg.Scope()
if _, ok := visited[curr]; ok { for _, name := range scope.Names() {
item, err := oc.get(scope.Lookup(name))
if err != nil {
continue continue
} }
visited[curr] = struct{}{} pset, ok := item.(*ProviderSet)
pkgInfo := prog.Package(curr) if !ok {
sets, err := findProviderSets(findContext{ continue
fset: prog.Fset,
pkg: pkgInfo.Pkg,
typeInfo: &pkgInfo.Info,
r: r,
}, pkgInfo.Files)
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
path := pkgInfo.Pkg.Path()
for name, set := range sets {
info.All[ProviderSetID{path, name}] = set
for _, imp := range set.Imports {
next = append(next, imp.ImportPath)
}
}
if _, ok := initial[path]; ok {
for name, set := range sets {
info.Sets[ProviderSetID{path, name}] = set
} }
// pset.Name may not equal name, since it could be an alias to
// another provider set.
id := ProviderSetID{ImportPath: pset.PkgPath, VarName: name}
info.Sets[id] = pset
} }
} }
return info, nil return info, nil
@@ -148,257 +133,217 @@ type Info struct {
// Sets contains all the provider sets in the initial packages. // Sets contains all the provider sets in the initial packages.
Sets map[ProviderSetID]*ProviderSet Sets map[ProviderSetID]*ProviderSet
// All contains all the provider sets transitively depended on by the
// initial packages' provider sets.
All map[ProviderSetID]*ProviderSet
} }
// A ProviderSetID identifies a provider set. // A ProviderSetID identifies a named provider set.
type ProviderSetID struct { type ProviderSetID struct {
ImportPath string ImportPath string
Name string VarName string
} }
// String returns the ID as ""path/to/pkg".Foo". // String returns the ID as ""path/to/pkg".Foo".
func (id ProviderSetID) String() string { func (id ProviderSetID) String() string {
return id.symref().String() return strconv.Quote(id.ImportPath) + "." + id.VarName
} }
func (id ProviderSetID) symref() symref { // objectCache is a lazily evaluated mapping of objects to goose structures.
return symref{importPath: id.ImportPath, name: id.Name} type objectCache struct {
prog *loader.Program
objects map[objRef]interface{} // *Provider or *ProviderSet
} }
type findContext struct { type objRef struct {
fset *token.FileSet importPath string
pkg *types.Package name string
typeInfo *types.Info
r *importResolver
} }
// findProviderSets processes a package and extracts the provider sets declared in it. func newObjectCache(prog *loader.Program) *objectCache {
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*ProviderSet, error) { return &objectCache{
sets := make(map[string]*ProviderSet) prog: prog,
for _, f := range files { objects: make(map[objRef]interface{}),
fileScope := fctx.typeInfo.Scopes[f]
if fileScope == nil {
return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fctx.fset.File(f.Pos()).Name())
} }
for _, dg := range parseFile(fctx.fset, f) {
if dg.decl != nil {
if err := processDeclDirectives(fctx, sets, fileScope, dg); err != nil {
return nil, err
}
} else {
for _, d := range dg.dirs {
if err := processUnassociatedDirective(fctx, sets, fileScope, d); err != nil {
return nil, err
}
}
}
}
}
return sets, nil
} }
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration. // get converts a Go object into a goose structure. It may return a
func processUnassociatedDirective(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, d directive) error { // *Provider, a structProviderPair, an *IfaceBinding, or a *ProviderSet.
switch d.kind { func (oc *objectCache) get(obj types.Object) (interface{}, error) {
case "provide": ref := objRef{
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos)) importPath: obj.Pkg().Path(),
case "use": name: obj.Name(),
// Ignore, picked up by injector flow.
case "bind":
args := d.args()
if len(args) != 3 {
return fmt.Errorf("%v: invalid binding: expected TARGET IFACE TYPE", fctx.fset.Position(d.pos))
} }
ifaceRef, err := parseSymbolRef(fctx.r, args[1], scope, fctx.pkg.Path(), d.pos) if val, cached := oc.objects[ref]; cached {
if err != nil { if val == nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) return nil, fmt.Errorf("%v is not a provider or a provider set", obj)
} }
ifaceObj, err := ifaceRef.resolveObject(fctx.pkg) return val, nil
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
} }
ifaceDecl, ok := ifaceObj.(*types.TypeName) switch obj := obj.(type) {
if !ok { case *types.Var:
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), ifaceRef) spec := oc.varDecl(obj)
if len(spec.Values) == 0 {
return nil, fmt.Errorf("%v is not a provider or a provider set", obj)
} }
iface := ifaceDecl.Type() var i int
methodSet, ok := iface.Underlying().(*types.Interface) for i = range spec.Names {
if !ok { if spec.Names[i].Name == obj.Name() {
return fmt.Errorf("%v: %v does not name an interface type", fctx.fset.Position(d.pos), ifaceRef)
}
providedRef, err := parseSymbolRef(fctx.r, strings.TrimPrefix(args[2], "*"), scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
providedObj, err := providedRef.resolveObject(fctx.pkg)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
providedDecl, ok := providedObj.(*types.TypeName)
if !ok {
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), providedRef)
}
provided := providedDecl.Type()
if types.Identical(provided, iface) {
return fmt.Errorf("%v: cannot bind interface to itself", fctx.fset.Position(d.pos))
}
if strings.HasPrefix(args[2], "*") {
provided = types.NewPointer(provided)
}
if !types.Implements(provided, methodSet) {
return fmt.Errorf("%v: %s does not implement %s", fctx.fset.Position(d.pos), types.TypeString(provided, nil), types.TypeString(iface, nil))
}
name := args[0]
if pset := sets[name]; pset != nil {
pset.Bindings = append(pset.Bindings, IfaceBinding{
Iface: iface,
Provided: provided,
})
} else {
sets[name] = &ProviderSet{
Bindings: []IfaceBinding{{
Iface: iface,
Provided: provided,
}},
}
}
case "import":
args := d.args()
if len(args) < 2 {
return fmt.Errorf("%v: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
}
name := args[0]
for _, spec := range args[1:] {
ref, err := parseSymbolRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
if err != nil {
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
}
if findImport(fctx.pkg, ref.importPath) == nil {
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
}
if mod := sets[name]; mod != nil {
found := false
for _, other := range mod.Imports {
if ref == other.symref() {
found = true
break break
} }
} }
if !found { return oc.processExpr(oc.prog.Package(obj.Pkg().Path()), spec.Values[i])
mod.Imports = append(mod.Imports, ProviderSetImport{ case *types.Func:
ProviderSetID: ProviderSetID{ p, err := processFuncProvider(oc.prog.Fset, obj)
ImportPath: ref.importPath, if err != nil {
Name: ref.name, oc.objects[ref] = nil
}, return nil, err
Pos: d.pos,
})
}
} else {
sets[name] = &ProviderSet{
Imports: []ProviderSetImport{{
ProviderSetID: ProviderSetID{
ImportPath: ref.importPath,
Name: ref.name,
},
Pos: d.pos,
}},
}
}
} }
oc.objects[ref] = p
return p, nil
default: default:
return fmt.Errorf("%v: unknown directive %s", fctx.fset.Position(d.pos), d.kind) oc.objects[ref] = nil
return nil, fmt.Errorf("%v is not a provider or a provider set", obj)
}
}
// varDecl finds the declaration that defines the given variable.
func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec {
// TODO(light): Walk files to build object -> declaration mapping, if more performant.
// Recommended by https://golang.org/s/types-tutorial
pkg := oc.prog.Package(obj.Pkg().Path())
pos := obj.Pos()
for _, f := range pkg.Files {
tokenFile := oc.prog.Fset.File(f.Pos())
if base := tokenFile.Base(); base <= int(pos) && int(pos) < base+tokenFile.Size() {
path, _ := astutil.PathEnclosingInterval(f, pos, pos)
for _, node := range path {
if spec, ok := node.(*ast.ValueSpec); ok {
return spec
}
}
}
} }
return nil return nil
} }
// processDeclDirectives processes the directives associated with a top-level declaration. // processExpr converts an expression into a goose structure. It may
func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, dg directiveGroup) error { // return a *Provider, a structProviderPair, an *IfaceBinding, or a
p, err := dg.single(fctx.fset, "provide") // *ProviderSet.
func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr) (interface{}, error) {
exprPos := oc.prog.Fset.Position(expr.Pos())
expr = astutil.Unparen(expr)
if obj := qualifiedIdentObject(&pkg.Info, expr); obj != nil {
item, err := oc.get(obj)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("%v: %v", exprPos, err)
} }
if !p.isValid() { return item, nil
return nil
} }
var providerSetName string if call, ok := expr.(*ast.CallExpr); ok {
if args := p.args(); len(args) == 1 { fnObj := qualifiedIdentObject(&pkg.Info, call.Fun)
// TODO(light): Validate identifier. if fnObj == nil || !isGooseImport(fnObj.Pkg().Path()) {
providerSetName = args[0] return nil, fmt.Errorf("%v: unknown pattern", exprPos)
} else if len(args) > 1 {
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(p.pos))
} }
switch decl := dg.decl.(type) { switch fnObj.Name() {
case *ast.FuncDecl: case "NewSet":
fn := fctx.typeInfo.ObjectOf(decl.Name).(*types.Func) pset, err := oc.processNewSet(pkg, call)
provider, err := processFuncProvider(fctx, fn)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("%v: %v", exprPos, err)
} }
if providerSetName == "" { return pset, nil
providerSetName = fn.Name() case "Bind":
} b, err := processBind(oc.prog.Fset, &pkg.Info, call)
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.Providers {
if types.Identical(other.Out, provider.Out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos))
}
}
mod.Providers = append(mod.Providers, provider)
} else {
sets[providerSetName] = &ProviderSet{
Providers: []*Provider{provider},
}
}
case *ast.GenDecl:
if decl.Tok != token.TYPE {
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
}
if len(decl.Specs) != 1 {
// TODO(light): Tighten directive extraction to associate with particular specs.
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
}
typeName := fctx.typeInfo.ObjectOf(decl.Specs[0].(*ast.TypeSpec).Name).(*types.TypeName)
if _, ok := typeName.Type().(*types.Named).Underlying().(*types.Struct); !ok {
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
}
provider, err := processStructProvider(fctx, typeName)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("%v: %v", exprPos, err)
}
if providerSetName == "" {
providerSetName = typeName.Name()
}
ptrProvider := new(Provider)
*ptrProvider = *provider
ptrProvider.Out = types.NewPointer(provider.Out)
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.Providers {
if types.Identical(other.Out, provider.Out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos))
}
if types.Identical(other.Out, ptrProvider.Out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.Out, nil), fctx.fset.Position(other.Pos))
}
}
mod.Providers = append(mod.Providers, provider, ptrProvider)
} else {
sets[providerSetName] = &ProviderSet{
Providers: []*Provider{provider, ptrProvider},
}
} }
return b, nil
default: default:
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) return nil, fmt.Errorf("%v: unknown pattern", exprPos)
} }
return nil }
if tn := structArgType(&pkg.Info, expr); tn != nil {
p, err := processStructProvider(oc.prog.Fset, tn)
if err != nil {
return nil, fmt.Errorf("%v: %v", exprPos, err)
}
ptrp := new(Provider)
*ptrp = *p
ptrp.Out = types.NewPointer(p.Out)
return structProviderPair{p, ptrp}, nil
}
return nil, fmt.Errorf("%v: unknown pattern", exprPos)
} }
func processFuncProvider(fctx findContext, fn *types.Func) (*Provider, error) { type structProviderPair struct {
provider *Provider
ptrProvider *Provider
}
func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr) (*ProviderSet, error) {
// Assumes that call.Fun is goose.NewSet or goose.Use.
pset := &ProviderSet{
Pos: call.Pos(),
PkgPath: pkg.Pkg.Path(),
}
for _, arg := range call.Args {
item, err := oc.processExpr(pkg, arg)
if err != nil {
return nil, err
}
switch item := item.(type) {
case *Provider:
pset.Providers = append(pset.Providers, item)
case *ProviderSet:
pset.Imports = append(pset.Imports, item)
case *IfaceBinding:
pset.Bindings = append(pset.Bindings, item)
case structProviderPair:
pset.Providers = append(pset.Providers, item.provider, item.ptrProvider)
default:
panic("unknown item type")
}
}
return pset, nil
}
// structArgType attempts to interpret an expression as a simple struct type.
// It assumes any parentheses have been stripped.
func structArgType(info *types.Info, expr ast.Expr) *types.TypeName {
lit, ok := expr.(*ast.CompositeLit)
if !ok {
return nil
}
tn, ok := qualifiedIdentObject(info, lit.Type).(*types.TypeName)
if !ok {
return nil
}
if _, isStruct := tn.Type().Underlying().(*types.Struct); !isStruct {
return nil
}
return tn
}
// qualifiedIdentObject finds the object for an identifier or a
// qualified identifier, or nil if the object could not be found.
func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object {
switch expr := expr.(type) {
case *ast.Ident:
return info.ObjectOf(expr)
case *ast.SelectorExpr:
pkgName, ok := expr.X.(*ast.Ident)
if !ok {
return nil
}
if _, ok := info.ObjectOf(pkgName).(*types.PkgName); !ok {
return nil
}
return info.ObjectOf(expr.Sel)
default:
return nil
}
}
// processFuncProvider creates a provider for a function declaration.
func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error) {
sig := fn.Type().(*types.Signature) sig := fn.Type().(*types.Signature)
fpos := fn.Pos() fpos := fn.Pos()
@@ -414,23 +359,23 @@ func processFuncProvider(fctx findContext, fn *types.Func) (*Provider, error) {
case types.Identical(t, cleanupType): case types.Identical(t, cleanupType):
hasCleanup, hasErr = true, false hasCleanup, hasErr = true, false
default: default:
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name()) return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fset.Position(fpos), fn.Name())
} }
case 3: case 3:
if t := r.At(1).Type(); !types.Identical(t, cleanupType) { if t := r.At(1).Type(); !types.Identical(t, cleanupType) {
return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name()) return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fset.Position(fpos), fn.Name())
} }
if t := r.At(2).Type(); !types.Identical(t, errorType) { if t := r.At(2).Type(); !types.Identical(t, errorType) {
return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name()) return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fset.Position(fpos), fn.Name())
} }
hasCleanup, hasErr = true, true hasCleanup, hasErr = true, true
default: default:
return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name()) return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name())
} }
out := r.At(0).Type() out := r.At(0).Type()
params := sig.Params() params := sig.Params()
provider := &Provider{ provider := &Provider{
ImportPath: fctx.pkg.Path(), ImportPath: fn.Pkg().Path(),
Name: fn.Name(), Name: fn.Name(),
Pos: fn.Pos(), Pos: fn.Pos(),
Args: make([]ProviderInput, params.Len()), Args: make([]ProviderInput, params.Len()),
@@ -444,20 +389,25 @@ func processFuncProvider(fctx findContext, fn *types.Func) (*Provider, error) {
} }
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil)) return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil))
} }
} }
} }
return provider, nil return provider, nil
} }
func processStructProvider(fctx findContext, typeName *types.TypeName) (*Provider, error) { // processStructProvider creates a provider for a named struct type.
// It only produces the non-pointer variant.
func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, error) {
out := typeName.Type() out := typeName.Type()
st := out.Underlying().(*types.Struct) st, ok := out.Underlying().(*types.Struct)
if !ok {
return nil, fmt.Errorf("%v does not name a struct", typeName)
}
pos := typeName.Pos() pos := typeName.Pos()
provider := &Provider{ provider := &Provider{
ImportPath: fctx.pkg.Path(), ImportPath: typeName.Pkg().Path(),
Name: typeName.Name(), Name: typeName.Name(),
Pos: pos, Pos: pos,
Args: make([]ProviderInput, st.NumFields()), Args: make([]ProviderInput, st.NumFields()),
@@ -473,332 +423,93 @@ func processStructProvider(fctx findContext, typeName *types.TypeName) (*Provide
provider.Fields[i] = f.Name() provider.Fields[i] = f.Name()
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.Args[j].Type, nil)) return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fset.Position(pos), types.TypeString(provider.Args[j].Type, nil))
} }
} }
} }
return provider, nil return provider, nil
} }
// providerSetCache is a lazily evaluated index of provider sets. // processBind creates an interface binding from a goose.Bind call.
type providerSetCache struct { func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) {
sets map[string]map[string]*ProviderSet // Assumes that call.Fun is goose.Bind.
fset *token.FileSet
prog *loader.Program
r *importResolver
}
func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCache { if len(call.Args) != 2 {
return &providerSetCache{ return nil, fmt.Errorf("%v: call to Bind takes exactly two arguments", fset.Position(call.Pos()))
fset: prog.Fset,
prog: prog,
r: r,
} }
} // TODO(light): Verify that arguments are simple expressions.
iface := info.TypeOf(call.Args[0])
func (mc *providerSetCache) get(ref symref) (*ProviderSet, error) { methodSet, ok := iface.Underlying().(*types.Interface)
if mods, cached := mc.sets[ref.importPath]; cached {
mod := mods[ref.name]
if mod == nil {
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
}
return mod, nil
}
if mc.sets == nil {
mc.sets = make(map[string]map[string]*ProviderSet)
}
pkg := mc.prog.Package(ref.importPath)
mods, err := findProviderSets(findContext{
fset: mc.fset,
pkg: pkg.Pkg,
typeInfo: &pkg.Info,
r: mc.r,
}, pkg.Files)
if err != nil {
mc.sets[ref.importPath] = nil
return nil, err
}
mc.sets[ref.importPath] = mods
mod := mods[ref.name]
if mod == nil {
return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath)
}
return mod, nil
}
// A symref is a parsed reference to a symbol (either a provider set or a Go object).
type symref struct {
importPath string
name string
}
func parseSymbolRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (symref, error) {
// TODO(light): Verify that provider set name is an identifier before returning.
i := strings.LastIndexByte(ref, '.')
if i == -1 {
return symref{importPath: pkg, name: ref}, nil
}
imp, name := ref[:i], ref[i+1:]
if strings.HasPrefix(imp, `"`) {
path, err := strconv.Unquote(imp)
if err != nil {
return symref{}, fmt.Errorf("parse symbol reference %q: bad import path", ref)
}
path, err = r.resolve(pos, path)
if err != nil {
return symref{}, fmt.Errorf("parse symbol reference %q: %v", ref, err)
}
return symref{importPath: path, name: name}, nil
}
_, obj := s.LookupParent(imp, pos)
if obj == nil {
return symref{}, fmt.Errorf("parse symbol reference %q: unknown identifier %s", ref, imp)
}
pn, ok := obj.(*types.PkgName)
if !ok { if !ok {
return symref{}, fmt.Errorf("parse symbol reference %q: %s does not name a package", ref, imp) return nil, fmt.Errorf("%v: first argument to bind must be of interface type; found %s", fset.Position(call.Pos()), types.TypeString(iface, nil))
} }
return symref{importPath: pn.Imported().Path(), name: name}, nil provided := info.TypeOf(call.Args[1])
if types.Identical(iface, provided) {
return nil, fmt.Errorf("%v: cannot bind interface to itself", fset.Position(call.Pos()))
}
if !types.Implements(provided, methodSet) {
return nil, fmt.Errorf("%v: %s does not implement %s", fset.Position(call.Pos()), types.TypeString(provided, nil), types.TypeString(iface, nil))
}
return &IfaceBinding{
Pos: call.Pos(),
Iface: iface,
Provided: provided,
}, nil
} }
func (ref symref) String() string { // isInjector checks whether a given function declaration is an
return strconv.Quote(ref.importPath) + "." + ref.name // injector template, returning the goose.Use call. It returns nil if
} // the function is not an injector template.
func isInjector(info *types.Info, fn *ast.FuncDecl) *ast.CallExpr {
func (ref symref) resolveObject(pkg *types.Package) (types.Object, error) { if fn.Body == nil {
imp := findImport(pkg, ref.importPath)
if imp == nil {
return nil, fmt.Errorf("resolve Go reference %v: package not directly imported", ref)
}
obj := imp.Scope().Lookup(ref.name)
if obj == nil {
return nil, fmt.Errorf("resolve Go reference %v: %s not found in package", ref, ref.name)
}
return obj, nil
}
type importResolver struct {
fset *token.FileSet
bctx *build.Context
findPackage func(bctx *build.Context, importPath, fromDir string, mode build.ImportMode) (*build.Package, error)
}
func newImportResolver(c *loader.Config, fset *token.FileSet) *importResolver {
r := &importResolver{
fset: fset,
bctx: c.Build,
findPackage: c.FindPackage,
}
if r.bctx == nil {
r.bctx = &build.Default
}
if r.findPackage == nil {
r.findPackage = (*build.Context).Import
}
return r
}
func (r *importResolver) resolve(pos token.Pos, path string) (string, error) {
dir := filepath.Dir(r.fset.File(pos).Name())
pkg, err := r.findPackage(r.bctx, path, dir, build.FindOnly)
if err != nil {
return "", err
}
return pkg.ImportPath, nil
}
func findImport(pkg *types.Package, path string) *types.Package {
if pkg.Path() == path {
return pkg
}
for _, imp := range pkg.Imports() {
if imp.Path() == path {
return imp
}
}
return nil return nil
}
// A directive is a parsed goose comment.
type directive struct {
pos token.Pos
kind string
line string
}
// A directiveGroup is a set of directives associated with a particular
// declaration.
type directiveGroup struct {
decl ast.Decl
dirs []directive
}
// parseFile extracts the directives from a file, grouped by declaration.
func parseFile(fset *token.FileSet, f *ast.File) []directiveGroup {
cmap := ast.NewCommentMap(fset, f, f.Comments)
// Reserve first group for directives that don't associate with a
// declaration, like import.
groups := make([]directiveGroup, 1, len(f.Decls)+1)
// Walk declarations and add to groups.
for _, decl := range f.Decls {
grp := directiveGroup{decl: decl}
ast.Inspect(decl, func(node ast.Node) bool {
if g := cmap[node]; len(g) > 0 {
for _, cg := range g {
start := len(grp.dirs)
grp.dirs = extractDirectives(grp.dirs, cg)
// Move directives that don't associate into the unassociated group.
n := 0
for i := start; i < len(grp.dirs); i++ {
if k := grp.dirs[i].kind; k == "provide" || k == "use" {
grp.dirs[start+n] = grp.dirs[i]
n++
} else {
groups[0].dirs = append(groups[0].dirs, grp.dirs[i])
} }
var only *ast.ExprStmt
for _, stmt := range fn.Body.List {
switch stmt := stmt.(type) {
case *ast.ExprStmt:
if only != nil {
return nil
} }
grp.dirs = grp.dirs[:start+n] only = stmt
} case *ast.EmptyStmt:
delete(cmap, node) // Do nothing.
}
return true
})
if len(grp.dirs) > 0 {
groups = append(groups, grp)
}
}
// Place remaining directives into the unassociated group.
unassoc := &groups[0]
for _, g := range cmap {
for _, cg := range g {
unassoc.dirs = extractDirectives(unassoc.dirs, cg)
}
}
if len(unassoc.dirs) == 0 {
return groups[1:]
}
return groups
}
func extractDirectives(d []directive, cg *ast.CommentGroup) []directive {
const prefix = "goose:"
text := cg.Text()
for len(text) > 0 {
text = strings.TrimLeft(text, " \t\r\n")
if !strings.HasPrefix(text, prefix) {
break
}
line := text[len(prefix):]
// Text() is always newline terminated.
i := strings.IndexByte(line, '\n')
line, text = line[:i], line[i+1:]
if i := strings.IndexByte(line, ' '); i != -1 {
d = append(d, directive{
kind: line[:i],
line: strings.TrimSpace(line[i+1:]),
pos: cg.Pos(), // TODO(light): More precise position.
})
} else {
d = append(d, directive{
kind: line,
pos: cg.Pos(), // TODO(light): More precise position.
})
}
}
return d
}
// single finds at most one directive that matches the given kind.
func (dg directiveGroup) single(fset *token.FileSet, kind string) (directive, error) {
var found directive
ok := false
for _, d := range dg.dirs {
if d.kind != kind {
continue
}
if ok {
switch decl := dg.decl.(type) {
case *ast.FuncDecl:
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, decl.Name.Name)
case *ast.GenDecl:
if decl.Tok == token.TYPE && len(decl.Specs) == 1 {
name := decl.Specs[0].(*ast.TypeSpec).Name.Name
return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, name)
}
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind)
default: default:
return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind) return nil
} }
} }
found, ok = d, true panicCall, ok := only.X.(*ast.CallExpr)
if !ok {
return nil
} }
return found, nil panicIdent, ok := panicCall.Fun.(*ast.Ident)
if !ok {
return nil
}
if info.ObjectOf(panicIdent) != types.Universe.Lookup("panic") {
return nil
}
if len(panicCall.Args) != 1 {
return nil
}
useCall, ok := panicCall.Args[0].(*ast.CallExpr)
if !ok {
return nil
}
useObj := qualifiedIdentObject(info, useCall.Fun)
if !isGooseImport(useObj.Pkg().Path()) || useObj.Name() != "Use" {
return nil
}
return useCall
} }
func (d directive) isValid() bool { func isGooseImport(path string) bool {
return d.kind != "" // TODO(light): This is depending on details of the current loader.
} const vendorPart = "vendor/"
if i := strings.LastIndex(path, vendorPart); i != -1 && (i == 0 || path[i-1] == '/') {
// args splits the directive line into tokens. path = path[i+len(vendorPart):]
func (d directive) args() []string {
var args []string
start := -1
state := 0 // 0 = boundary, 1 = in token, 2 = in quote, 3 = quote backslash
for i, r := range d.line {
switch state {
case 0:
// Argument boundary.
switch {
case r == '"':
start = i
state = 2
case !unicode.IsSpace(r):
start = i
state = 1
} }
case 1: return path == "codename/goose"
// In token.
switch {
case unicode.IsSpace(r):
args = append(args, d.line[start:i])
start = -1
state = 0
case r == '"':
state = 2
}
case 2:
// In quotes.
switch {
case r == '"':
state = 1
case r == '\\':
state = 3
}
case 3:
// Quote backslash. Consumes one character and jumps back into "in quote" state.
state = 2
default:
panic("unreachable")
}
}
if start != -1 {
args = append(args, d.line[start:])
}
return args
}
// isInjectFile reports whether a given file is an injection template.
func isInjectFile(f *ast.File) bool {
// TODO(light): Better determination.
for _, cg := range f.Comments {
text := cg.Text()
if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") {
return true
}
}
return false
} }
// paramIndex returns the index of the parameter with the given name, or // paramIndex returns the index of the parameter with the given name, or

View File

@@ -1,37 +0,0 @@
package goose
import (
"testing"
)
func TestDirectiveArgs(t *testing.T) {
tests := []struct {
line string
args []string
}{
{"", []string{}},
{" \t ", []string{}},
{"foo", []string{"foo"}},
{"foo bar", []string{"foo", "bar"}},
{" foo \t bar ", []string{"foo", "bar"}},
{"foo \"bar \t baz\" fido", []string{"foo", "\"bar \t baz\"", "fido"}},
{"foo \"bar \t baz\".quux fido", []string{"foo", "\"bar \t baz\".quux", "fido"}},
}
eq := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
for _, test := range tests {
got := (directive{line: test.line}).args()
if !eq(got, test.args) {
t.Errorf("directive{line: %q}.args() = %q; want %q", test.line, got, test.args)
}
}
}

View File

@@ -1,6 +1,10 @@
package main package main
import "fmt" import (
"fmt"
"codename/goose"
)
func main() { func main() {
fmt.Println(injectFooBar()) fmt.Println(injectFooBar())
@@ -9,12 +13,14 @@ func main() {
type Foo int type Foo int
type FooBar int type FooBar int
//goose:provide Set var Set = goose.NewSet(
provideFoo,
provideFooBar)
func provideFoo() Foo { func provideFoo() Foo {
return 41 return 41
} }
//goose:provide Set
func provideFooBar(foo Foo) FooBar { func provideFooBar(foo Foo) FooBar {
return FooBar(foo) + 1 return FooBar(foo) + 1
} }

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use Set import (
"codename/goose"
)
func injectFooBar() FooBar func injectFooBar() FooBar {
panic(goose.Use(Set))
}

View File

@@ -1,6 +1,8 @@
package main package main
import "fmt" import (
"fmt"
)
func main() { func main() {
bar, cleanup := injectBar() bar, cleanup := injectBar()
@@ -12,14 +14,12 @@ func main() {
type Foo int type Foo int
type Bar int type Bar int
//goose:provide Foo
func provideFoo() (*Foo, func()) { func provideFoo() (*Foo, func()) {
foo := new(Foo) foo := new(Foo)
*foo = 42 *foo = 42
return foo, func() { *foo = 0 } return foo, func() { *foo = 0 }
} }
//goose:provide Bar
func provideBar(foo *Foo) (*Bar, func()) { func provideBar(foo *Foo) (*Bar, func()) {
bar := new(Bar) bar := new(Bar)
*bar = 77 *bar = 77

View File

@@ -2,7 +2,10 @@
package main package main
//goose:use Foo import (
//goose:use Bar "codename/goose"
)
func injectBar() (*Bar, func()) func injectBar() (*Bar, func()) {
panic(goose.Use(provideFoo, provideBar))
}

View File

@@ -0,0 +1,11 @@
package main
import (
"fmt"
)
func main() {
fmt.Println(injectedMessage())
}
var myFakeSet struct{}

View File

@@ -0,0 +1,11 @@
//+build gooseinject
package main
import (
"codename/goose"
)
func injectedMessage() string {
panic(goose.Use(myFakeSet))
}

View File

@@ -3,7 +3,8 @@ package main
import ( import (
"fmt" "fmt"
_ "foo" "codename/goose"
"foo"
) )
func main() { func main() {
@@ -16,11 +17,12 @@ func (b *Bar) Foo() string {
return string(*b) return string(*b)
} }
//goose:provide
func provideBar() *Bar { func provideBar() *Bar {
b := new(Bar) b := new(Bar)
*b = "Hello, World!" *b = "Hello, World!"
return b return b
} }
//goose:bind provideBar "foo".Fooer *Bar var Set = goose.NewSet(
provideBar,
goose.Bind(foo.Fooer(nil), (*Bar)(nil)))

View File

@@ -2,8 +2,11 @@
package main package main
import "foo" import (
"codename/goose"
"foo"
)
//goose:use provideBar func injectFooer() foo.Fooer {
panic(goose.Use(Set))
func injectFooer() foo.Fooer }

View File

@@ -1,6 +1,10 @@
package main package main
import "fmt" import (
"fmt"
"codename/goose"
)
func main() { func main() {
fmt.Println(injectFooBar(40)) fmt.Println(injectFooBar(40))
@@ -10,12 +14,14 @@ type Foo int
type Bar int type Bar int
type FooBar int type FooBar int
//goose:provide Set var Set = goose.NewSet(
provideBar,
provideFooBar)
func provideBar() Bar { func provideBar() Bar {
return 2 return 2
} }
//goose:provide Set
func provideFooBar(foo Foo, bar Bar) FooBar { func provideFooBar(foo Foo, bar Bar) FooBar {
return FooBar(foo) + FooBar(bar) return FooBar(foo) + FooBar(bar)
} }

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use Set import (
"codename/goose"
)
func injectFooBar(foo Foo) FooBar func injectFooBar(foo Foo) FooBar {
panic(goose.Use(Set))
}

View File

@@ -1,6 +1,10 @@
package main package main
import "fmt" import (
"fmt"
"codename/goose"
)
func main() { func main() {
// I'm on the fence as to whether this should be an error (versus an // I'm on the fence as to whether this should be an error (versus an
@@ -12,12 +16,14 @@ func main() {
type Foo int type Foo int
type Bar int type Bar int
//goose:provide Set var Set = goose.NewSet(
provideFoo,
provideBar)
func provideFoo() Foo { func provideFoo() Foo {
return -888 return -888
} }
//goose:provide Set
func provideBar(foo Foo) Bar { func provideBar(foo Foo) Bar {
return 2 return 2
} }

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use Set import (
"codename/goose"
)
func injectBar(foo Foo) Bar func injectBar(foo Foo) Bar {
panic(goose.Use(Set))
}

View File

@@ -1,6 +1,10 @@
package main package main
import "fmt" import (
"fmt"
"codename/goose"
)
func main() { func main() {
fmt.Println(injectFooer().Foo()) fmt.Println(injectFooer().Foo())
@@ -16,11 +20,12 @@ func (b *Bar) Foo() string {
return string(*b) return string(*b)
} }
//goose:provide
func provideBar() *Bar { func provideBar() *Bar {
b := new(Bar) b := new(Bar)
*b = "Hello, World!" *b = "Hello, World!"
return b return b
} }
//goose:bind provideBar Fooer *Bar var Set = goose.NewSet(
provideBar,
goose.Bind(Fooer(nil), (*Bar)(nil)))

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use provideBar import (
"codename/goose"
)
func injectFooer() Fooer func injectFooer() Fooer {
panic(goose.Use(Set))
}

View File

@@ -28,8 +28,6 @@ func (b *Bar) Foo() string {
return string(*b) return string(*b)
} }
//goose:provide
//goose:bind provideBar Fooer *Bar
func provideBar() *Bar { func provideBar() *Bar {
mu.Lock() mu.Lock()
provideBarCalls++ provideBarCalls++
@@ -44,7 +42,6 @@ var (
provideBarCalls int provideBarCalls int
) )
//goose:provide
func provideFooBar(fooer Fooer, bar *Bar) FooBar { func provideFooBar(fooer Fooer, bar *Bar) FooBar {
return FooBar{fooer, bar} return FooBar{fooer, bar}
} }

View File

@@ -2,7 +2,13 @@
package main package main
//goose:use provideBar import (
//goose:use provideFooBar "codename/goose"
)
func injectFooBar() FooBar func injectFooBar() FooBar {
panic(goose.Use(
provideBar,
provideFooBar,
goose.Bind(Fooer(nil), (*Bar)(nil))))
}

View File

@@ -1,14 +0,0 @@
package main
import "fmt"
func main() {
fmt.Println(injectedMessage())
}
//goose:provide Set
// provideMessage provides a friendly user greeting.
func provideMessage() string {
return "Hello, World!"
}

View File

@@ -1,5 +0,0 @@
//+build gooseinject
package main
func injectedMessage() string

View File

@@ -1,22 +0,0 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooBar())
}
type Foo int
type FooBar int
//goose:provide Foo
func provideFoo() Foo {
return 41
}
//goose:provide FooBar
func provideFooBar(foo Foo) FooBar {
return FooBar(foo) + 1
}
//goose:import Set Foo FooBar

View File

@@ -1,7 +0,0 @@
//+build gooseinject
package main
//goose:use Set
func injectFooBar() FooBar

View File

@@ -1 +0,0 @@
42

View File

@@ -1 +0,0 @@
foo

View File

@@ -1,20 +0,0 @@
package main
import "fmt"
func main() {
fmt.Println(injectFooBar())
}
type Foo int
type FooBar int
//goose:provide Foo
func provideFoo() Foo {
return 41
}
//goose:provide FooBar
func provideFooBar(foo Foo) FooBar {
return FooBar(foo) + 1
}

View File

@@ -1,7 +0,0 @@
//+build gooseinject
package main
//goose:use Foo FooBar
func injectFooBar() FooBar

View File

@@ -1 +0,0 @@
42

View File

@@ -1 +0,0 @@
foo

View File

@@ -17,8 +17,6 @@ func main() {
fmt.Println(c) fmt.Println(c)
} }
//goose:provide
func provide(ctx stdcontext.Context) (context, error) { func provide(ctx stdcontext.Context) (context, error) {
return context{}, nil return context{}, nil
} }

View File

@@ -4,8 +4,10 @@ package main
import ( import (
stdcontext "context" stdcontext "context"
"codename/goose"
) )
//goose:use provide func inject(context stdcontext.Context, err struct{}) (context, error) {
panic(goose.Use(provide))
func inject(context stdcontext.Context, err struct{}) (context, error) }

View File

@@ -6,8 +6,6 @@ func main() {
fmt.Println(injectedMessage()) fmt.Println(injectedMessage())
} }
//goose:provide
// provideMessage provides a friendly user greeting. // provideMessage provides a friendly user greeting.
func provideMessage() string { func provideMessage() string {
return "Hello, World!" return "Hello, World!"

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use provideMessage import (
"codename/goose"
)
func injectedMessage() string func injectedMessage() string {
panic(goose.Use(provideMessage))
}

View File

@@ -16,7 +16,6 @@ func (b Bar) Foo() string {
return string(b) return string(b)
} }
//goose:provide
func provideBar() Bar { func provideBar() Bar {
return "Hello, World!" return "Hello, World!"
} }

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use provideBar import (
"codename/goose"
)
func injectFooer() Fooer func injectFooer() Fooer {
panic(goose.Use(provideBar))
}

View File

@@ -17,8 +17,6 @@ func main() {
fmt.Println(c) fmt.Println(c)
} }
//goose:provide
func provide(ctx stdcontext.Context) (context, error) { func provide(ctx stdcontext.Context) (context, error) {
return context{}, nil return context{}, nil
} }

View File

@@ -4,11 +4,13 @@ package main
import ( import (
stdcontext "context" stdcontext "context"
"codename/goose"
) )
// The notable characteristic of this test is that there are no // The notable characteristic of this test is that there are no
// parameter names on the inject stub. // parameter names on the inject stub.
//goose:use provide func inject(stdcontext.Context, struct{}) (context, error) {
panic(goose.Use(provide))
func inject(stdcontext.Context, struct{}) (context, error) }

View File

@@ -25,14 +25,12 @@ type Foo int
type Bar int type Bar int
type Baz int type Baz int
//goose:provide Foo
func provideFoo() (*Foo, func()) { func provideFoo() (*Foo, func()) {
foo := new(Foo) foo := new(Foo)
*foo = 42 *foo = 42
return foo, func() { *foo = 0; cleanedFoo = true } return foo, func() { *foo = 0; cleanedFoo = true }
} }
//goose:provide Bar
func provideBar(foo *Foo) (*Bar, func(), error) { func provideBar(foo *Foo) (*Bar, func(), error) {
bar := new(Bar) bar := new(Bar)
*bar = 77 *bar = 77
@@ -45,7 +43,6 @@ func provideBar(foo *Foo) (*Bar, func(), error) {
}, nil }, nil
} }
//goose:provide Baz
func provideBaz(bar *Bar) (Baz, error) { func provideBaz(bar *Bar) (Baz, error) {
return 0, errors.New("bork!") return 0, errors.New("bork!")
} }

View File

@@ -2,8 +2,10 @@
package main package main
//goose:use Foo import (
//goose:use Bar "codename/goose"
//goose:use Baz )
func injectBaz() (Baz, func(), error) func injectBaz() (Baz, func(), error) {
panic(goose.Use(provideFoo, provideBar, provideBaz))
}

View File

@@ -2,7 +2,6 @@ package bar
type Bar int type Bar int
//goose:provide Bar
func ProvideBar() Bar { func ProvideBar() Bar {
return 1 return 1
} }

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"bar" "bar"
"codename/goose"
) )
func main() { func main() {
@@ -13,14 +14,15 @@ func main() {
type Foo int type Foo int
type FooBar int type FooBar int
//goose:provide Set var Set = goose.NewSet(
provideFoo,
bar.ProvideBar,
provideFooBar)
func provideFoo() Foo { func provideFoo() Foo {
return 41 return 41
} }
//goose:import Set "bar".Bar
//goose:provide Set
func provideFooBar(foo Foo, barVal bar.Bar) FooBar { func provideFooBar(foo Foo, barVal bar.Bar) FooBar {
return FooBar(foo) + FooBar(barVal) return FooBar(foo) + FooBar(barVal)
} }

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use Set import (
"codename/goose"
)
func injectFooBar() FooBar func injectFooBar() FooBar {
panic(goose.Use(Set))
}

View File

@@ -1,8 +1,12 @@
package main package main
import "errors" import (
import "fmt" "errors"
import "strings" "fmt"
"strings"
"codename/goose"
)
func main() { func main() {
foo, err := injectFoo() foo, err := injectFoo()
@@ -16,7 +20,8 @@ func main() {
type Foo int type Foo int
//goose:provide Set
func provideFoo() (Foo, error) { func provideFoo() (Foo, error) {
return 42, errors.New("there is no Foo") return 42, errors.New("there is no Foo")
} }
var Set = goose.NewSet(provideFoo)

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use Set import (
"codename/goose"
)
func injectFoo() (Foo, error) func injectFoo() (Foo, error) {
panic(goose.Use(Set))
}

View File

@@ -1,6 +1,10 @@
package main package main
import "fmt" import (
"fmt"
"codename/goose"
)
func main() { func main() {
fb := injectFooBar() fb := injectFooBar()
@@ -10,18 +14,20 @@ func main() {
type Foo int type Foo int
type Bar int type Bar int
//goose:provide Set
type FooBar struct { type FooBar struct {
Foo Foo Foo Foo
Bar Bar Bar Bar
} }
//goose:provide Set
func provideFoo() Foo { func provideFoo() Foo {
return 41 return 41
} }
//goose:provide Set
func provideBar() Bar { func provideBar() Bar {
return 1 return 1
} }
var Set = goose.NewSet(
FooBar{},
provideFoo,
provideBar)

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use Set import (
"codename/goose"
)
func injectFooBar() FooBar func injectFooBar() FooBar {
panic(goose.Use(Set))
}

View File

@@ -1,6 +1,10 @@
package main package main
import "fmt" import (
"fmt"
"codename/goose"
)
func main() { func main() {
fb := injectFooBar() fb := injectFooBar()
@@ -10,18 +14,20 @@ func main() {
type Foo int type Foo int
type Bar int type Bar int
//goose:provide Set
type FooBar struct { type FooBar struct {
Foo Foo Foo Foo
Bar Bar Bar Bar
} }
//goose:provide Set
func provideFoo() Foo { func provideFoo() Foo {
return 41 return 41
} }
//goose:provide Set
func provideBar() Bar { func provideBar() Bar {
return 1 return 1
} }
var Set = goose.NewSet(
FooBar{},
provideFoo,
provideBar)

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use Set import (
"codename/goose"
)
func injectFooBar() *FooBar func injectFooBar() *FooBar {
panic(goose.Use(Set))
}

View File

@@ -1,6 +1,10 @@
package main package main
import "fmt" import (
"fmt"
"codename/goose"
)
func main() { func main() {
fmt.Println(injectFooBar()) fmt.Println(injectFooBar())
@@ -10,17 +14,19 @@ type Foo int
type Bar int type Bar int
type FooBar int type FooBar int
//goose:provide Set
func provideFoo() Foo { func provideFoo() Foo {
return 40 return 40
} }
//goose:provide Set
func provideBar() Bar { func provideBar() Bar {
return 2 return 2
} }
//goose:provide Set
func provideFooBar(foo Foo, bar Bar) FooBar { func provideFooBar(foo Foo, bar Bar) FooBar {
return FooBar(foo) + FooBar(bar) return FooBar(foo) + FooBar(bar)
} }
var Set = goose.NewSet(
provideFoo,
provideBar,
provideFooBar)

View File

@@ -2,6 +2,10 @@
package main package main
//goose:use Set import (
"codename/goose"
)
func injectFooBar() FooBar func injectFooBar() FooBar {
panic(goose.Use(Set))
}

View File

@@ -3,9 +3,10 @@
package main package main
import ( import (
_ "bar" "bar"
"codename/goose"
) )
//goose:use "bar".Message func injectedMessage() string {
panic(goose.Use(bar.ProvideMessage))
func injectedMessage() string }

View File

@@ -1,8 +1,6 @@
// Package bar is the vendored copy of bar which contains the real provider. // Package bar is the vendored copy of bar which contains the real provider.
package bar package bar
//goose:provide Message
// ProvideMessage provides a friendly user greeting. // ProvideMessage provides a friendly user greeting.
func ProvideMessage() string { func ProvideMessage() string {
return "Hello, World!" return "Hello, World!"