goose: add show command

Lists provider sets in packages given on the command line, including
outputs grouped by what is needed to obtain them.

The goose package now exports the loading phase as an API.

Example output: https://paste.googleplex.com/5509965720584192

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-04 14:42:56 -07:00
parent 2044e2213b
commit cfc6111ea5
4 changed files with 581 additions and 189 deletions

View File

@@ -60,8 +60,8 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
index := new(typeutil.Map) index := new(typeutil.Map)
for i, g := range given { for i, g := range given {
if p := providers.At(g); p != nil { if p := providers.At(g); p != nil {
pp := p.(*providerInfo) pp := p.(*Provider)
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.name, mc.fset.Position(pp.pos)) return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, mc.fset.Position(pp.Pos))
} }
index.Set(g, i) index.Set(g, i)
} }
@@ -70,49 +70,49 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
// using a depth-first search. The graph may contain cycles, which // using a depth-first search. The graph may contain cycles, which
// should trigger an error. // should trigger an error.
var calls []call var calls []call
var visit func(trail []providerInput) error var visit func(trail []ProviderInput) error
visit = func(trail []providerInput) error { visit = func(trail []ProviderInput) error {
typ := trail[len(trail)-1].typ typ := trail[len(trail)-1].Type
if index.At(typ) != nil { if index.At(typ) != nil {
return nil return nil
} }
for _, in := range trail[:len(trail)-1] { for _, in := range trail[:len(trail)-1] {
if types.Identical(typ, in.typ) { if types.Identical(typ, in.Type) {
// TODO(light): describe cycle // TODO(light): describe cycle
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil)) return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
} }
} }
p, _ := providers.At(typ).(*providerInfo) p, _ := providers.At(typ).(*Provider)
if p == nil { if p == nil {
if trail[len(trail)-1].optional { if trail[len(trail)-1].Optional {
return nil return nil
} }
if len(trail) == 1 { if len(trail) == 1 {
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil)) return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
} }
// TODO(light): give name of provider // TODO(light): give name of provider
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil)) return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil))
} }
if !types.Identical(p.out, typ) { if !types.Identical(p.Out, typ) {
// Interface binding. Don't create a call ourselves. // Interface binding. Don't create a call ourselves.
if err := visit(append(trail, providerInput{typ: p.out})); err != nil { if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil {
return err return err
} }
index.Set(typ, index.At(p.out)) index.Set(typ, index.At(p.Out))
return nil return nil
} }
for _, a := range p.args { for _, a := range p.Args {
// TODO(light): this will discard grown trail arrays. // TODO(light): this will discard grown trail arrays.
if err := visit(append(trail, a)); err != nil { if err := visit(append(trail, a)); err != nil {
return err return err
} }
} }
args := make([]int, len(p.args)) args := make([]int, len(p.Args))
ins := make([]types.Type, len(p.args)) ins := make([]types.Type, len(p.Args))
for i := range p.args { for i := range p.Args {
ins[i] = p.args[i].typ ins[i] = p.Args[i].Type
if x := index.At(p.args[i].typ); x != nil { if x := index.At(p.Args[i].Type); x != nil {
args[i] = x.(int) args[i] = x.(int)
} else { } else {
args[i] = -1 args[i] = -1
@@ -120,19 +120,19 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
} }
index.Set(typ, len(given)+len(calls)) index.Set(typ, len(given)+len(calls))
calls = append(calls, call{ calls = append(calls, call{
importPath: p.importPath, importPath: p.ImportPath,
name: p.name, name: p.Name,
args: args, args: args,
isStruct: p.isStruct, isStruct: p.IsStruct,
fieldNames: p.fields, fieldNames: p.Fields,
ins: ins, ins: ins,
out: typ, out: typ,
hasCleanup: p.hasCleanup, hasCleanup: p.HasCleanup,
hasErr: p.hasErr, hasErr: p.HasErr,
}) })
return nil return nil
} }
if err := visit([]providerInput{{typ: out}}); err != nil { if err := visit([]ProviderInput{{Type: out}}); err != nil {
return nil, err return nil, err
} }
return calls, nil return calls, nil
@@ -146,7 +146,7 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error
pos token.Pos pos token.Pos
} }
type binding struct { type binding struct {
ifaceBinding IfaceBinding
pset symref pset symref
from symref from symref
} }
@@ -173,53 +173,53 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error
} }
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err) return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
} }
for _, p := range pset.providers { for _, p := range pset.Providers {
if prev := pm.At(p.out); prev != nil { if prev := pm.At(p.Out); prev != nil {
pos := mc.fset.Position(p.pos) pos := mc.fset.Position(p.Pos)
typ := types.TypeString(p.out, nil) typ := types.TypeString(p.Out, nil)
prevPos := mc.fset.Position(prev.(*providerInfo).pos) prevPos := mc.fset.Position(prev.(*Provider).Pos)
if curr.from.importPath == "" { if curr.from.importPath == "" {
// Provider set is imported directly by injector. // Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
} }
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos) return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
} }
pm.Set(p.out, p) pm.Set(p.Out, p)
} }
for _, b := range pset.bindings { for _, b := range pset.Bindings {
bindings = append(bindings, binding{ bindings = append(bindings, binding{
ifaceBinding: b, IfaceBinding: b,
pset: curr.to, pset: curr.to,
from: curr.from, from: curr.from,
}) })
} }
for _, imp := range pset.imports { for _, imp := range pset.Imports {
next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos}) next = append(next, nextEnt{to: imp.symref(), from: curr.to, pos: imp.Pos})
} }
} }
for _, b := range bindings { for _, b := range bindings {
if prev := pm.At(b.iface); prev != nil { if prev := pm.At(b.Iface); prev != nil {
pos := mc.fset.Position(b.pos) pos := mc.fset.Position(b.Pos)
typ := types.TypeString(b.iface, nil) typ := types.TypeString(b.Iface, nil)
// TODO(light): error message for conflicting with another interface binding will point at provider instead of binding. // TODO(light): error message for conflicting with another interface binding will point at provider instead of binding.
prevPos := mc.fset.Position(prev.(*providerInfo).pos) prevPos := mc.fset.Position(prev.(*Provider).Pos)
if b.from.importPath == "" { if b.from.importPath == "" {
// Provider set is imported directly by injector. // Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
} }
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos) return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos)
} }
concrete := pm.At(b.provided) concrete := pm.At(b.Provided)
if concrete == nil { if concrete == nil {
pos := mc.fset.Position(b.pos) pos := mc.fset.Position(b.Pos)
typ := types.TypeString(b.provided, nil) typ := types.TypeString(b.Provided, nil)
if b.from.importPath == "" { if b.from.importPath == "" {
// Concrete provider is imported directly by injector. // Concrete provider is imported directly by injector.
return nil, fmt.Errorf("%v: no binding for %s", pos, typ) return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
} }
return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from) return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from)
} }
pm.Set(b.iface, concrete) pm.Set(b.Iface, concrete)
} }
return pm, nil return pm, nil
} }

View File

@@ -22,17 +22,7 @@ import (
// Generate performs dependency injection for a single package, // Generate performs dependency injection for a single package,
// returning the gofmt'd Go source code. // returning the gofmt'd Go source code.
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
// TODO(light): allow errors conf := newLoaderConfig(bctx, wd, true)
// TODO(light): stop errors from printing to stderr
conf := &loader.Config{
Build: new(build.Context),
ParserMode: parser.ParseComments,
Cwd: wd,
TypeCheckFuncBodies: func(string) bool { return false },
}
*conf.Build = *bctx
n := len(conf.Build.BuildTags)
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
conf.Import(pkg) conf.Import(pkg)
prog, err := conf.Load() prog, err := conf.Load()
if err != nil { if err != nil {
@@ -99,6 +89,24 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
return fmtSrc, nil return fmtSrc, nil
} }
func newLoaderConfig(bctx *build.Context, wd string, inject bool) *loader.Config {
// TODO(light): allow errors
// TODO(light): stop errors from printing to stderr
conf := &loader.Config{
Build: bctx,
ParserMode: parser.ParseComments,
Cwd: wd,
TypeCheckFuncBodies: func(string) bool { return false },
}
if inject {
conf.Build = new(build.Context)
*conf.Build = *bctx
n := len(conf.Build.BuildTags)
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
}
return conf
}
// gen is the generator state. // gen is the generator state.
type gen struct { type gen struct {
currPackage string currPackage string

View File

@@ -14,72 +14,158 @@ import (
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
) )
// A providerSet describes a set of providers. The zero value is an empty // A ProviderSet describes a set of providers. The zero value is an empty
// providerSet. // ProviderSet.
type providerSet struct { type ProviderSet struct {
providers []*providerInfo Providers []*Provider
bindings []ifaceBinding Bindings []IfaceBinding
imports []providerSetImport Imports []ProviderSetImport
} }
// An ifaceBinding declares that a type should be used to satisfy inputs // An IfaceBinding declares that a type should be used to satisfy inputs
// of the given interface type. // of the given interface type.
// type IfaceBinding struct {
// provided is always a type that is assignable to iface. // Iface is the interface type, which is what can be injected.
type ifaceBinding struct { Iface types.Type
// iface is the interface type, which is what can be injected.
iface types.Type
// provided is always a type that is assignable to Iface. // Provided is always a type that is assignable to Iface.
provided types.Type Provided types.Type
// pos is the position where the binding was declared. // Pos is the position where the binding was declared.
pos token.Pos Pos token.Pos
} }
type providerSetImport struct { // A ProviderSetImport adds providers from one provider set into another.
symref type ProviderSetImport struct {
pos token.Pos ProviderSetID
Pos token.Pos
} }
// providerInfo records the signature of a provider. // Provider records the signature of a provider. A provider is a
type providerInfo struct { // single Go object, either a function or a named struct type.
// importPath is the package path that the Go object resides in. type Provider struct {
importPath string // ImportPath is the package path that the Go object resides in.
ImportPath string
// name is the name of the Go object. // Name is the name of the Go object.
name string Name string
// pos is the source position of the func keyword or type spec // Pos is the source position of the func keyword or type spec
// defining this provider. // defining this provider.
pos token.Pos Pos token.Pos
// args is the list of data dependencies this provider has. // Args is the list of data dependencies this provider has.
args []providerInput Args []ProviderInput
// isStruct is true if this provider is a named struct type. // IsStruct is true if this provider is a named struct type.
// Otherwise it's a function. // Otherwise it's a function.
isStruct bool IsStruct bool
// fields lists the field names to populate. This will map 1:1 with // Fields lists the field names to populate. This will map 1:1 with
// elements in Args. // elements in Args.
fields []string Fields []string
// out is the type this provider produces. // Out is the type this provider produces.
out types.Type Out types.Type
// hasCleanup reports whether the provider function returns a cleanup // HasCleanup reports whether the provider function returns a cleanup
// function. (Always false for structs.) // function. (Always false for structs.)
hasCleanup bool HasCleanup bool
// hasErr reports whether the provider function can return an error. // HasErr reports whether the provider function can return an error.
// (Always false for structs.) // (Always false for structs.)
hasErr bool HasErr bool
} }
type providerInput struct { // ProviderInput describes an incoming edge in the provider graph.
typ types.Type type ProviderInput struct {
optional bool Type types.Type
Optional bool
}
// Load finds all the provider sets in the given packages, as well as
// the provider sets' transitive dependencies.
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) {
conf := newLoaderConfig(bctx, wd, false)
for _, p := range pkgs {
conf.Import(p)
}
prog, err := conf.Load()
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
r := newImportResolver(conf, prog.Fset)
var next []string
initial := make(map[string]struct{})
for _, pkgInfo := range prog.InitialPackages() {
path := pkgInfo.Pkg.Path()
next = append(next, path)
initial[path] = struct{}{}
}
visited := make(map[string]struct{})
info := &Info{
Fset: prog.Fset,
Sets: make(map[ProviderSetID]*ProviderSet),
All: make(map[ProviderSetID]*ProviderSet),
}
for len(next) > 0 {
curr := next[len(next)-1]
next = next[:len(next)-1]
if _, ok := visited[curr]; ok {
continue
}
visited[curr] = struct{}{}
pkgInfo := prog.Package(curr)
sets, err := findProviderSets(findContext{
fset: prog.Fset,
pkg: pkgInfo.Pkg,
typeInfo: &pkgInfo.Info,
r: r,
}, pkgInfo.Files)
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
path := pkgInfo.Pkg.Path()
for name, set := range sets {
info.All[ProviderSetID{path, name}] = set
for _, imp := range set.Imports {
next = append(next, imp.ImportPath)
}
}
if _, ok := initial[path]; ok {
for name, set := range sets {
info.Sets[ProviderSetID{path, name}] = set
}
}
}
return info, nil
}
// Info holds the result of Load.
type Info struct {
Fset *token.FileSet
// Sets contains all the provider sets in the initial packages.
Sets map[ProviderSetID]*ProviderSet
// All contains all the provider sets transitively depended on by the
// initial packages' provider sets.
All map[ProviderSetID]*ProviderSet
}
// A ProviderSetID identifies a provider set.
type ProviderSetID struct {
ImportPath string
Name string
}
// String returns the ID as ""path/to/pkg".Foo".
func (id ProviderSetID) String() string {
return id.symref().String()
}
func (id ProviderSetID) symref() symref {
return symref{importPath: id.ImportPath, name: id.Name}
} }
type findContext struct { type findContext struct {
@@ -90,8 +176,8 @@ type findContext struct {
} }
// findProviderSets processes a package and extracts the provider sets declared in it. // findProviderSets processes a package and extracts the provider sets declared in it.
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*providerSet, error) { func findProviderSets(fctx findContext, files []*ast.File) (map[string]*ProviderSet, error) {
sets := make(map[string]*providerSet) sets := make(map[string]*ProviderSet)
for _, f := range files { for _, f := range files {
fileScope := fctx.typeInfo.Scopes[f] fileScope := fctx.typeInfo.Scopes[f]
if fileScope == nil { if fileScope == nil {
@@ -115,7 +201,7 @@ func findProviderSets(fctx findContext, files []*ast.File) (map[string]*provider
} }
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration. // processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet, scope *types.Scope, d directive) error { func processUnassociatedDirective(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, d directive) error {
switch d.kind { switch d.kind {
case "provide", "optional": case "provide", "optional":
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos)) return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
@@ -169,15 +255,15 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
name := args[0] name := args[0]
if pset := sets[name]; pset != nil { if pset := sets[name]; pset != nil {
pset.bindings = append(pset.bindings, ifaceBinding{ pset.Bindings = append(pset.Bindings, IfaceBinding{
iface: iface, Iface: iface,
provided: provided, Provided: provided,
}) })
} else { } else {
sets[name] = &providerSet{ sets[name] = &ProviderSet{
bindings: []ifaceBinding{{ Bindings: []IfaceBinding{{
iface: iface, Iface: iface,
provided: provided, Provided: provided,
}}, }},
} }
} }
@@ -197,18 +283,30 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
} }
if mod := sets[name]; mod != nil { if mod := sets[name]; mod != nil {
found := false found := false
for _, other := range mod.imports { for _, other := range mod.Imports {
if ref == other.symref { if ref == other.symref() {
found = true found = true
break break
} }
} }
if !found { if !found {
mod.imports = append(mod.imports, providerSetImport{symref: ref, pos: d.pos}) mod.Imports = append(mod.Imports, ProviderSetImport{
ProviderSetID: ProviderSetID{
ImportPath: ref.importPath,
Name: ref.name,
},
Pos: d.pos,
})
} }
} else { } else {
sets[name] = &providerSet{ sets[name] = &ProviderSet{
imports: []providerSetImport{{symref: ref, pos: d.pos}}, Imports: []ProviderSetImport{{
ProviderSetID: ProviderSetID{
ImportPath: ref.importPath,
Name: ref.name,
},
Pos: d.pos,
}},
} }
} }
} }
@@ -219,7 +317,7 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
} }
// processDeclDirectives processes the directives associated with a top-level declaration. // processDeclDirectives processes the directives associated with a top-level declaration.
func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope *types.Scope, dg directiveGroup) error { func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, dg directiveGroup) error {
p, err := dg.single(fctx.fset, "provide") p, err := dg.single(fctx.fset, "provide")
if err != nil { if err != nil {
return err return err
@@ -258,15 +356,15 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
providerSetName = fn.Name() providerSetName = fn.Name()
} }
if mod := sets[providerSetName]; mod != nil { if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers { for _, other := range mod.Providers {
if types.Identical(other.out, provider.out) { if types.Identical(other.Out, provider.Out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos)) return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos))
} }
} }
mod.providers = append(mod.providers, provider) mod.Providers = append(mod.Providers, provider)
} else { } else {
sets[providerSetName] = &providerSet{ sets[providerSetName] = &ProviderSet{
providers: []*providerInfo{provider}, Providers: []*Provider{provider},
} }
} }
case *ast.GenDecl: case *ast.GenDecl:
@@ -288,22 +386,22 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
if providerSetName == "" { if providerSetName == "" {
providerSetName = typeName.Name() providerSetName = typeName.Name()
} }
ptrProvider := new(providerInfo) ptrProvider := new(Provider)
*ptrProvider = *provider *ptrProvider = *provider
ptrProvider.out = types.NewPointer(provider.out) ptrProvider.Out = types.NewPointer(provider.Out)
if mod := sets[providerSetName]; mod != nil { if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers { for _, other := range mod.Providers {
if types.Identical(other.out, provider.out) { if types.Identical(other.Out, provider.Out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos)) return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos))
} }
if types.Identical(other.out, ptrProvider.out) { if types.Identical(other.Out, ptrProvider.Out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.out, nil), fctx.fset.Position(other.pos)) return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.Out, nil), fctx.fset.Position(other.Pos))
} }
} }
mod.providers = append(mod.providers, provider, ptrProvider) mod.Providers = append(mod.Providers, provider, ptrProvider)
} else { } else {
sets[providerSetName] = &providerSet{ sets[providerSetName] = &ProviderSet{
providers: []*providerInfo{provider, ptrProvider}, Providers: []*Provider{provider, ptrProvider},
} }
} }
default: default:
@@ -312,7 +410,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
return nil return nil
} }
func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*providerInfo, error) { func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*Provider, error) {
sig := fn.Type().(*types.Signature) sig := fn.Type().(*types.Signature)
optionals := make([]bool, sig.Params().Len()) optionals := make([]bool, sig.Params().Len())
@@ -352,30 +450,30 @@ func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[stri
} }
out := r.At(0).Type() out := r.At(0).Type()
params := sig.Params() params := sig.Params()
provider := &providerInfo{ provider := &Provider{
importPath: fctx.pkg.Path(), ImportPath: fctx.pkg.Path(),
name: fn.Name(), Name: fn.Name(),
pos: fn.Pos(), Pos: fn.Pos(),
args: make([]providerInput, params.Len()), Args: make([]ProviderInput, params.Len()),
out: out, Out: out,
hasCleanup: hasCleanup, HasCleanup: hasCleanup,
hasErr: hasErr, HasErr: hasErr,
} }
for i := 0; i < params.Len(); i++ { for i := 0; i < params.Len(); i++ {
provider.args[i] = providerInput{ provider.Args[i] = ProviderInput{
typ: params.At(i).Type(), Type: params.At(i).Type(),
optional: optionals[i], Optional: optionals[i],
} }
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if types.Identical(provider.args[i].typ, provider.args[j].typ) { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil)) return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil))
} }
} }
} }
return provider, nil return provider, nil
} }
func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*providerInfo, error) { func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*Provider, error) {
out := typeName.Type() out := typeName.Type()
st := out.Underlying().(*types.Struct) st := out.Underlying().(*types.Struct)
for arg, dpos := range optionals { for arg, dpos := range optionals {
@@ -392,26 +490,26 @@ func processStructProvider(fctx findContext, typeName *types.TypeName, optionals
} }
pos := typeName.Pos() pos := typeName.Pos()
provider := &providerInfo{ provider := &Provider{
importPath: fctx.pkg.Path(), ImportPath: fctx.pkg.Path(),
name: typeName.Name(), Name: typeName.Name(),
pos: pos, Pos: pos,
args: make([]providerInput, st.NumFields()), Args: make([]ProviderInput, st.NumFields()),
fields: make([]string, st.NumFields()), Fields: make([]string, st.NumFields()),
isStruct: true, IsStruct: true,
out: out, Out: out,
} }
for i := 0; i < st.NumFields(); i++ { for i := 0; i < st.NumFields(); i++ {
f := st.Field(i) f := st.Field(i)
_, optional := optionals[f.Name()] _, optional := optionals[f.Name()]
provider.args[i] = providerInput{ provider.Args[i] = ProviderInput{
typ: f.Type(), Type: f.Type(),
optional: optional, Optional: optional,
} }
provider.fields[i] = f.Name() provider.Fields[i] = f.Name()
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if types.Identical(provider.args[i].typ, provider.args[j].typ) { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.args[j].typ, nil)) return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.Args[j].Type, nil))
} }
} }
} }
@@ -420,7 +518,7 @@ func processStructProvider(fctx findContext, typeName *types.TypeName, optionals
// providerSetCache is a lazily evaluated index of provider sets. // providerSetCache is a lazily evaluated index of provider sets.
type providerSetCache struct { type providerSetCache struct {
sets map[string]map[string]*providerSet sets map[string]map[string]*ProviderSet
fset *token.FileSet fset *token.FileSet
prog *loader.Program prog *loader.Program
r *importResolver r *importResolver
@@ -434,7 +532,7 @@ func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCa
} }
} }
func (mc *providerSetCache) get(ref symref) (*providerSet, error) { func (mc *providerSetCache) get(ref symref) (*ProviderSet, error) {
if mods, cached := mc.sets[ref.importPath]; cached { if mods, cached := mc.sets[ref.importPath]; cached {
mod := mods[ref.name] mod := mods[ref.name]
if mod == nil { if mod == nil {
@@ -443,7 +541,7 @@ func (mc *providerSetCache) get(ref symref) (*providerSet, error) {
return mod, nil return mod, nil
} }
if mc.sets == nil { if mc.sets == nil {
mc.sets = make(map[string]map[string]*providerSet) mc.sets = make(map[string]map[string]*ProviderSet)
} }
pkg := mc.prog.Package(ref.importPath) pkg := mc.prog.Package(ref.importPath)
mods, err := findProviderSets(findContext{ mods, err := findProviderSets(findContext{

342
main.go
View File

@@ -6,47 +6,333 @@ package main
import ( import (
"fmt" "fmt"
"go/build" "go/build"
"go/token"
"go/types"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"sort"
"strings"
"codename/goose/internal/goose" "codename/goose/internal/goose"
"golang.org/x/tools/go/types/typeutil"
) )
func main() { func main() {
var pkg string var err error
switch len(os.Args) { switch {
case 1: case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen":
pkg = "." err = generate(".")
case 2: case len(os.Args) == 2 && os.Args[1] == "show":
pkg = os.Args[1] err = show(".")
case len(os.Args) == 2:
err = generate(os.Args[1])
case len(os.Args) > 2 && os.Args[1] == "show":
err = show(os.Args[2:]...)
case len(os.Args) == 3 && os.Args[1] == "gen":
err = generate(os.Args[2])
default: default:
fmt.Fprintln(os.Stderr, "goose: usage: goose [PKG]") fmt.Fprintln(os.Stderr, "goose: usage: goose [gen] [PKG] | goose show [...]")
os.Exit(64) os.Exit(64)
} }
wd, err := os.Getwd()
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err) fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1) os.Exit(1)
} }
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly) }
if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err) // generate runs the gen subcommand. Given a package, gen will create
os.Exit(1) // the goose_gen.go file.
} func generate(pkg string) error {
out, err := goose.Generate(&build.Default, wd, pkg) wd, err := os.Getwd()
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err) return err
os.Exit(1) }
} pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
if len(out) == 0 { if err != nil {
// No Goose directives, don't write anything. return err
fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg) }
return out, err := goose.Generate(&build.Default, wd, pkg)
} if err != nil {
p := filepath.Join(pkgInfo.Dir, "goose_gen.go") return err
if err := ioutil.WriteFile(p, out, 0666); err != nil { }
fmt.Fprintln(os.Stderr, "goose:", err) if len(out) == 0 {
os.Exit(1) // No Goose directives, don't write anything.
} fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg)
return nil
}
p := filepath.Join(pkgInfo.Dir, "goose_gen.go")
if err := ioutil.WriteFile(p, out, 0666); err != nil {
return err
}
return nil
}
// show runs the show subcommand.
//
// Given one or more packages, show will find all the declared provider
// sets and print what other provider sets it imports and what outputs
// it can produce, given possible inputs.
func show(pkgs ...string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
info, err := goose.Load(&build.Default, wd, pkgs)
if err != nil {
return err
}
keys := make([]goose.ProviderSetID, 0, len(info.Sets))
for k := range info.Sets {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].ImportPath == keys[j].ImportPath {
return keys[i].Name < keys[j].Name
}
return keys[i].ImportPath < keys[j].ImportPath
})
// ANSI color codes.
const (
reset = "\x1b[0m"
redBold = "\x1b[0;1;31m"
blue = "\x1b[0;34m"
green = "\x1b[0;32m"
)
for i, k := range keys {
if i > 0 {
fmt.Println()
}
outGroups, imports := gather(info, k)
fmt.Printf("%s%s%s\n", redBold, k, reset)
for _, imp := range sortSet(imports) {
fmt.Printf("\t%s\n", imp)
}
for i := range outGroups {
fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset)
out := make(map[string]token.Pos, outGroups[i].outputs.Len())
outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) {
switch v := v.(type) {
case *goose.Provider:
out[types.TypeString(t, nil)] = v.Pos
case goose.IfaceBinding:
out[types.TypeString(t, nil)] = v.Pos
default:
panic("unreachable")
}
})
for _, t := range sortSet(out) {
fmt.Printf("\t%s%s%s\n", green, t, reset)
fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t]))
}
}
}
return nil
}
type outGroup struct {
name string
inputs *typeutil.Map // values are not important
outputs *typeutil.Map // values are either *goose.Provider or goose.IfaceBinding
}
// gather flattens a provider set into outputs grouped by the inputs
// required to create them. As it flattens the provider set, it records
// the visited provider sets as imports.
func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports map[string]struct{}) {
hash := typeutil.MakeHasher()
// Map types to providers and bindings.
pm := new(typeutil.Map)
pm.SetHasher(hash)
next := []goose.ProviderSetID{key}
visited := make(map[goose.ProviderSetID]struct{})
imports = make(map[string]struct{})
for len(next) > 0 {
curr := next[len(next)-1]
next = next[:len(next)-1]
if _, found := visited[curr]; found {
continue
}
visited[curr] = struct{}{}
if curr != key {
imports[curr.String()] = struct{}{}
}
set := info.All[curr]
for _, p := range set.Providers {
pm.Set(p.Out, p)
}
for _, b := range set.Bindings {
pm.Set(b.Iface, b)
}
for _, imp := range set.Imports {
next = append(next, imp.ProviderSetID)
}
}
// Depth-first search to build groups.
var groups []outGroup
inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input.
inputVisited.SetHasher(hash)
pmKeys := pm.Keys()
var stk []types.Type
for _, k := range pmKeys {
// Start a DFS by picking a random unvisited node.
if inputVisited.At(k) == nil {
stk = append(stk, k)
}
// Run DFS
dfs:
for len(stk) > 0 {
curr := stk[len(stk)-1]
stk = stk[:len(stk)-1]
if inputVisited.At(curr) != nil {
continue
}
switch p := pm.At(curr).(type) {
case nil:
// This is an input.
inputVisited.Set(curr, -1)
case *goose.Provider:
// Try to see if any args haven't been visited.
allPresent := true
for _, arg := range p.Args {
if arg.Optional {
continue
}
if inputVisited.At(arg.Type) == nil {
allPresent = false
}
}
if !allPresent {
stk = append(stk, curr)
for _, arg := range p.Args {
if arg.Optional {
continue
}
if inputVisited.At(arg.Type) == nil {
stk = append(stk, arg.Type)
}
}
continue dfs
}
// Build up set of input types, match to a group.
in := new(typeutil.Map)
in.SetHasher(hash)
for _, arg := range p.Args {
if arg.Optional {
continue
}
i := inputVisited.At(arg.Type).(int)
if i == -1 {
in.Set(arg.Type, true)
} else {
mergeTypeSets(in, groups[i].inputs)
}
}
for i := range groups {
if sameTypeKeys(groups[i].inputs, in) {
groups[i].outputs.Set(p.Out, p)
inputVisited.Set(p.Out, i)
continue dfs
}
}
out := new(typeutil.Map)
out.SetHasher(hash)
out.Set(p.Out, p)
inputVisited.Set(p.Out, len(groups))
groups = append(groups, outGroup{
inputs: in,
outputs: out,
})
case goose.IfaceBinding:
i, ok := inputVisited.At(p.Provided).(int)
if !ok {
stk = append(stk, curr, p.Provided)
continue dfs
}
if i != -1 {
groups[i].outputs.Set(p.Iface, p)
inputVisited.Set(p.Iface, i)
continue dfs
}
// Binding must be provided. Find or add a group.
for i := range groups {
if groups[i].inputs.Len() != 1 {
continue
}
if groups[i].inputs.At(p.Provided) != nil {
groups[i].outputs.Set(p.Iface, p)
inputVisited.Set(p.Iface, i)
continue dfs
}
}
in := new(typeutil.Map)
in.SetHasher(hash)
in.Set(p.Provided, true)
out := new(typeutil.Map)
out.SetHasher(hash)
out.Set(p.Iface, p)
groups = append(groups, outGroup{
inputs: in,
outputs: out,
})
default:
panic("unreachable")
}
}
}
// Name and sort groups
for i := range groups {
if groups[i].inputs.Len() == 0 {
groups[i].name = "no inputs"
continue
}
instr := make([]string, 0, groups[i].inputs.Len())
groups[i].inputs.Iterate(func(k types.Type, _ interface{}) {
instr = append(instr, types.TypeString(k, nil))
})
sort.Strings(instr)
groups[i].name = strings.Join(instr, ", ")
}
sort.Slice(groups, func(i, j int) bool {
if groups[i].inputs.Len() == groups[j].inputs.Len() {
return groups[i].name < groups[j].name
}
return groups[i].inputs.Len() < groups[j].inputs.Len()
})
return groups, imports
}
func mergeTypeSets(dst, src *typeutil.Map) {
src.Iterate(func(k types.Type, _ interface{}) {
dst.Set(k, true)
})
}
func sameTypeKeys(a, b *typeutil.Map) bool {
if a.Len() != b.Len() {
return false
}
same := true
a.Iterate(func(k types.Type, _ interface{}) {
if b.At(k) == nil {
same = false
}
})
return same
}
func sortSet(set interface{}) []string {
rv := reflect.ValueOf(set)
a := make([]string, 0, rv.Len())
keys := rv.MapKeys()
for _, k := range keys {
a = append(a, k.String())
}
sort.Strings(a)
return a
} }