diff --git a/README.md b/README.md index 54cd280..9d373a8 100644 --- a/README.md +++ b/README.md @@ -13,34 +13,27 @@ might have hand-written. ### Defining Providers The primary mechanism in goose is the **provider**: a function that can -produce a value, annotated with the special `goose:provide` directive. These -functions are otherwise ordinary Go code. +produce a value. These functions are ordinary Go code. ```go package foobarbaz type Foo int -// goose:provide - // ProvideFoo returns a Foo. func ProvideFoo() Foo { return 42 } ``` -Providers are always part of a **provider set**: if there is no provider set -named on the `//goose:provide` line, then the provider is added to the provider -set with the same name as the function (`ProvideFoo`, in this case). - Providers can specify dependencies with parameters: ```go package foobarbaz -type Bar int +// ... -// goose:provide SuperSet +type Bar int // ProvideBar returns a Bar: a negative Foo. func ProvideBar(foo Foo) Bar { @@ -58,9 +51,9 @@ import ( "errors" ) -type Baz int +// ... -// goose:provide SuperSet +type Baz int // ProvideBaz returns a value if Bar is not zero. func ProvideBaz(ctx context.Context, bar Bar) (Baz, error) { @@ -71,23 +64,36 @@ func ProvideBaz(ctx context.Context, bar Bar) (Baz, error) { } ``` -Provider sets can import other provider sets. To add the `ProvideFoo` set to -`SuperSet`: +Providers can be grouped in **provider sets**. To add these providers to a new +set called `SuperSet`, use the `goose.NewSet` function: ```go -// goose:import SuperSet ProvideFoo +package foobarbaz + +import ( + // ... + "codename/goose" +) + +// ... + +var SuperSet = goose.NewSet(ProvideFoo, ProvideBar, ProvideBaz) ``` -You can also import provider sets in another package, provided that you have a -Go import for the package: +You can also add other provider sets into a provider set. ```go -// goose:import SuperSet "example.com/some/other/pkg".OtherSet -``` +package foobarbaz -A provider set reference is an optional import qualifier (either a package name -or a quoted import path, as seen above) ending with a dot, followed by the -provider set name. +import ( + // ... + "example.com/some/other/pkg" +) + +// ... + +var MegaSet = goose.NewSet(SuperSet, pkg.OtherSet) +``` ### Injectors @@ -95,32 +101,34 @@ An application wires up these providers with an **injector**: a function that calls providers in dependency order. With goose, you write the injector's signature, then goose generates the function's body. -An injector is declared by writing a function declaration without a body in a -file guarded by a `gooseinject` build tag. Let's say that the above providers -were defined in a package called `example.com/foobarbaz`. The following would -declare an injector to obtain a `Baz`: +An injector is declared by writing a function declaration whose body is a call +to `panic()` with a call to `goose.Use` as its argument. Let's say that the +above providers were defined in a package called `example.com/foobarbaz`. The +following would declare an injector to obtain a `Baz`: ```go -//+build gooseinject +// +build gooseinject + +// ^ build tag makes sure the stub is not built in the final build package main import ( "context" + "codename/goose" "example.com/foobarbaz" ) -// goose:use foobarbaz.SuperSet - -func initializeApp(ctx context.Context) (foobarbaz.Baz, error) +func initializeApp(ctx context.Context) (foobarbaz.Baz, error) { + panic(goose.Use(foobarbaz.MegaSet)) +} ``` Like providers, injectors can be parameterized on inputs (which then get sent to -providers) and can return errors. Each `goose:use` directive specifies a -provider set to use in the injection. An injector can have one or more -`goose:use` directives. `goose:use` directives use the same syntax as -`goose:import` to reference provider sets. +providers) and can return errors. Arguments to `goose.Use` are the same as +`goose.NewSet`: they form a provider set. This is the provider set that gets +used during code generation for that injector. You can generate the injector by invoking goose in the package directory: @@ -164,7 +172,7 @@ func initializeApp(ctx context.Context) (foobarbaz.Baz, error) { ``` As you can see, the output is very close to what a developer would write -themselves. Further, there is no dependency on goose at runtime: all of the +themselves. Further, there is little dependency on goose at runtime: all of the written code is just normal Go code, and can be used without goose. [`go generate`]: https://blog.golang.org/generate @@ -228,19 +236,21 @@ func (b *Bar) Foo() string { return string(*b) } -//goose:provide BarFooer -func provideBar() *Bar { +func ProvideBar() *Bar { b := new(Bar) *b = "Hello, World!" return b } -//goose:bind BarFooer Fooer *Bar +var BarFooer = goose.NewSet( + ProvideBar, + goose.Bind(Fooer(nil), (*Bar)(nil))) ``` -The syntax is provider set name, interface type, and finally the concrete type. -An interface binding does not necessarily need to have a provider in the same -set that provides the concrete type. +The first argument to `goose.Bind` is a nil value for the interface type and the +second argument is a zero value of the concrete type. An interface binding does +not necessarily need to have a provider in the same set that provides the +concrete type. [type identity]: https://golang.org/ref/spec#Type_identity [return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces @@ -256,32 +266,31 @@ following providers: type Foo int type Bar int -//goose:provide Foo - -func provideFoo() Foo { +func ProvideFoo() Foo { // ... } -//goose:provide Bar - -func provideBar() Bar { +func ProvideBar() Bar { // ... } -//goose:provide - type FooBar struct { Foo Foo Bar Bar } + +var Set = goose.NewSet( + ProvideFoo, + ProvideBar, + FooBar{}) ``` A generated injector for `FooBar` would look like this: ```go func injectFooBar() FooBar { - foo := provideFoo() - bar := provideBar() + foo := ProvideFoo() + bar := ProvideBar() fooBar := FooBar{ Foo: foo, Bar: bar, @@ -300,8 +309,6 @@ this to either return an aggregated cleanup function to the caller or to clean up the resource if a later provider returns an error. ```go -//goose:provide - func provideFile(log Logger, path Path) (*os.File, func(), error) { f, err := os.Open(string(path)) if err != nil { diff --git a/main.go b/cmd/goose/main.go similarity index 87% rename from main.go rename to cmd/goose/main.go index 07a4cce..66b6538 100644 --- a/main.go +++ b/cmd/goose/main.go @@ -13,6 +13,7 @@ import ( "path/filepath" "reflect" "sort" + "strconv" "strings" "codename/goose/internal/goose" @@ -71,9 +72,9 @@ func generate(pkg string) error { // show runs the show subcommand. // -// Given one or more packages, show will find all the declared provider -// sets and print what other provider sets it imports and what outputs -// it can produce, given possible inputs. +// Given one or more packages, show will find all the provider sets +// declared as top-level variables and print what other provider sets it +// imports and what outputs it can produce, given possible inputs. func show(pkgs ...string) error { wd, err := os.Getwd() if err != nil { @@ -89,11 +90,12 @@ func show(pkgs ...string) error { } sort.Slice(keys, func(i, j int) bool { if keys[i].ImportPath == keys[j].ImportPath { - return keys[i].Name < keys[j].Name + return keys[i].VarName < keys[j].VarName } return keys[i].ImportPath < keys[j].ImportPath }) // ANSI color codes. + // TODO(light): Possibly use github.com/fatih/color? const ( reset = "\x1b[0m" redBold = "\x1b[0;1;31m" @@ -116,7 +118,7 @@ func show(pkgs ...string) error { switch v := v.(type) { case *goose.Provider: out[types.TypeString(t, nil)] = v.Pos - case goose.IfaceBinding: + case *goose.IfaceBinding: out[types.TypeString(t, nil)] = v.Pos default: panic("unreachable") @@ -134,19 +136,19 @@ func show(pkgs ...string) error { type outGroup struct { name string inputs *typeutil.Map // values are not important - outputs *typeutil.Map // values are either *goose.Provider or goose.IfaceBinding + outputs *typeutil.Map // values are either *goose.Provider or *goose.IfaceBinding } // gather flattens a provider set into outputs grouped by the inputs // required to create them. As it flattens the provider set, it records -// the visited provider sets as imports. +// the visited named provider sets as imports. func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports map[string]struct{}) { hash := typeutil.MakeHasher() // Map types to providers and bindings. pm := new(typeutil.Map) pm.SetHasher(hash) - next := []goose.ProviderSetID{key} - visited := make(map[goose.ProviderSetID]struct{}) + next := []*goose.ProviderSet{info.Sets[key]} + visited := make(map[*goose.ProviderSet]struct{}) imports = make(map[string]struct{}) for len(next) > 0 { curr := next[len(next)-1] @@ -155,18 +157,17 @@ func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports ma continue } visited[curr] = struct{}{} - if curr != key { - imports[curr.String()] = struct{}{} + if curr.Name != "" && !(curr.PkgPath == key.ImportPath && curr.Name == key.VarName) { + imports[formatProviderSetName(curr.PkgPath, curr.Name)] = struct{}{} } - set := info.All[curr] - for _, p := range set.Providers { + for _, p := range curr.Providers { pm.Set(p.Out, p) } - for _, b := range set.Bindings { + for _, b := range curr.Bindings { pm.Set(b.Iface, b) } - for _, imp := range set.Imports { - next = append(next, imp.ProviderSetID) + for _, imp := range curr.Imports { + next = append(next, imp) } } @@ -238,7 +239,7 @@ func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports ma inputs: in, outputs: out, }) - case goose.IfaceBinding: + case *goose.IfaceBinding: i, ok := inputVisited.At(p.Provided).(int) if !ok { stk = append(stk, curr, p.Provided) @@ -327,3 +328,8 @@ func sortSet(set interface{}) []string { sort.Strings(a) return a } + +func formatProviderSetName(importPath, varName string) string { + // Since varName is an identifier, it doesn't make sense to quote. + return strconv.Quote(importPath) + "." + varName +} diff --git a/goose.go b/goose.go new file mode 100644 index 0000000..e8fc996 --- /dev/null +++ b/goose.go @@ -0,0 +1,38 @@ +// Package goose contains directives for goose code generation. +package goose + +// ProviderSet is a marker type that collects a group of providers. +type ProviderSet struct{} + +// NewSet creates a new provider set that includes the providers in +// its arguments. Each argument is either an exported function value, +// an exported struct (zero) value, or a call to Bind. +func NewSet(...interface{}) ProviderSet { + return ProviderSet{} +} + +// Use is placed in the body of an injector function to declare the +// providers to use. Its arguments are the same as NewSet. Its return +// value is an error message that can be sent to panic. +// +// Example: +// +// func injector(ctx context.Context) (*sql.DB, error) { +// panic(Use(otherpkg.Foo, myProviderFunc, goose.Bind())) +// } +func Use(...interface{}) string { + return "implementation not generated, run goose" +} + +// A Binding maps an interface to a concrete type. +type Binding struct{} + +// Bind declares that a concrete type should be used to satisfy a +// dependency on iface. +// +// Example: +// +// var MySet = goose.NewSet(goose.Bind(MyInterface(nil), new(MyStruct))) +func Bind(iface, to interface{}) Binding { + return Binding{} +} diff --git a/internal/goose/analyze.go b/internal/goose/analyze.go index d4399b7..d72a74e 100644 --- a/internal/goose/analyze.go +++ b/internal/goose/analyze.go @@ -42,7 +42,7 @@ type call struct { // solve finds the sequence of calls required to produce an output type // with an optional set of provided inputs. -func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symref) ([]call, error) { +func solve(fset *token.FileSet, out types.Type, given []types.Type, set *ProviderSet) ([]call, error) { for i, g := range given { for _, h := range given[:i] { if types.Identical(g, h) { @@ -50,7 +50,7 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr } } } - providers, err := buildProviderMap(mc, sets) + providers, err := buildProviderMap(fset, set) if err != nil { return nil, err } @@ -61,7 +61,7 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr for i, g := range given { if p := providers.At(g); p != nil { pp := p.(*Provider) - return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, mc.fset.Position(pp.Pos)) + return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, fset.Position(pp.Pos)) } index.Set(g, i) } @@ -135,88 +135,70 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr return calls, nil } -func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error) { - type nextEnt struct { - to symref - - from symref - pos token.Pos - } +func buildProviderMap(fset *token.FileSet, set *ProviderSet) (*typeutil.Map, error) { type binding struct { - IfaceBinding - pset symref - from symref + *IfaceBinding + set *ProviderSet } - pm := new(typeutil.Map) // to *providerInfo + providerMap := new(typeutil.Map) // to *Provider + setMap := new(typeutil.Map) // to *ProviderSet, for error messages var bindings []binding - visited := make(map[symref]struct{}) - var next []nextEnt - for _, ref := range sets { - next = append(next, nextEnt{to: ref}) - } + visited := make(map[*ProviderSet]struct{}) + next := []*ProviderSet{set} for len(next) > 0 { curr := next[0] copy(next, next[1:]) next = next[:len(next)-1] - if _, skip := visited[curr.to]; skip { + if _, skip := visited[curr]; skip { continue } - visited[curr.to] = struct{}{} - pset, err := mc.get(curr.to) - if err != nil { - if !curr.pos.IsValid() { - return nil, err + visited[curr] = struct{}{} + for _, p := range curr.Providers { + if providerMap.At(p.Out) != nil { + return nil, bindingConflictError(fset, p.Pos, p.Out, setMap.At(p.Out).(*ProviderSet)) } - return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err) + providerMap.Set(p.Out, p) + setMap.Set(p.Out, curr) } - for _, p := range pset.Providers { - if prev := pm.At(p.Out); prev != nil { - pos := mc.fset.Position(p.Pos) - typ := types.TypeString(p.Out, nil) - prevPos := mc.fset.Position(prev.(*Provider).Pos) - if curr.from.importPath == "" { - // Provider set is imported directly by injector. - return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) - } - return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos) - } - pm.Set(p.Out, p) - } - for _, b := range pset.Bindings { + for _, b := range curr.Bindings { bindings = append(bindings, binding{ IfaceBinding: b, - pset: curr.to, - from: curr.from, + set: curr, }) } - for _, imp := range pset.Imports { - next = append(next, nextEnt{to: imp.symref(), from: curr.to, pos: imp.Pos}) + for _, imp := range curr.Imports { + next = append(next, imp) } } + // Validate that bindings have their concrete type provided in the set. + // TODO(light): Move this validation up into provider set creation. for _, b := range bindings { - if prev := pm.At(b.Iface); prev != nil { - pos := mc.fset.Position(b.Pos) - typ := types.TypeString(b.Iface, nil) - // TODO(light): Error message for conflicting with another interface binding will point at provider instead of binding. - prevPos := mc.fset.Position(prev.(*Provider).Pos) - if b.from.importPath == "" { - // Provider set is imported directly by injector. - return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) - } - return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos) + if providerMap.At(b.Iface) != nil { + return nil, bindingConflictError(fset, b.Pos, b.Iface, setMap.At(b.Iface).(*ProviderSet)) } - concrete := pm.At(b.Provided) + concrete := providerMap.At(b.Provided) if concrete == nil { - pos := mc.fset.Position(b.Pos) + pos := fset.Position(b.Pos) typ := types.TypeString(b.Provided, nil) - if b.from.importPath == "" { - // Concrete provider is imported directly by injector. - return nil, fmt.Errorf("%v: no binding for %s", pos, typ) - } - return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from) + return nil, fmt.Errorf("%v: no binding for %s", pos, typ) } - pm.Set(b.Iface, concrete) + providerMap.Set(b.Iface, concrete) + setMap.Set(b.Iface, b.set) } - return pm, nil + return providerMap, nil +} + +// bindingConflictError creates a new error describing multiple bindings +// for the same output type. +func bindingConflictError(fset *token.FileSet, pos token.Pos, typ types.Type, prevSet *ProviderSet) error { + position := fset.Position(pos) + typString := types.TypeString(typ, nil) + if prevSet.Name == "" { + prevPosition := fset.Position(prevSet.Pos) + return fmt.Errorf("%v: multiple bindings for %s (previous binding at %v)", + position, typString, prevPosition) + } + return fmt.Errorf("%v: multiple bindings for %s (previous binding in %q.%s)", + position, typString, prevSet.PkgPath, prevSet.Name) } diff --git a/internal/goose/goose.go b/internal/goose/goose.go index 31743fe..0cafb0d 100644 --- a/internal/goose/goose.go +++ b/internal/goose/goose.go @@ -8,7 +8,7 @@ import ( "go/ast" "go/build" "go/format" - "go/parser" + "go/token" "go/types" "sort" "strconv" @@ -22,8 +22,24 @@ import ( // Generate performs dependency injection for a single package, // returning the gofmt'd Go source code. func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { - conf := newLoaderConfig(bctx, wd, true) + mainPkg, err := bctx.Import(pkg, wd, build.FindOnly) + if err != nil { + return nil, fmt.Errorf("load: %v", err) + } + // TODO(light): Stop errors from printing to stderr. + conf := &loader.Config{ + Build: new(build.Context), + Cwd: wd, + TypeCheckFuncBodies: func(path string) bool { + return path == mainPkg.ImportPath + }, + } + *conf.Build = *bctx + n := len(conf.Build.BuildTags) + // TODO(light): Only apply gooseinject build tag on main package. + conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject") conf.Import(pkg) + prog, err := conf.Load() if err != nil { return nil, fmt.Errorf("load: %v", err) @@ -34,47 +50,23 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { } pkgInfo := prog.InitialPackages()[0] g := newGen(prog, pkgInfo.Pkg.Path()) - r := newImportResolver(conf, prog.Fset) - mc := newProviderSetCache(prog, r) + oc := newObjectCache(prog) for _, f := range pkgInfo.Files { - if !isInjectFile(f) { - continue - } - fileScope := pkgInfo.Scopes[f] - groups := parseFile(prog.Fset, f) for _, decl := range f.Decls { fn, ok := decl.(*ast.FuncDecl) if !ok { continue } - var dg directiveGroup - for _, dg = range groups { - if dg.decl == decl { - break - } + useCall := isInjector(&pkgInfo.Info, fn) + if useCall == nil { + continue } - if dg.decl != decl { - dg = directiveGroup{} - } - var sets []symref - for _, d := range dg.dirs { - if d.kind != "use" { - return nil, fmt.Errorf("%v: cannot use %s directive on inject function", prog.Fset.Position(d.pos), d.kind) - } - args := d.args() - if len(args) == 0 { - return nil, fmt.Errorf("%v: goose:use must have at least one provider set reference", prog.Fset.Position(d.pos)) - } - for _, arg := range args { - ref, err := parseSymbolRef(r, arg, fileScope, g.currPackage, d.pos) - if err != nil { - return nil, fmt.Errorf("%v: %v", prog.Fset.Position(d.pos), err) - } - sets = append(sets, ref) - } + set, err := oc.processNewSet(pkgInfo, useCall) + if err != nil { + return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err) } sig := pkgInfo.ObjectOf(fn.Name).Type().(*types.Signature) - if err := g.inject(mc, fn.Name.Name, sig, sets); err != nil { + if err := g.inject(prog.Fset, fn.Name.Name, sig, set); err != nil { return nil, fmt.Errorf("%v: %v", prog.Fset.Position(fn.Pos()), err) } } @@ -89,23 +81,6 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { return fmtSrc, nil } -func newLoaderConfig(bctx *build.Context, wd string, inject bool) *loader.Config { - // TODO(light): Stop errors from printing to stderr. - conf := &loader.Config{ - Build: bctx, - ParserMode: parser.ParseComments, - Cwd: wd, - TypeCheckFuncBodies: func(string) bool { return false }, - } - if inject { - conf.Build = new(build.Context) - *conf.Build = *bctx - n := len(conf.Build.BuildTags) - conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject") - } - return conf -} - // gen is the generator state. type gen struct { currPackage string @@ -150,7 +125,7 @@ func (g *gen) frame() []byte { } // inject emits the code for an injector. -func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, sets []symref) error { +func (g *gen) inject(fset *token.FileSet, name string, sig *types.Signature, set *ProviderSet) error { results := sig.Results() var returnsCleanup, returnsErr bool switch results.Len() { @@ -184,7 +159,7 @@ func (g *gen) inject(mc *providerSetCache, name string, sig *types.Signature, se for i := 0; i < params.Len(); i++ { given[i] = params.At(i).Type() } - calls, err := solve(mc, outType, given, sets) + calls, err := solve(fset, outType, given, set) if err != nil { return err } diff --git a/internal/goose/goose_test.go b/internal/goose/goose_test.go index e745a3f..aac6d29 100644 --- a/internal/goose/goose_test.go +++ b/internal/goose/goose_test.go @@ -29,13 +29,19 @@ func TestGoose(t *testing.T) { if err != nil { t.Fatal(err) } + // The marker function package source is needed to have the test cases + // type check. loadTestCase places this file at the well-known import path. + gooseGo, err := ioutil.ReadFile(filepath.Join("..", "..", "goose.go")) + if err != nil { + t.Fatal(err) + } tests := make([]*testCase, 0, len(testdataEnts)) for _, ent := range testdataEnts { name := ent.Name() if !ent.IsDir() || strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") { continue } - test, err := loadTestCase(filepath.Join(testRoot, name)) + test, err := loadTestCase(filepath.Join(testRoot, name), gooseGo) if err != nil { t.Error(err) } @@ -227,7 +233,7 @@ type testCase struct { // out.txt file containing the expected output, or the magic string "ERROR" // if this test should cause generation to fail // ... any Go files found recursively placed under GOPATH/src/... -func loadTestCase(root string) (*testCase, error) { +func loadTestCase(root string, gooseGoSrc []byte) (*testCase, error) { name := filepath.Base(root) pkg, err := ioutil.ReadFile(filepath.Join(root, "pkg")) if err != nil { @@ -242,7 +248,9 @@ func loadTestCase(root string) (*testCase, error) { wantError = true out = nil } - goFiles := make(map[string][]byte) + goFiles := map[string][]byte{ + "codename/goose/goose.go": gooseGoSrc, + } err = filepath.Walk(root, func(src string, info os.FileInfo, err error) error { if err != nil { return err diff --git a/internal/goose/parse.go b/internal/goose/parse.go index 3f03180..efa1e6e 100644 --- a/internal/goose/parse.go +++ b/internal/goose/parse.go @@ -6,20 +6,28 @@ import ( "go/build" "go/token" "go/types" - "path/filepath" "strconv" "strings" - "unicode" + "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/loader" ) // A ProviderSet describes a set of providers. The zero value is an empty // ProviderSet. type ProviderSet struct { + // Pos is the position of the call to goose.NewSet or goose.Use that + // created the set. + Pos token.Pos + // PkgPath is the import path of the package that declared this set. + PkgPath string + // Name is the variable name of the set, if it came from a package + // variable. + Name string + Providers []*Provider - Bindings []IfaceBinding - Imports []ProviderSetImport + Bindings []*IfaceBinding + Imports []*ProviderSet } // An IfaceBinding declares that a type should be used to satisfy inputs @@ -35,12 +43,6 @@ type IfaceBinding struct { Pos token.Pos } -// A ProviderSetImport adds providers from one provider set into another. -type ProviderSetImport struct { - ProviderSetID - Pos token.Pos -} - // Provider records the signature of a provider. A provider is a // single Go object, either a function or a named struct type. type Provider struct { @@ -87,7 +89,12 @@ type ProviderInput struct { // Load finds all the provider sets in the given packages, as well as // the provider sets' transitive dependencies. func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) { - conf := newLoaderConfig(bctx, wd, false) + // TODO(light): Stop errors from printing to stderr. + conf := &loader.Config{ + Build: bctx, + Cwd: wd, + TypeCheckFuncBodies: func(string) bool { return false }, + } for _, p := range pkgs { conf.Import(p) } @@ -95,48 +102,26 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) { if err != nil { return nil, fmt.Errorf("load: %v", err) } - r := newImportResolver(conf, prog.Fset) - var next []string - initial := make(map[string]struct{}) - for _, pkgInfo := range prog.InitialPackages() { - path := pkgInfo.Pkg.Path() - next = append(next, path) - initial[path] = struct{}{} - } - visited := make(map[string]struct{}) info := &Info{ Fset: prog.Fset, Sets: make(map[ProviderSetID]*ProviderSet), - All: make(map[ProviderSetID]*ProviderSet), } - for len(next) > 0 { - curr := next[len(next)-1] - next = next[:len(next)-1] - if _, ok := visited[curr]; ok { - continue - } - visited[curr] = struct{}{} - pkgInfo := prog.Package(curr) - sets, err := findProviderSets(findContext{ - fset: prog.Fset, - pkg: pkgInfo.Pkg, - typeInfo: &pkgInfo.Info, - r: r, - }, pkgInfo.Files) - if err != nil { - return nil, fmt.Errorf("load: %v", err) - } - path := pkgInfo.Pkg.Path() - for name, set := range sets { - info.All[ProviderSetID{path, name}] = set - for _, imp := range set.Imports { - next = append(next, imp.ImportPath) + oc := newObjectCache(prog) + for _, pkgInfo := range prog.InitialPackages() { + scope := pkgInfo.Pkg.Scope() + for _, name := range scope.Names() { + item, err := oc.get(scope.Lookup(name)) + if err != nil { + continue } - } - if _, ok := initial[path]; ok { - for name, set := range sets { - info.Sets[ProviderSetID{path, name}] = set + pset, ok := item.(*ProviderSet) + if !ok { + continue } + // pset.Name may not equal name, since it could be an alias to + // another provider set. + id := ProviderSetID{ImportPath: pset.PkgPath, VarName: name} + info.Sets[id] = pset } } return info, nil @@ -148,257 +133,217 @@ type Info struct { // Sets contains all the provider sets in the initial packages. Sets map[ProviderSetID]*ProviderSet - - // All contains all the provider sets transitively depended on by the - // initial packages' provider sets. - All map[ProviderSetID]*ProviderSet } -// A ProviderSetID identifies a provider set. +// A ProviderSetID identifies a named provider set. type ProviderSetID struct { ImportPath string - Name string + VarName string } // String returns the ID as ""path/to/pkg".Foo". func (id ProviderSetID) String() string { - return id.symref().String() + return strconv.Quote(id.ImportPath) + "." + id.VarName } -func (id ProviderSetID) symref() symref { - return symref{importPath: id.ImportPath, name: id.Name} +// objectCache is a lazily evaluated mapping of objects to goose structures. +type objectCache struct { + prog *loader.Program + objects map[objRef]interface{} // *Provider or *ProviderSet } -type findContext struct { - fset *token.FileSet - pkg *types.Package - typeInfo *types.Info - r *importResolver +type objRef struct { + importPath string + name string } -// findProviderSets processes a package and extracts the provider sets declared in it. -func findProviderSets(fctx findContext, files []*ast.File) (map[string]*ProviderSet, error) { - sets := make(map[string]*ProviderSet) - for _, f := range files { - fileScope := fctx.typeInfo.Scopes[f] - if fileScope == nil { - return nil, fmt.Errorf("%s: no scope found for file (likely a bug)", fctx.fset.File(f.Pos()).Name()) - } - for _, dg := range parseFile(fctx.fset, f) { - if dg.decl != nil { - if err := processDeclDirectives(fctx, sets, fileScope, dg); err != nil { - return nil, err - } - } else { - for _, d := range dg.dirs { - if err := processUnassociatedDirective(fctx, sets, fileScope, d); err != nil { - return nil, err - } - } - } - } +func newObjectCache(prog *loader.Program) *objectCache { + return &objectCache{ + prog: prog, + objects: make(map[objRef]interface{}), } - return sets, nil } -// processUnassociatedDirective handles any directive that was not associated with a top-level declaration. -func processUnassociatedDirective(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, d directive) error { - switch d.kind { - case "provide": - return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos)) - case "use": - // Ignore, picked up by injector flow. - case "bind": - args := d.args() - if len(args) != 3 { - return fmt.Errorf("%v: invalid binding: expected TARGET IFACE TYPE", fctx.fset.Position(d.pos)) +// get converts a Go object into a goose structure. It may return a +// *Provider, a structProviderPair, an *IfaceBinding, or a *ProviderSet. +func (oc *objectCache) get(obj types.Object) (interface{}, error) { + ref := objRef{ + importPath: obj.Pkg().Path(), + name: obj.Name(), + } + if val, cached := oc.objects[ref]; cached { + if val == nil { + return nil, fmt.Errorf("%v is not a provider or a provider set", obj) } - ifaceRef, err := parseSymbolRef(fctx.r, args[1], scope, fctx.pkg.Path(), d.pos) - if err != nil { - return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) + return val, nil + } + switch obj := obj.(type) { + case *types.Var: + spec := oc.varDecl(obj) + if len(spec.Values) == 0 { + return nil, fmt.Errorf("%v is not a provider or a provider set", obj) } - ifaceObj, err := ifaceRef.resolveObject(fctx.pkg) - if err != nil { - return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) - } - ifaceDecl, ok := ifaceObj.(*types.TypeName) - if !ok { - return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), ifaceRef) - } - iface := ifaceDecl.Type() - methodSet, ok := iface.Underlying().(*types.Interface) - if !ok { - return fmt.Errorf("%v: %v does not name an interface type", fctx.fset.Position(d.pos), ifaceRef) - } - - providedRef, err := parseSymbolRef(fctx.r, strings.TrimPrefix(args[2], "*"), scope, fctx.pkg.Path(), d.pos) - if err != nil { - return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) - } - providedObj, err := providedRef.resolveObject(fctx.pkg) - if err != nil { - return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) - } - providedDecl, ok := providedObj.(*types.TypeName) - if !ok { - return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), providedRef) - } - provided := providedDecl.Type() - if types.Identical(provided, iface) { - return fmt.Errorf("%v: cannot bind interface to itself", fctx.fset.Position(d.pos)) - } - if strings.HasPrefix(args[2], "*") { - provided = types.NewPointer(provided) - } - if !types.Implements(provided, methodSet) { - return fmt.Errorf("%v: %s does not implement %s", fctx.fset.Position(d.pos), types.TypeString(provided, nil), types.TypeString(iface, nil)) - } - - name := args[0] - if pset := sets[name]; pset != nil { - pset.Bindings = append(pset.Bindings, IfaceBinding{ - Iface: iface, - Provided: provided, - }) - } else { - sets[name] = &ProviderSet{ - Bindings: []IfaceBinding{{ - Iface: iface, - Provided: provided, - }}, + var i int + for i = range spec.Names { + if spec.Names[i].Name == obj.Name() { + break } } - case "import": - args := d.args() - if len(args) < 2 { - return fmt.Errorf("%v: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos)) - } - name := args[0] - for _, spec := range args[1:] { - ref, err := parseSymbolRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos) - if err != nil { - return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err) - } - if findImport(fctx.pkg, ref.importPath) == nil { - return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath) - } - if mod := sets[name]; mod != nil { - found := false - for _, other := range mod.Imports { - if ref == other.symref() { - found = true - break - } - } - if !found { - mod.Imports = append(mod.Imports, ProviderSetImport{ - ProviderSetID: ProviderSetID{ - ImportPath: ref.importPath, - Name: ref.name, - }, - Pos: d.pos, - }) - } - } else { - sets[name] = &ProviderSet{ - Imports: []ProviderSetImport{{ - ProviderSetID: ProviderSetID{ - ImportPath: ref.importPath, - Name: ref.name, - }, - Pos: d.pos, - }}, - } - } + return oc.processExpr(oc.prog.Package(obj.Pkg().Path()), spec.Values[i]) + case *types.Func: + p, err := processFuncProvider(oc.prog.Fset, obj) + if err != nil { + oc.objects[ref] = nil + return nil, err } + oc.objects[ref] = p + return p, nil default: - return fmt.Errorf("%v: unknown directive %s", fctx.fset.Position(d.pos), d.kind) + oc.objects[ref] = nil + return nil, fmt.Errorf("%v is not a provider or a provider set", obj) + } +} + +// varDecl finds the declaration that defines the given variable. +func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { + // TODO(light): Walk files to build object -> declaration mapping, if more performant. + // Recommended by https://golang.org/s/types-tutorial + pkg := oc.prog.Package(obj.Pkg().Path()) + pos := obj.Pos() + for _, f := range pkg.Files { + tokenFile := oc.prog.Fset.File(f.Pos()) + if base := tokenFile.Base(); base <= int(pos) && int(pos) < base+tokenFile.Size() { + path, _ := astutil.PathEnclosingInterval(f, pos, pos) + for _, node := range path { + if spec, ok := node.(*ast.ValueSpec); ok { + return spec + } + } + } } return nil } -// processDeclDirectives processes the directives associated with a top-level declaration. -func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, dg directiveGroup) error { - p, err := dg.single(fctx.fset, "provide") - if err != nil { - return err +// processExpr converts an expression into a goose structure. It may +// return a *Provider, a structProviderPair, an *IfaceBinding, or a +// *ProviderSet. +func (oc *objectCache) processExpr(pkg *loader.PackageInfo, expr ast.Expr) (interface{}, error) { + exprPos := oc.prog.Fset.Position(expr.Pos()) + expr = astutil.Unparen(expr) + if obj := qualifiedIdentObject(&pkg.Info, expr); obj != nil { + item, err := oc.get(obj) + if err != nil { + return nil, fmt.Errorf("%v: %v", exprPos, err) + } + return item, nil } - if !p.isValid() { + if call, ok := expr.(*ast.CallExpr); ok { + fnObj := qualifiedIdentObject(&pkg.Info, call.Fun) + if fnObj == nil || !isGooseImport(fnObj.Pkg().Path()) { + return nil, fmt.Errorf("%v: unknown pattern", exprPos) + } + switch fnObj.Name() { + case "NewSet": + pset, err := oc.processNewSet(pkg, call) + if err != nil { + return nil, fmt.Errorf("%v: %v", exprPos, err) + } + return pset, nil + case "Bind": + b, err := processBind(oc.prog.Fset, &pkg.Info, call) + if err != nil { + return nil, fmt.Errorf("%v: %v", exprPos, err) + } + return b, nil + default: + return nil, fmt.Errorf("%v: unknown pattern", exprPos) + } + } + if tn := structArgType(&pkg.Info, expr); tn != nil { + p, err := processStructProvider(oc.prog.Fset, tn) + if err != nil { + return nil, fmt.Errorf("%v: %v", exprPos, err) + } + ptrp := new(Provider) + *ptrp = *p + ptrp.Out = types.NewPointer(p.Out) + return structProviderPair{p, ptrp}, nil + } + return nil, fmt.Errorf("%v: unknown pattern", exprPos) +} + +type structProviderPair struct { + provider *Provider + ptrProvider *Provider +} + +func (oc *objectCache) processNewSet(pkg *loader.PackageInfo, call *ast.CallExpr) (*ProviderSet, error) { + // Assumes that call.Fun is goose.NewSet or goose.Use. + + pset := &ProviderSet{ + Pos: call.Pos(), + PkgPath: pkg.Pkg.Path(), + } + for _, arg := range call.Args { + item, err := oc.processExpr(pkg, arg) + if err != nil { + return nil, err + } + switch item := item.(type) { + case *Provider: + pset.Providers = append(pset.Providers, item) + case *ProviderSet: + pset.Imports = append(pset.Imports, item) + case *IfaceBinding: + pset.Bindings = append(pset.Bindings, item) + case structProviderPair: + pset.Providers = append(pset.Providers, item.provider, item.ptrProvider) + default: + panic("unknown item type") + } + } + return pset, nil +} + +// structArgType attempts to interpret an expression as a simple struct type. +// It assumes any parentheses have been stripped. +func structArgType(info *types.Info, expr ast.Expr) *types.TypeName { + lit, ok := expr.(*ast.CompositeLit) + if !ok { return nil } - var providerSetName string - if args := p.args(); len(args) == 1 { - // TODO(light): Validate identifier. - providerSetName = args[0] - } else if len(args) > 1 { - return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(p.pos)) + tn, ok := qualifiedIdentObject(info, lit.Type).(*types.TypeName) + if !ok { + return nil } - switch decl := dg.decl.(type) { - case *ast.FuncDecl: - fn := fctx.typeInfo.ObjectOf(decl.Name).(*types.Func) - provider, err := processFuncProvider(fctx, fn) - if err != nil { - return err - } - if providerSetName == "" { - providerSetName = fn.Name() - } - if mod := sets[providerSetName]; mod != nil { - for _, other := range mod.Providers { - if types.Identical(other.Out, provider.Out) { - return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos)) - } - } - mod.Providers = append(mod.Providers, provider) - } else { - sets[providerSetName] = &ProviderSet{ - Providers: []*Provider{provider}, - } - } - case *ast.GenDecl: - if decl.Tok != token.TYPE { - return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) - } - if len(decl.Specs) != 1 { - // TODO(light): Tighten directive extraction to associate with particular specs. - return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) - } - typeName := fctx.typeInfo.ObjectOf(decl.Specs[0].(*ast.TypeSpec).Name).(*types.TypeName) - if _, ok := typeName.Type().(*types.Named).Underlying().(*types.Struct); !ok { - return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) - } - provider, err := processStructProvider(fctx, typeName) - if err != nil { - return err - } - if providerSetName == "" { - providerSetName = typeName.Name() - } - ptrProvider := new(Provider) - *ptrProvider = *provider - ptrProvider.Out = types.NewPointer(provider.Out) - if mod := sets[providerSetName]; mod != nil { - for _, other := range mod.Providers { - if types.Identical(other.Out, provider.Out) { - return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos)) - } - if types.Identical(other.Out, ptrProvider.Out) { - return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.Out, nil), fctx.fset.Position(other.Pos)) - } - } - mod.Providers = append(mod.Providers, provider, ptrProvider) - } else { - sets[providerSetName] = &ProviderSet{ - Providers: []*Provider{provider, ptrProvider}, - } - } - default: - return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) + if _, isStruct := tn.Type().Underlying().(*types.Struct); !isStruct { + return nil } - return nil + return tn } -func processFuncProvider(fctx findContext, fn *types.Func) (*Provider, error) { +// qualifiedIdentObject finds the object for an identifier or a +// qualified identifier, or nil if the object could not be found. +func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object { + switch expr := expr.(type) { + case *ast.Ident: + return info.ObjectOf(expr) + case *ast.SelectorExpr: + pkgName, ok := expr.X.(*ast.Ident) + if !ok { + return nil + } + if _, ok := info.ObjectOf(pkgName).(*types.PkgName); !ok { + return nil + } + return info.ObjectOf(expr.Sel) + default: + return nil + } +} + +// processFuncProvider creates a provider for a function declaration. +func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, error) { sig := fn.Type().(*types.Signature) fpos := fn.Pos() @@ -414,23 +359,23 @@ func processFuncProvider(fctx findContext, fn *types.Func) (*Provider, error) { case types.Identical(t, cleanupType): hasCleanup, hasErr = true, false default: - return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fctx.fset.Position(fpos), fn.Name()) + return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be error or func()", fset.Position(fpos), fn.Name()) } case 3: if t := r.At(1).Type(); !types.Identical(t, cleanupType) { - return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fctx.fset.Position(fpos), fn.Name()) + return nil, fmt.Errorf("%v: wrong signature for provider %s: second return type must be func()", fset.Position(fpos), fn.Name()) } if t := r.At(2).Type(); !types.Identical(t, errorType) { - return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fctx.fset.Position(fpos), fn.Name()) + return nil, fmt.Errorf("%v: wrong signature for provider %s: third return type must be error", fset.Position(fpos), fn.Name()) } hasCleanup, hasErr = true, true default: - return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fctx.fset.Position(fpos), fn.Name()) + return nil, fmt.Errorf("%v: wrong signature for provider %s: must have one return value and optional error", fset.Position(fpos), fn.Name()) } out := r.At(0).Type() params := sig.Params() provider := &Provider{ - ImportPath: fctx.pkg.Path(), + ImportPath: fn.Pkg().Path(), Name: fn.Name(), Pos: fn.Pos(), Args: make([]ProviderInput, params.Len()), @@ -444,20 +389,25 @@ func processFuncProvider(fctx findContext, fn *types.Func) (*Provider, error) { } for j := 0; j < i; j++ { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { - return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil)) + return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil)) } } } return provider, nil } -func processStructProvider(fctx findContext, typeName *types.TypeName) (*Provider, error) { +// processStructProvider creates a provider for a named struct type. +// It only produces the non-pointer variant. +func processStructProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, error) { out := typeName.Type() - st := out.Underlying().(*types.Struct) + st, ok := out.Underlying().(*types.Struct) + if !ok { + return nil, fmt.Errorf("%v does not name a struct", typeName) + } pos := typeName.Pos() provider := &Provider{ - ImportPath: fctx.pkg.Path(), + ImportPath: typeName.Pkg().Path(), Name: typeName.Name(), Pos: pos, Args: make([]ProviderInput, st.NumFields()), @@ -473,332 +423,93 @@ func processStructProvider(fctx findContext, typeName *types.TypeName) (*Provide provider.Fields[i] = f.Name() for j := 0; j < i; j++ { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { - return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.Args[j].Type, nil)) + return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fset.Position(pos), types.TypeString(provider.Args[j].Type, nil)) } } } return provider, nil } -// providerSetCache is a lazily evaluated index of provider sets. -type providerSetCache struct { - sets map[string]map[string]*ProviderSet - fset *token.FileSet - prog *loader.Program - r *importResolver -} +// processBind creates an interface binding from a goose.Bind call. +func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) { + // Assumes that call.Fun is goose.Bind. -func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCache { - return &providerSetCache{ - fset: prog.Fset, - prog: prog, - r: r, + if len(call.Args) != 2 { + return nil, fmt.Errorf("%v: call to Bind takes exactly two arguments", fset.Position(call.Pos())) } -} - -func (mc *providerSetCache) get(ref symref) (*ProviderSet, error) { - if mods, cached := mc.sets[ref.importPath]; cached { - mod := mods[ref.name] - if mod == nil { - return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath) - } - return mod, nil - } - if mc.sets == nil { - mc.sets = make(map[string]map[string]*ProviderSet) - } - pkg := mc.prog.Package(ref.importPath) - mods, err := findProviderSets(findContext{ - fset: mc.fset, - pkg: pkg.Pkg, - typeInfo: &pkg.Info, - r: mc.r, - }, pkg.Files) - if err != nil { - mc.sets[ref.importPath] = nil - return nil, err - } - mc.sets[ref.importPath] = mods - mod := mods[ref.name] - if mod == nil { - return nil, fmt.Errorf("no such provider set %s in package %q", ref.name, ref.importPath) - } - return mod, nil -} - -// A symref is a parsed reference to a symbol (either a provider set or a Go object). -type symref struct { - importPath string - name string -} - -func parseSymbolRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (symref, error) { - // TODO(light): Verify that provider set name is an identifier before returning. - - i := strings.LastIndexByte(ref, '.') - if i == -1 { - return symref{importPath: pkg, name: ref}, nil - } - imp, name := ref[:i], ref[i+1:] - if strings.HasPrefix(imp, `"`) { - path, err := strconv.Unquote(imp) - if err != nil { - return symref{}, fmt.Errorf("parse symbol reference %q: bad import path", ref) - } - path, err = r.resolve(pos, path) - if err != nil { - return symref{}, fmt.Errorf("parse symbol reference %q: %v", ref, err) - } - return symref{importPath: path, name: name}, nil - } - _, obj := s.LookupParent(imp, pos) - if obj == nil { - return symref{}, fmt.Errorf("parse symbol reference %q: unknown identifier %s", ref, imp) - } - pn, ok := obj.(*types.PkgName) + // TODO(light): Verify that arguments are simple expressions. + iface := info.TypeOf(call.Args[0]) + methodSet, ok := iface.Underlying().(*types.Interface) if !ok { - return symref{}, fmt.Errorf("parse symbol reference %q: %s does not name a package", ref, imp) + return nil, fmt.Errorf("%v: first argument to bind must be of interface type; found %s", fset.Position(call.Pos()), types.TypeString(iface, nil)) } - return symref{importPath: pn.Imported().Path(), name: name}, nil + provided := info.TypeOf(call.Args[1]) + if types.Identical(iface, provided) { + return nil, fmt.Errorf("%v: cannot bind interface to itself", fset.Position(call.Pos())) + } + if !types.Implements(provided, methodSet) { + return nil, fmt.Errorf("%v: %s does not implement %s", fset.Position(call.Pos()), types.TypeString(provided, nil), types.TypeString(iface, nil)) + } + return &IfaceBinding{ + Pos: call.Pos(), + Iface: iface, + Provided: provided, + }, nil } -func (ref symref) String() string { - return strconv.Quote(ref.importPath) + "." + ref.name -} - -func (ref symref) resolveObject(pkg *types.Package) (types.Object, error) { - imp := findImport(pkg, ref.importPath) - if imp == nil { - return nil, fmt.Errorf("resolve Go reference %v: package not directly imported", ref) +// isInjector checks whether a given function declaration is an +// injector template, returning the goose.Use call. It returns nil if +// the function is not an injector template. +func isInjector(info *types.Info, fn *ast.FuncDecl) *ast.CallExpr { + if fn.Body == nil { + return nil } - obj := imp.Scope().Lookup(ref.name) - if obj == nil { - return nil, fmt.Errorf("resolve Go reference %v: %s not found in package", ref, ref.name) - } - return obj, nil -} - -type importResolver struct { - fset *token.FileSet - bctx *build.Context - findPackage func(bctx *build.Context, importPath, fromDir string, mode build.ImportMode) (*build.Package, error) -} - -func newImportResolver(c *loader.Config, fset *token.FileSet) *importResolver { - r := &importResolver{ - fset: fset, - bctx: c.Build, - findPackage: c.FindPackage, - } - if r.bctx == nil { - r.bctx = &build.Default - } - if r.findPackage == nil { - r.findPackage = (*build.Context).Import - } - return r -} - -func (r *importResolver) resolve(pos token.Pos, path string) (string, error) { - dir := filepath.Dir(r.fset.File(pos).Name()) - pkg, err := r.findPackage(r.bctx, path, dir, build.FindOnly) - if err != nil { - return "", err - } - return pkg.ImportPath, nil -} - -func findImport(pkg *types.Package, path string) *types.Package { - if pkg.Path() == path { - return pkg - } - for _, imp := range pkg.Imports() { - if imp.Path() == path { - return imp - } - } - return nil -} - -// A directive is a parsed goose comment. -type directive struct { - pos token.Pos - kind string - line string -} - -// A directiveGroup is a set of directives associated with a particular -// declaration. -type directiveGroup struct { - decl ast.Decl - dirs []directive -} - -// parseFile extracts the directives from a file, grouped by declaration. -func parseFile(fset *token.FileSet, f *ast.File) []directiveGroup { - cmap := ast.NewCommentMap(fset, f, f.Comments) - // Reserve first group for directives that don't associate with a - // declaration, like import. - groups := make([]directiveGroup, 1, len(f.Decls)+1) - // Walk declarations and add to groups. - for _, decl := range f.Decls { - grp := directiveGroup{decl: decl} - ast.Inspect(decl, func(node ast.Node) bool { - if g := cmap[node]; len(g) > 0 { - for _, cg := range g { - start := len(grp.dirs) - grp.dirs = extractDirectives(grp.dirs, cg) - - // Move directives that don't associate into the unassociated group. - n := 0 - for i := start; i < len(grp.dirs); i++ { - if k := grp.dirs[i].kind; k == "provide" || k == "use" { - grp.dirs[start+n] = grp.dirs[i] - n++ - } else { - groups[0].dirs = append(groups[0].dirs, grp.dirs[i]) - } - } - grp.dirs = grp.dirs[:start+n] - } - delete(cmap, node) + var only *ast.ExprStmt + for _, stmt := range fn.Body.List { + switch stmt := stmt.(type) { + case *ast.ExprStmt: + if only != nil { + return nil } - return true - }) - if len(grp.dirs) > 0 { - groups = append(groups, grp) - } - } - // Place remaining directives into the unassociated group. - unassoc := &groups[0] - for _, g := range cmap { - for _, cg := range g { - unassoc.dirs = extractDirectives(unassoc.dirs, cg) - } - } - if len(unassoc.dirs) == 0 { - return groups[1:] - } - return groups -} - -func extractDirectives(d []directive, cg *ast.CommentGroup) []directive { - const prefix = "goose:" - text := cg.Text() - for len(text) > 0 { - text = strings.TrimLeft(text, " \t\r\n") - if !strings.HasPrefix(text, prefix) { - break - } - line := text[len(prefix):] - // Text() is always newline terminated. - i := strings.IndexByte(line, '\n') - line, text = line[:i], line[i+1:] - if i := strings.IndexByte(line, ' '); i != -1 { - d = append(d, directive{ - kind: line[:i], - line: strings.TrimSpace(line[i+1:]), - pos: cg.Pos(), // TODO(light): More precise position. - }) - } else { - d = append(d, directive{ - kind: line, - pos: cg.Pos(), // TODO(light): More precise position. - }) - } - } - return d -} - -// single finds at most one directive that matches the given kind. -func (dg directiveGroup) single(fset *token.FileSet, kind string) (directive, error) { - var found directive - ok := false - for _, d := range dg.dirs { - if d.kind != kind { - continue - } - if ok { - switch decl := dg.decl.(type) { - case *ast.FuncDecl: - return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, decl.Name.Name) - case *ast.GenDecl: - if decl.Tok == token.TYPE && len(decl.Specs) == 1 { - name := decl.Specs[0].(*ast.TypeSpec).Name.Name - return directive{}, fmt.Errorf("%v: multiple %s directives for %s", fset.Position(d.pos), kind, name) - } - return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind) - default: - return directive{}, fmt.Errorf("%v: multiple %s directives", fset.Position(d.pos), kind) - } - } - found, ok = d, true - } - return found, nil -} - -func (d directive) isValid() bool { - return d.kind != "" -} - -// args splits the directive line into tokens. -func (d directive) args() []string { - var args []string - start := -1 - state := 0 // 0 = boundary, 1 = in token, 2 = in quote, 3 = quote backslash - for i, r := range d.line { - switch state { - case 0: - // Argument boundary. - switch { - case r == '"': - start = i - state = 2 - case !unicode.IsSpace(r): - start = i - state = 1 - } - case 1: - // In token. - switch { - case unicode.IsSpace(r): - args = append(args, d.line[start:i]) - start = -1 - state = 0 - case r == '"': - state = 2 - } - case 2: - // In quotes. - switch { - case r == '"': - state = 1 - case r == '\\': - state = 3 - } - case 3: - // Quote backslash. Consumes one character and jumps back into "in quote" state. - state = 2 + only = stmt + case *ast.EmptyStmt: + // Do nothing. default: - panic("unreachable") + return nil } } - if start != -1 { - args = append(args, d.line[start:]) + panicCall, ok := only.X.(*ast.CallExpr) + if !ok { + return nil } - return args + panicIdent, ok := panicCall.Fun.(*ast.Ident) + if !ok { + return nil + } + if info.ObjectOf(panicIdent) != types.Universe.Lookup("panic") { + return nil + } + if len(panicCall.Args) != 1 { + return nil + } + useCall, ok := panicCall.Args[0].(*ast.CallExpr) + if !ok { + return nil + } + useObj := qualifiedIdentObject(info, useCall.Fun) + if !isGooseImport(useObj.Pkg().Path()) || useObj.Name() != "Use" { + return nil + } + return useCall } -// isInjectFile reports whether a given file is an injection template. -func isInjectFile(f *ast.File) bool { - // TODO(light): Better determination. - for _, cg := range f.Comments { - text := cg.Text() - if strings.HasPrefix(text, "+build") && strings.Contains(text, "gooseinject") { - return true - } +func isGooseImport(path string) bool { + // TODO(light): This is depending on details of the current loader. + const vendorPart = "vendor/" + if i := strings.LastIndex(path, vendorPart); i != -1 && (i == 0 || path[i-1] == '/') { + path = path[i+len(vendorPart):] } - return false + return path == "codename/goose" } // paramIndex returns the index of the parameter with the given name, or diff --git a/internal/goose/parse_test.go b/internal/goose/parse_test.go deleted file mode 100644 index 7ebef72..0000000 --- a/internal/goose/parse_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package goose - -import ( - "testing" -) - -func TestDirectiveArgs(t *testing.T) { - tests := []struct { - line string - args []string - }{ - {"", []string{}}, - {" \t ", []string{}}, - {"foo", []string{"foo"}}, - {"foo bar", []string{"foo", "bar"}}, - {" foo \t bar ", []string{"foo", "bar"}}, - {"foo \"bar \t baz\" fido", []string{"foo", "\"bar \t baz\"", "fido"}}, - {"foo \"bar \t baz\".quux fido", []string{"foo", "\"bar \t baz\".quux", "fido"}}, - } - eq := func(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true - } - for _, test := range tests { - got := (directive{line: test.line}).args() - if !eq(got, test.args) { - t.Errorf("directive{line: %q}.args() = %q; want %q", test.line, got, test.args) - } - } -} diff --git a/internal/goose/testdata/Chain/foo/foo.go b/internal/goose/testdata/Chain/foo/foo.go index 502e200..c11b4bf 100644 --- a/internal/goose/testdata/Chain/foo/foo.go +++ b/internal/goose/testdata/Chain/foo/foo.go @@ -1,6 +1,10 @@ package main -import "fmt" +import ( + "fmt" + + "codename/goose" +) func main() { fmt.Println(injectFooBar()) @@ -9,12 +13,14 @@ func main() { type Foo int type FooBar int -//goose:provide Set +var Set = goose.NewSet( + provideFoo, + provideFooBar) + func provideFoo() Foo { return 41 } -//goose:provide Set func provideFooBar(foo Foo) FooBar { return FooBar(foo) + 1 } diff --git a/internal/goose/testdata/Chain/foo/foo_goose.go b/internal/goose/testdata/Chain/foo/foo_goose.go index 73f5093..47946a7 100644 --- a/internal/goose/testdata/Chain/foo/foo_goose.go +++ b/internal/goose/testdata/Chain/foo/foo_goose.go @@ -2,6 +2,10 @@ package main -//goose:use Set +import ( + "codename/goose" +) -func injectFooBar() FooBar +func injectFooBar() FooBar { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/Cleanup/foo/foo.go b/internal/goose/testdata/Cleanup/foo/foo.go index e086644..625182a 100644 --- a/internal/goose/testdata/Cleanup/foo/foo.go +++ b/internal/goose/testdata/Cleanup/foo/foo.go @@ -1,6 +1,8 @@ package main -import "fmt" +import ( + "fmt" +) func main() { bar, cleanup := injectBar() @@ -12,14 +14,12 @@ func main() { type Foo int type Bar int -//goose:provide Foo func provideFoo() (*Foo, func()) { foo := new(Foo) *foo = 42 return foo, func() { *foo = 0 } } -//goose:provide Bar func provideBar(foo *Foo) (*Bar, func()) { bar := new(Bar) *bar = 77 diff --git a/internal/goose/testdata/Cleanup/foo/goose.go b/internal/goose/testdata/Cleanup/foo/goose.go index b56a73e..fe97cab 100644 --- a/internal/goose/testdata/Cleanup/foo/goose.go +++ b/internal/goose/testdata/Cleanup/foo/goose.go @@ -2,7 +2,10 @@ package main -//goose:use Foo -//goose:use Bar +import ( + "codename/goose" +) -func injectBar() (*Bar, func()) +func injectBar() (*Bar, func()) { + panic(goose.Use(provideFoo, provideBar)) +} diff --git a/internal/goose/testdata/EmptyVar/foo/foo.go b/internal/goose/testdata/EmptyVar/foo/foo.go new file mode 100644 index 0000000..b251127 --- /dev/null +++ b/internal/goose/testdata/EmptyVar/foo/foo.go @@ -0,0 +1,11 @@ +package main + +import ( + "fmt" +) + +func main() { + fmt.Println(injectedMessage()) +} + +var myFakeSet struct{} diff --git a/internal/goose/testdata/EmptyVar/foo/goose.go b/internal/goose/testdata/EmptyVar/foo/goose.go new file mode 100644 index 0000000..3dddda0 --- /dev/null +++ b/internal/goose/testdata/EmptyVar/foo/goose.go @@ -0,0 +1,11 @@ +//+build gooseinject + +package main + +import ( + "codename/goose" +) + +func injectedMessage() string { + panic(goose.Use(myFakeSet)) +} diff --git a/internal/goose/testdata/MissingUse/out.txt b/internal/goose/testdata/EmptyVar/out.txt similarity index 100% rename from internal/goose/testdata/MissingUse/out.txt rename to internal/goose/testdata/EmptyVar/out.txt diff --git a/internal/goose/testdata/MissingUse/pkg b/internal/goose/testdata/EmptyVar/pkg similarity index 100% rename from internal/goose/testdata/MissingUse/pkg rename to internal/goose/testdata/EmptyVar/pkg diff --git a/internal/goose/testdata/ImportedInterfaceBinding/bar/bar.go b/internal/goose/testdata/ImportedInterfaceBinding/bar/bar.go index 9970572..1009a9a 100644 --- a/internal/goose/testdata/ImportedInterfaceBinding/bar/bar.go +++ b/internal/goose/testdata/ImportedInterfaceBinding/bar/bar.go @@ -3,7 +3,8 @@ package main import ( "fmt" - _ "foo" + "codename/goose" + "foo" ) func main() { @@ -16,11 +17,12 @@ func (b *Bar) Foo() string { return string(*b) } -//goose:provide func provideBar() *Bar { b := new(Bar) *b = "Hello, World!" return b } -//goose:bind provideBar "foo".Fooer *Bar +var Set = goose.NewSet( + provideBar, + goose.Bind(foo.Fooer(nil), (*Bar)(nil))) diff --git a/internal/goose/testdata/ImportedInterfaceBinding/bar/goose.go b/internal/goose/testdata/ImportedInterfaceBinding/bar/goose.go index 46812cc..8ffb415 100644 --- a/internal/goose/testdata/ImportedInterfaceBinding/bar/goose.go +++ b/internal/goose/testdata/ImportedInterfaceBinding/bar/goose.go @@ -2,8 +2,11 @@ package main -import "foo" +import ( + "codename/goose" + "foo" +) -//goose:use provideBar - -func injectFooer() foo.Fooer +func injectFooer() foo.Fooer { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/InjectInput/foo/foo.go b/internal/goose/testdata/InjectInput/foo/foo.go index bd2a6d4..16b8990 100644 --- a/internal/goose/testdata/InjectInput/foo/foo.go +++ b/internal/goose/testdata/InjectInput/foo/foo.go @@ -1,6 +1,10 @@ package main -import "fmt" +import ( + "fmt" + + "codename/goose" +) func main() { fmt.Println(injectFooBar(40)) @@ -10,12 +14,14 @@ type Foo int type Bar int type FooBar int -//goose:provide Set +var Set = goose.NewSet( + provideBar, + provideFooBar) + func provideBar() Bar { return 2 } -//goose:provide Set func provideFooBar(foo Foo, bar Bar) FooBar { return FooBar(foo) + FooBar(bar) } diff --git a/internal/goose/testdata/InjectInput/foo/foo_goose.go b/internal/goose/testdata/InjectInput/foo/foo_goose.go index cabd74b..996c4d7 100644 --- a/internal/goose/testdata/InjectInput/foo/foo_goose.go +++ b/internal/goose/testdata/InjectInput/foo/foo_goose.go @@ -2,6 +2,10 @@ package main -//goose:use Set +import ( + "codename/goose" +) -func injectFooBar(foo Foo) FooBar +func injectFooBar(foo Foo) FooBar { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/InjectInputConflict/foo/foo.go b/internal/goose/testdata/InjectInputConflict/foo/foo.go index 6afb76d..499ce16 100644 --- a/internal/goose/testdata/InjectInputConflict/foo/foo.go +++ b/internal/goose/testdata/InjectInputConflict/foo/foo.go @@ -1,6 +1,10 @@ package main -import "fmt" +import ( + "fmt" + + "codename/goose" +) func main() { // I'm on the fence as to whether this should be an error (versus an @@ -12,12 +16,14 @@ func main() { type Foo int type Bar int -//goose:provide Set +var Set = goose.NewSet( + provideFoo, + provideBar) + func provideFoo() Foo { return -888 } -//goose:provide Set func provideBar(foo Foo) Bar { return 2 } diff --git a/internal/goose/testdata/InjectInputConflict/foo/foo_goose.go b/internal/goose/testdata/InjectInputConflict/foo/foo_goose.go index 282ae3f..2e388f2 100644 --- a/internal/goose/testdata/InjectInputConflict/foo/foo_goose.go +++ b/internal/goose/testdata/InjectInputConflict/foo/foo_goose.go @@ -2,6 +2,10 @@ package main -//goose:use Set +import ( + "codename/goose" +) -func injectBar(foo Foo) Bar +func injectBar(foo Foo) Bar { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/InterfaceBinding/foo/foo.go b/internal/goose/testdata/InterfaceBinding/foo/foo.go index 50c523c..500cc5f 100644 --- a/internal/goose/testdata/InterfaceBinding/foo/foo.go +++ b/internal/goose/testdata/InterfaceBinding/foo/foo.go @@ -1,6 +1,10 @@ package main -import "fmt" +import ( + "fmt" + + "codename/goose" +) func main() { fmt.Println(injectFooer().Foo()) @@ -16,11 +20,12 @@ func (b *Bar) Foo() string { return string(*b) } -//goose:provide func provideBar() *Bar { b := new(Bar) *b = "Hello, World!" return b } -//goose:bind provideBar Fooer *Bar +var Set = goose.NewSet( + provideBar, + goose.Bind(Fooer(nil), (*Bar)(nil))) diff --git a/internal/goose/testdata/InterfaceBinding/foo/foo_goose.go b/internal/goose/testdata/InterfaceBinding/foo/foo_goose.go index 38876cb..493ca41 100644 --- a/internal/goose/testdata/InterfaceBinding/foo/foo_goose.go +++ b/internal/goose/testdata/InterfaceBinding/foo/foo_goose.go @@ -2,6 +2,10 @@ package main -//goose:use provideBar +import ( + "codename/goose" +) -func injectFooer() Fooer +func injectFooer() Fooer { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/InterfaceBindingReuse/foo/foo.go b/internal/goose/testdata/InterfaceBindingReuse/foo/foo.go index bb097eb..d77375c 100644 --- a/internal/goose/testdata/InterfaceBindingReuse/foo/foo.go +++ b/internal/goose/testdata/InterfaceBindingReuse/foo/foo.go @@ -28,8 +28,6 @@ func (b *Bar) Foo() string { return string(*b) } -//goose:provide -//goose:bind provideBar Fooer *Bar func provideBar() *Bar { mu.Lock() provideBarCalls++ @@ -44,7 +42,6 @@ var ( provideBarCalls int ) -//goose:provide func provideFooBar(fooer Fooer, bar *Bar) FooBar { return FooBar{fooer, bar} } diff --git a/internal/goose/testdata/InterfaceBindingReuse/foo/foo_goose.go b/internal/goose/testdata/InterfaceBindingReuse/foo/foo_goose.go index 48a1c01..c8179fc 100644 --- a/internal/goose/testdata/InterfaceBindingReuse/foo/foo_goose.go +++ b/internal/goose/testdata/InterfaceBindingReuse/foo/foo_goose.go @@ -2,7 +2,13 @@ package main -//goose:use provideBar -//goose:use provideFooBar +import ( + "codename/goose" +) -func injectFooBar() FooBar +func injectFooBar() FooBar { + panic(goose.Use( + provideBar, + provideFooBar, + goose.Bind(Fooer(nil), (*Bar)(nil)))) +} diff --git a/internal/goose/testdata/MissingUse/foo/foo.go b/internal/goose/testdata/MissingUse/foo/foo.go deleted file mode 100644 index 65ab36b..0000000 --- a/internal/goose/testdata/MissingUse/foo/foo.go +++ /dev/null @@ -1,14 +0,0 @@ -package main - -import "fmt" - -func main() { - fmt.Println(injectedMessage()) -} - -//goose:provide Set - -// provideMessage provides a friendly user greeting. -func provideMessage() string { - return "Hello, World!" -} diff --git a/internal/goose/testdata/MissingUse/foo/foo_goose.go b/internal/goose/testdata/MissingUse/foo/foo_goose.go deleted file mode 100644 index fea1a42..0000000 --- a/internal/goose/testdata/MissingUse/foo/foo_goose.go +++ /dev/null @@ -1,5 +0,0 @@ -//+build gooseinject - -package main - -func injectedMessage() string diff --git a/internal/goose/testdata/MultiImport/foo/foo.go b/internal/goose/testdata/MultiImport/foo/foo.go deleted file mode 100644 index 1721d3a..0000000 --- a/internal/goose/testdata/MultiImport/foo/foo.go +++ /dev/null @@ -1,22 +0,0 @@ -package main - -import "fmt" - -func main() { - fmt.Println(injectFooBar()) -} - -type Foo int -type FooBar int - -//goose:provide Foo -func provideFoo() Foo { - return 41 -} - -//goose:provide FooBar -func provideFooBar(foo Foo) FooBar { - return FooBar(foo) + 1 -} - -//goose:import Set Foo FooBar diff --git a/internal/goose/testdata/MultiImport/foo/goose.go b/internal/goose/testdata/MultiImport/foo/goose.go deleted file mode 100644 index 73f5093..0000000 --- a/internal/goose/testdata/MultiImport/foo/goose.go +++ /dev/null @@ -1,7 +0,0 @@ -//+build gooseinject - -package main - -//goose:use Set - -func injectFooBar() FooBar diff --git a/internal/goose/testdata/MultiImport/out.txt b/internal/goose/testdata/MultiImport/out.txt deleted file mode 100644 index d81cc07..0000000 --- a/internal/goose/testdata/MultiImport/out.txt +++ /dev/null @@ -1 +0,0 @@ -42 diff --git a/internal/goose/testdata/MultiImport/pkg b/internal/goose/testdata/MultiImport/pkg deleted file mode 100644 index 257cc56..0000000 --- a/internal/goose/testdata/MultiImport/pkg +++ /dev/null @@ -1 +0,0 @@ -foo diff --git a/internal/goose/testdata/MultiUse/foo/foo.go b/internal/goose/testdata/MultiUse/foo/foo.go deleted file mode 100644 index 232c06a..0000000 --- a/internal/goose/testdata/MultiUse/foo/foo.go +++ /dev/null @@ -1,20 +0,0 @@ -package main - -import "fmt" - -func main() { - fmt.Println(injectFooBar()) -} - -type Foo int -type FooBar int - -//goose:provide Foo -func provideFoo() Foo { - return 41 -} - -//goose:provide FooBar -func provideFooBar(foo Foo) FooBar { - return FooBar(foo) + 1 -} diff --git a/internal/goose/testdata/MultiUse/foo/goose.go b/internal/goose/testdata/MultiUse/foo/goose.go deleted file mode 100644 index d273f88..0000000 --- a/internal/goose/testdata/MultiUse/foo/goose.go +++ /dev/null @@ -1,7 +0,0 @@ -//+build gooseinject - -package main - -//goose:use Foo FooBar - -func injectFooBar() FooBar diff --git a/internal/goose/testdata/MultiUse/out.txt b/internal/goose/testdata/MultiUse/out.txt deleted file mode 100644 index d81cc07..0000000 --- a/internal/goose/testdata/MultiUse/out.txt +++ /dev/null @@ -1 +0,0 @@ -42 diff --git a/internal/goose/testdata/MultiUse/pkg b/internal/goose/testdata/MultiUse/pkg deleted file mode 100644 index 257cc56..0000000 --- a/internal/goose/testdata/MultiUse/pkg +++ /dev/null @@ -1 +0,0 @@ -foo diff --git a/internal/goose/testdata/NamingWorstCase/foo/foo.go b/internal/goose/testdata/NamingWorstCase/foo/foo.go index 4558109..be9e535 100644 --- a/internal/goose/testdata/NamingWorstCase/foo/foo.go +++ b/internal/goose/testdata/NamingWorstCase/foo/foo.go @@ -17,8 +17,6 @@ func main() { fmt.Println(c) } -//goose:provide - func provide(ctx stdcontext.Context) (context, error) { return context{}, nil } diff --git a/internal/goose/testdata/NamingWorstCase/foo/goose.go b/internal/goose/testdata/NamingWorstCase/foo/goose.go index 6c276e3..d1a8dee 100644 --- a/internal/goose/testdata/NamingWorstCase/foo/goose.go +++ b/internal/goose/testdata/NamingWorstCase/foo/goose.go @@ -4,8 +4,10 @@ package main import ( stdcontext "context" + + "codename/goose" ) -//goose:use provide - -func inject(context stdcontext.Context, err struct{}) (context, error) +func inject(context stdcontext.Context, err struct{}) (context, error) { + panic(goose.Use(provide)) +} diff --git a/internal/goose/testdata/NiladicIdentity/foo/foo.go b/internal/goose/testdata/NiladicIdentity/foo/foo.go index ac80c0a..e6ac32d 100644 --- a/internal/goose/testdata/NiladicIdentity/foo/foo.go +++ b/internal/goose/testdata/NiladicIdentity/foo/foo.go @@ -6,8 +6,6 @@ func main() { fmt.Println(injectedMessage()) } -//goose:provide - // provideMessage provides a friendly user greeting. func provideMessage() string { return "Hello, World!" diff --git a/internal/goose/testdata/NiladicIdentity/foo/foo_goose.go b/internal/goose/testdata/NiladicIdentity/foo/foo_goose.go index d63fbe3..b45ef45 100644 --- a/internal/goose/testdata/NiladicIdentity/foo/foo_goose.go +++ b/internal/goose/testdata/NiladicIdentity/foo/foo_goose.go @@ -2,6 +2,10 @@ package main -//goose:use provideMessage +import ( + "codename/goose" +) -func injectedMessage() string +func injectedMessage() string { + panic(goose.Use(provideMessage)) +} diff --git a/internal/goose/testdata/NoImplicitInterface/foo/foo.go b/internal/goose/testdata/NoImplicitInterface/foo/foo.go index 01585ff..a50cda4 100644 --- a/internal/goose/testdata/NoImplicitInterface/foo/foo.go +++ b/internal/goose/testdata/NoImplicitInterface/foo/foo.go @@ -16,7 +16,6 @@ func (b Bar) Foo() string { return string(b) } -//goose:provide func provideBar() Bar { return "Hello, World!" } diff --git a/internal/goose/testdata/NoImplicitInterface/foo/foo_goose.go b/internal/goose/testdata/NoImplicitInterface/foo/foo_goose.go index 38876cb..5b898e0 100644 --- a/internal/goose/testdata/NoImplicitInterface/foo/foo_goose.go +++ b/internal/goose/testdata/NoImplicitInterface/foo/foo_goose.go @@ -2,6 +2,10 @@ package main -//goose:use provideBar +import ( + "codename/goose" +) -func injectFooer() Fooer +func injectFooer() Fooer { + panic(goose.Use(provideBar)) +} diff --git a/internal/goose/testdata/NoInjectParamNames/foo/foo.go b/internal/goose/testdata/NoInjectParamNames/foo/foo.go index 4558109..be9e535 100644 --- a/internal/goose/testdata/NoInjectParamNames/foo/foo.go +++ b/internal/goose/testdata/NoInjectParamNames/foo/foo.go @@ -17,8 +17,6 @@ func main() { fmt.Println(c) } -//goose:provide - func provide(ctx stdcontext.Context) (context, error) { return context{}, nil } diff --git a/internal/goose/testdata/NoInjectParamNames/foo/goose.go b/internal/goose/testdata/NoInjectParamNames/foo/goose.go index e94e92f..81fec99 100644 --- a/internal/goose/testdata/NoInjectParamNames/foo/goose.go +++ b/internal/goose/testdata/NoInjectParamNames/foo/goose.go @@ -4,11 +4,13 @@ package main import ( stdcontext "context" + + "codename/goose" ) // The notable characteristic of this test is that there are no // parameter names on the inject stub. -//goose:use provide - -func inject(stdcontext.Context, struct{}) (context, error) +func inject(stdcontext.Context, struct{}) (context, error) { + panic(goose.Use(provide)) +} diff --git a/internal/goose/testdata/PartialCleanup/foo/foo.go b/internal/goose/testdata/PartialCleanup/foo/foo.go index 4aec1b2..c73ff95 100644 --- a/internal/goose/testdata/PartialCleanup/foo/foo.go +++ b/internal/goose/testdata/PartialCleanup/foo/foo.go @@ -25,14 +25,12 @@ type Foo int type Bar int type Baz int -//goose:provide Foo func provideFoo() (*Foo, func()) { foo := new(Foo) *foo = 42 return foo, func() { *foo = 0; cleanedFoo = true } } -//goose:provide Bar func provideBar(foo *Foo) (*Bar, func(), error) { bar := new(Bar) *bar = 77 @@ -45,7 +43,6 @@ func provideBar(foo *Foo) (*Bar, func(), error) { }, nil } -//goose:provide Baz func provideBaz(bar *Bar) (Baz, error) { return 0, errors.New("bork!") } diff --git a/internal/goose/testdata/PartialCleanup/foo/goose.go b/internal/goose/testdata/PartialCleanup/foo/goose.go index c9fe147..77029b4 100644 --- a/internal/goose/testdata/PartialCleanup/foo/goose.go +++ b/internal/goose/testdata/PartialCleanup/foo/goose.go @@ -2,8 +2,10 @@ package main -//goose:use Foo -//goose:use Bar -//goose:use Baz +import ( + "codename/goose" +) -func injectBaz() (Baz, func(), error) +func injectBaz() (Baz, func(), error) { + panic(goose.Use(provideFoo, provideBar, provideBaz)) +} diff --git a/internal/goose/testdata/PkgImport/bar/bar.go b/internal/goose/testdata/PkgImport/bar/bar.go index c5a1de4..c604501 100644 --- a/internal/goose/testdata/PkgImport/bar/bar.go +++ b/internal/goose/testdata/PkgImport/bar/bar.go @@ -2,7 +2,6 @@ package bar type Bar int -//goose:provide Bar func ProvideBar() Bar { return 1 } diff --git a/internal/goose/testdata/PkgImport/foo/foo.go b/internal/goose/testdata/PkgImport/foo/foo.go index 35e73a1..c8a3c94 100644 --- a/internal/goose/testdata/PkgImport/foo/foo.go +++ b/internal/goose/testdata/PkgImport/foo/foo.go @@ -4,6 +4,7 @@ import ( "fmt" "bar" + "codename/goose" ) func main() { @@ -13,14 +14,15 @@ func main() { type Foo int type FooBar int -//goose:provide Set +var Set = goose.NewSet( + provideFoo, + bar.ProvideBar, + provideFooBar) + func provideFoo() Foo { return 41 } -//goose:import Set "bar".Bar - -//goose:provide Set func provideFooBar(foo Foo, barVal bar.Bar) FooBar { return FooBar(foo) + FooBar(barVal) } diff --git a/internal/goose/testdata/PkgImport/foo/foo_goose.go b/internal/goose/testdata/PkgImport/foo/foo_goose.go index 73f5093..47946a7 100644 --- a/internal/goose/testdata/PkgImport/foo/foo_goose.go +++ b/internal/goose/testdata/PkgImport/foo/foo_goose.go @@ -2,6 +2,10 @@ package main -//goose:use Set +import ( + "codename/goose" +) -func injectFooBar() FooBar +func injectFooBar() FooBar { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/ReturnError/foo/foo.go b/internal/goose/testdata/ReturnError/foo/foo.go index cf415b1..b58e500 100644 --- a/internal/goose/testdata/ReturnError/foo/foo.go +++ b/internal/goose/testdata/ReturnError/foo/foo.go @@ -1,8 +1,12 @@ package main -import "errors" -import "fmt" -import "strings" +import ( + "errors" + "fmt" + "strings" + + "codename/goose" +) func main() { foo, err := injectFoo() @@ -16,7 +20,8 @@ func main() { type Foo int -//goose:provide Set func provideFoo() (Foo, error) { return 42, errors.New("there is no Foo") } + +var Set = goose.NewSet(provideFoo) diff --git a/internal/goose/testdata/ReturnError/foo/foo_goose.go b/internal/goose/testdata/ReturnError/foo/foo_goose.go index 5aa8775..169a8b0 100644 --- a/internal/goose/testdata/ReturnError/foo/foo_goose.go +++ b/internal/goose/testdata/ReturnError/foo/foo_goose.go @@ -2,6 +2,10 @@ package main -//goose:use Set +import ( + "codename/goose" +) -func injectFoo() (Foo, error) +func injectFoo() (Foo, error) { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/Struct/foo/foo.go b/internal/goose/testdata/Struct/foo/foo.go index d53f1c4..9f05bb8 100644 --- a/internal/goose/testdata/Struct/foo/foo.go +++ b/internal/goose/testdata/Struct/foo/foo.go @@ -1,6 +1,10 @@ package main -import "fmt" +import ( + "fmt" + + "codename/goose" +) func main() { fb := injectFooBar() @@ -10,18 +14,20 @@ func main() { type Foo int type Bar int -//goose:provide Set type FooBar struct { Foo Foo Bar Bar } -//goose:provide Set func provideFoo() Foo { return 41 } -//goose:provide Set func provideBar() Bar { return 1 } + +var Set = goose.NewSet( + FooBar{}, + provideFoo, + provideBar) diff --git a/internal/goose/testdata/Struct/foo/goose.go b/internal/goose/testdata/Struct/foo/goose.go index 73f5093..47946a7 100644 --- a/internal/goose/testdata/Struct/foo/goose.go +++ b/internal/goose/testdata/Struct/foo/goose.go @@ -2,6 +2,10 @@ package main -//goose:use Set +import ( + "codename/goose" +) -func injectFooBar() FooBar +func injectFooBar() FooBar { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/StructPointer/foo/foo.go b/internal/goose/testdata/StructPointer/foo/foo.go index d53f1c4..9f05bb8 100644 --- a/internal/goose/testdata/StructPointer/foo/foo.go +++ b/internal/goose/testdata/StructPointer/foo/foo.go @@ -1,6 +1,10 @@ package main -import "fmt" +import ( + "fmt" + + "codename/goose" +) func main() { fb := injectFooBar() @@ -10,18 +14,20 @@ func main() { type Foo int type Bar int -//goose:provide Set type FooBar struct { Foo Foo Bar Bar } -//goose:provide Set func provideFoo() Foo { return 41 } -//goose:provide Set func provideBar() Bar { return 1 } + +var Set = goose.NewSet( + FooBar{}, + provideFoo, + provideBar) diff --git a/internal/goose/testdata/StructPointer/foo/goose.go b/internal/goose/testdata/StructPointer/foo/goose.go index d99308a..af10467 100644 --- a/internal/goose/testdata/StructPointer/foo/goose.go +++ b/internal/goose/testdata/StructPointer/foo/goose.go @@ -2,6 +2,10 @@ package main -//goose:use Set +import ( + "codename/goose" +) -func injectFooBar() *FooBar +func injectFooBar() *FooBar { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/TwoDeps/foo/foo.go b/internal/goose/testdata/TwoDeps/foo/foo.go index 3284c2c..f18ce30 100644 --- a/internal/goose/testdata/TwoDeps/foo/foo.go +++ b/internal/goose/testdata/TwoDeps/foo/foo.go @@ -1,6 +1,10 @@ package main -import "fmt" +import ( + "fmt" + + "codename/goose" +) func main() { fmt.Println(injectFooBar()) @@ -10,17 +14,19 @@ type Foo int type Bar int type FooBar int -//goose:provide Set func provideFoo() Foo { return 40 } -//goose:provide Set func provideBar() Bar { return 2 } -//goose:provide Set func provideFooBar(foo Foo, bar Bar) FooBar { return FooBar(foo) + FooBar(bar) } + +var Set = goose.NewSet( + provideFoo, + provideBar, + provideFooBar) diff --git a/internal/goose/testdata/TwoDeps/foo/foo_goose.go b/internal/goose/testdata/TwoDeps/foo/foo_goose.go index 73f5093..47946a7 100644 --- a/internal/goose/testdata/TwoDeps/foo/foo_goose.go +++ b/internal/goose/testdata/TwoDeps/foo/foo_goose.go @@ -2,6 +2,10 @@ package main -//goose:use Set +import ( + "codename/goose" +) -func injectFooBar() FooBar +func injectFooBar() FooBar { + panic(goose.Use(Set)) +} diff --git a/internal/goose/testdata/Vendor/foo/goose.go b/internal/goose/testdata/Vendor/foo/goose.go index d522980..7619775 100644 --- a/internal/goose/testdata/Vendor/foo/goose.go +++ b/internal/goose/testdata/Vendor/foo/goose.go @@ -3,9 +3,10 @@ package main import ( - _ "bar" + "bar" + "codename/goose" ) -//goose:use "bar".Message - -func injectedMessage() string +func injectedMessage() string { + panic(goose.Use(bar.ProvideMessage)) +} diff --git a/internal/goose/testdata/Vendor/foo/vendor/bar/bar.go b/internal/goose/testdata/Vendor/foo/vendor/bar/bar.go index 84a1069..e7359c5 100644 --- a/internal/goose/testdata/Vendor/foo/vendor/bar/bar.go +++ b/internal/goose/testdata/Vendor/foo/vendor/bar/bar.go @@ -1,8 +1,6 @@ // Package bar is the vendored copy of bar which contains the real provider. package bar -//goose:provide Message - // ProvideMessage provides a friendly user greeting. func ProvideMessage() string { return "Hello, World!"