goose: add show command

Lists provider sets in packages given on the command line, including
outputs grouped by what is needed to obtain them.

The goose package now exports the loading phase as an API.

Example output: https://paste.googleplex.com/5509965720584192

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-04 14:42:56 -07:00
parent 2044e2213b
commit cfc6111ea5
4 changed files with 581 additions and 189 deletions

View File

@@ -60,8 +60,8 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
index := new(typeutil.Map)
for i, g := range given {
if p := providers.At(g); p != nil {
pp := p.(*providerInfo)
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.name, mc.fset.Position(pp.pos))
pp := p.(*Provider)
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, mc.fset.Position(pp.Pos))
}
index.Set(g, i)
}
@@ -70,49 +70,49 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
// using a depth-first search. The graph may contain cycles, which
// should trigger an error.
var calls []call
var visit func(trail []providerInput) error
visit = func(trail []providerInput) error {
typ := trail[len(trail)-1].typ
var visit func(trail []ProviderInput) error
visit = func(trail []ProviderInput) error {
typ := trail[len(trail)-1].Type
if index.At(typ) != nil {
return nil
}
for _, in := range trail[:len(trail)-1] {
if types.Identical(typ, in.typ) {
if types.Identical(typ, in.Type) {
// TODO(light): describe cycle
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
}
}
p, _ := providers.At(typ).(*providerInfo)
p, _ := providers.At(typ).(*Provider)
if p == nil {
if trail[len(trail)-1].optional {
if trail[len(trail)-1].Optional {
return nil
}
if len(trail) == 1 {
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
}
// TODO(light): give name of provider
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil))
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil))
}
if !types.Identical(p.out, typ) {
if !types.Identical(p.Out, typ) {
// Interface binding. Don't create a call ourselves.
if err := visit(append(trail, providerInput{typ: p.out})); err != nil {
if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil {
return err
}
index.Set(typ, index.At(p.out))
index.Set(typ, index.At(p.Out))
return nil
}
for _, a := range p.args {
for _, a := range p.Args {
// TODO(light): this will discard grown trail arrays.
if err := visit(append(trail, a)); err != nil {
return err
}
}
args := make([]int, len(p.args))
ins := make([]types.Type, len(p.args))
for i := range p.args {
ins[i] = p.args[i].typ
if x := index.At(p.args[i].typ); x != nil {
args := make([]int, len(p.Args))
ins := make([]types.Type, len(p.Args))
for i := range p.Args {
ins[i] = p.Args[i].Type
if x := index.At(p.Args[i].Type); x != nil {
args[i] = x.(int)
} else {
args[i] = -1
@@ -120,19 +120,19 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
}
index.Set(typ, len(given)+len(calls))
calls = append(calls, call{
importPath: p.importPath,
name: p.name,
importPath: p.ImportPath,
name: p.Name,
args: args,
isStruct: p.isStruct,
fieldNames: p.fields,
isStruct: p.IsStruct,
fieldNames: p.Fields,
ins: ins,
out: typ,
hasCleanup: p.hasCleanup,
hasErr: p.hasErr,
hasCleanup: p.HasCleanup,
hasErr: p.HasErr,
})
return nil
}
if err := visit([]providerInput{{typ: out}}); err != nil {
if err := visit([]ProviderInput{{Type: out}}); err != nil {
return nil, err
}
return calls, nil
@@ -146,7 +146,7 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error
pos token.Pos
}
type binding struct {
ifaceBinding
IfaceBinding
pset symref
from symref
}
@@ -173,53 +173,53 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error
}
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
}
for _, p := range pset.providers {
if prev := pm.At(p.out); prev != nil {
pos := mc.fset.Position(p.pos)
typ := types.TypeString(p.out, nil)
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
for _, p := range pset.Providers {
if prev := pm.At(p.Out); prev != nil {
pos := mc.fset.Position(p.Pos)
typ := types.TypeString(p.Out, nil)
prevPos := mc.fset.Position(prev.(*Provider).Pos)
if curr.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
}
pm.Set(p.out, p)
pm.Set(p.Out, p)
}
for _, b := range pset.bindings {
for _, b := range pset.Bindings {
bindings = append(bindings, binding{
ifaceBinding: b,
IfaceBinding: b,
pset: curr.to,
from: curr.from,
})
}
for _, imp := range pset.imports {
next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos})
for _, imp := range pset.Imports {
next = append(next, nextEnt{to: imp.symref(), from: curr.to, pos: imp.Pos})
}
}
for _, b := range bindings {
if prev := pm.At(b.iface); prev != nil {
pos := mc.fset.Position(b.pos)
typ := types.TypeString(b.iface, nil)
if prev := pm.At(b.Iface); prev != nil {
pos := mc.fset.Position(b.Pos)
typ := types.TypeString(b.Iface, nil)
// TODO(light): error message for conflicting with another interface binding will point at provider instead of binding.
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
prevPos := mc.fset.Position(prev.(*Provider).Pos)
if b.from.importPath == "" {
// Provider set is imported directly by injector.
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
}
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos)
}
concrete := pm.At(b.provided)
concrete := pm.At(b.Provided)
if concrete == nil {
pos := mc.fset.Position(b.pos)
typ := types.TypeString(b.provided, nil)
pos := mc.fset.Position(b.Pos)
typ := types.TypeString(b.Provided, nil)
if b.from.importPath == "" {
// Concrete provider is imported directly by injector.
return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
}
return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from)
}
pm.Set(b.iface, concrete)
pm.Set(b.Iface, concrete)
}
return pm, nil
}

View File

@@ -22,17 +22,7 @@ import (
// Generate performs dependency injection for a single package,
// returning the gofmt'd Go source code.
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
// TODO(light): allow errors
// TODO(light): stop errors from printing to stderr
conf := &loader.Config{
Build: new(build.Context),
ParserMode: parser.ParseComments,
Cwd: wd,
TypeCheckFuncBodies: func(string) bool { return false },
}
*conf.Build = *bctx
n := len(conf.Build.BuildTags)
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
conf := newLoaderConfig(bctx, wd, true)
conf.Import(pkg)
prog, err := conf.Load()
if err != nil {
@@ -99,6 +89,24 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
return fmtSrc, nil
}
func newLoaderConfig(bctx *build.Context, wd string, inject bool) *loader.Config {
// TODO(light): allow errors
// TODO(light): stop errors from printing to stderr
conf := &loader.Config{
Build: bctx,
ParserMode: parser.ParseComments,
Cwd: wd,
TypeCheckFuncBodies: func(string) bool { return false },
}
if inject {
conf.Build = new(build.Context)
*conf.Build = *bctx
n := len(conf.Build.BuildTags)
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
}
return conf
}
// gen is the generator state.
type gen struct {
currPackage string

View File

@@ -14,72 +14,158 @@ import (
"golang.org/x/tools/go/loader"
)
// A providerSet describes a set of providers. The zero value is an empty
// providerSet.
type providerSet struct {
providers []*providerInfo
bindings []ifaceBinding
imports []providerSetImport
// A ProviderSet describes a set of providers. The zero value is an empty
// ProviderSet.
type ProviderSet struct {
Providers []*Provider
Bindings []IfaceBinding
Imports []ProviderSetImport
}
// An ifaceBinding declares that a type should be used to satisfy inputs
// An IfaceBinding declares that a type should be used to satisfy inputs
// of the given interface type.
//
// provided is always a type that is assignable to iface.
type ifaceBinding struct {
// iface is the interface type, which is what can be injected.
iface types.Type
type IfaceBinding struct {
// Iface is the interface type, which is what can be injected.
Iface types.Type
// provided is always a type that is assignable to Iface.
provided types.Type
// Provided is always a type that is assignable to Iface.
Provided types.Type
// pos is the position where the binding was declared.
pos token.Pos
// Pos is the position where the binding was declared.
Pos token.Pos
}
type providerSetImport struct {
symref
pos token.Pos
// A ProviderSetImport adds providers from one provider set into another.
type ProviderSetImport struct {
ProviderSetID
Pos token.Pos
}
// providerInfo records the signature of a provider.
type providerInfo struct {
// importPath is the package path that the Go object resides in.
importPath string
// Provider records the signature of a provider. A provider is a
// single Go object, either a function or a named struct type.
type Provider struct {
// ImportPath is the package path that the Go object resides in.
ImportPath string
// name is the name of the Go object.
name string
// Name is the name of the Go object.
Name string
// pos is the source position of the func keyword or type spec
// Pos is the source position of the func keyword or type spec
// defining this provider.
pos token.Pos
Pos token.Pos
// args is the list of data dependencies this provider has.
args []providerInput
// Args is the list of data dependencies this provider has.
Args []ProviderInput
// isStruct is true if this provider is a named struct type.
// IsStruct is true if this provider is a named struct type.
// Otherwise it's a function.
isStruct bool
IsStruct bool
// fields lists the field names to populate. This will map 1:1 with
// Fields lists the field names to populate. This will map 1:1 with
// elements in Args.
fields []string
Fields []string
// out is the type this provider produces.
out types.Type
// Out is the type this provider produces.
Out types.Type
// hasCleanup reports whether the provider function returns a cleanup
// HasCleanup reports whether the provider function returns a cleanup
// function. (Always false for structs.)
hasCleanup bool
HasCleanup bool
// hasErr reports whether the provider function can return an error.
// HasErr reports whether the provider function can return an error.
// (Always false for structs.)
hasErr bool
HasErr bool
}
type providerInput struct {
typ types.Type
optional bool
// ProviderInput describes an incoming edge in the provider graph.
type ProviderInput struct {
Type types.Type
Optional bool
}
// Load finds all the provider sets in the given packages, as well as
// the provider sets' transitive dependencies.
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) {
conf := newLoaderConfig(bctx, wd, false)
for _, p := range pkgs {
conf.Import(p)
}
prog, err := conf.Load()
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
r := newImportResolver(conf, prog.Fset)
var next []string
initial := make(map[string]struct{})
for _, pkgInfo := range prog.InitialPackages() {
path := pkgInfo.Pkg.Path()
next = append(next, path)
initial[path] = struct{}{}
}
visited := make(map[string]struct{})
info := &Info{
Fset: prog.Fset,
Sets: make(map[ProviderSetID]*ProviderSet),
All: make(map[ProviderSetID]*ProviderSet),
}
for len(next) > 0 {
curr := next[len(next)-1]
next = next[:len(next)-1]
if _, ok := visited[curr]; ok {
continue
}
visited[curr] = struct{}{}
pkgInfo := prog.Package(curr)
sets, err := findProviderSets(findContext{
fset: prog.Fset,
pkg: pkgInfo.Pkg,
typeInfo: &pkgInfo.Info,
r: r,
}, pkgInfo.Files)
if err != nil {
return nil, fmt.Errorf("load: %v", err)
}
path := pkgInfo.Pkg.Path()
for name, set := range sets {
info.All[ProviderSetID{path, name}] = set
for _, imp := range set.Imports {
next = append(next, imp.ImportPath)
}
}
if _, ok := initial[path]; ok {
for name, set := range sets {
info.Sets[ProviderSetID{path, name}] = set
}
}
}
return info, nil
}
// Info holds the result of Load.
type Info struct {
Fset *token.FileSet
// Sets contains all the provider sets in the initial packages.
Sets map[ProviderSetID]*ProviderSet
// All contains all the provider sets transitively depended on by the
// initial packages' provider sets.
All map[ProviderSetID]*ProviderSet
}
// A ProviderSetID identifies a provider set.
type ProviderSetID struct {
ImportPath string
Name string
}
// String returns the ID as ""path/to/pkg".Foo".
func (id ProviderSetID) String() string {
return id.symref().String()
}
func (id ProviderSetID) symref() symref {
return symref{importPath: id.ImportPath, name: id.Name}
}
type findContext struct {
@@ -90,8 +176,8 @@ type findContext struct {
}
// findProviderSets processes a package and extracts the provider sets declared in it.
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*providerSet, error) {
sets := make(map[string]*providerSet)
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*ProviderSet, error) {
sets := make(map[string]*ProviderSet)
for _, f := range files {
fileScope := fctx.typeInfo.Scopes[f]
if fileScope == nil {
@@ -115,7 +201,7 @@ func findProviderSets(fctx findContext, files []*ast.File) (map[string]*provider
}
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet, scope *types.Scope, d directive) error {
func processUnassociatedDirective(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, d directive) error {
switch d.kind {
case "provide", "optional":
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
@@ -169,15 +255,15 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
name := args[0]
if pset := sets[name]; pset != nil {
pset.bindings = append(pset.bindings, ifaceBinding{
iface: iface,
provided: provided,
pset.Bindings = append(pset.Bindings, IfaceBinding{
Iface: iface,
Provided: provided,
})
} else {
sets[name] = &providerSet{
bindings: []ifaceBinding{{
iface: iface,
provided: provided,
sets[name] = &ProviderSet{
Bindings: []IfaceBinding{{
Iface: iface,
Provided: provided,
}},
}
}
@@ -197,18 +283,30 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
}
if mod := sets[name]; mod != nil {
found := false
for _, other := range mod.imports {
if ref == other.symref {
for _, other := range mod.Imports {
if ref == other.symref() {
found = true
break
}
}
if !found {
mod.imports = append(mod.imports, providerSetImport{symref: ref, pos: d.pos})
mod.Imports = append(mod.Imports, ProviderSetImport{
ProviderSetID: ProviderSetID{
ImportPath: ref.importPath,
Name: ref.name,
},
Pos: d.pos,
})
}
} else {
sets[name] = &providerSet{
imports: []providerSetImport{{symref: ref, pos: d.pos}},
sets[name] = &ProviderSet{
Imports: []ProviderSetImport{{
ProviderSetID: ProviderSetID{
ImportPath: ref.importPath,
Name: ref.name,
},
Pos: d.pos,
}},
}
}
}
@@ -219,7 +317,7 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
}
// processDeclDirectives processes the directives associated with a top-level declaration.
func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope *types.Scope, dg directiveGroup) error {
func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, dg directiveGroup) error {
p, err := dg.single(fctx.fset, "provide")
if err != nil {
return err
@@ -258,15 +356,15 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
providerSetName = fn.Name()
}
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers {
if types.Identical(other.out, provider.out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
for _, other := range mod.Providers {
if types.Identical(other.Out, provider.Out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos))
}
}
mod.providers = append(mod.providers, provider)
mod.Providers = append(mod.Providers, provider)
} else {
sets[providerSetName] = &providerSet{
providers: []*providerInfo{provider},
sets[providerSetName] = &ProviderSet{
Providers: []*Provider{provider},
}
}
case *ast.GenDecl:
@@ -288,22 +386,22 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
if providerSetName == "" {
providerSetName = typeName.Name()
}
ptrProvider := new(providerInfo)
ptrProvider := new(Provider)
*ptrProvider = *provider
ptrProvider.out = types.NewPointer(provider.out)
ptrProvider.Out = types.NewPointer(provider.Out)
if mod := sets[providerSetName]; mod != nil {
for _, other := range mod.providers {
if types.Identical(other.out, provider.out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
for _, other := range mod.Providers {
if types.Identical(other.Out, provider.Out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos))
}
if types.Identical(other.out, ptrProvider.out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.out, nil), fctx.fset.Position(other.pos))
if types.Identical(other.Out, ptrProvider.Out) {
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.Out, nil), fctx.fset.Position(other.Pos))
}
}
mod.providers = append(mod.providers, provider, ptrProvider)
mod.Providers = append(mod.Providers, provider, ptrProvider)
} else {
sets[providerSetName] = &providerSet{
providers: []*providerInfo{provider, ptrProvider},
sets[providerSetName] = &ProviderSet{
Providers: []*Provider{provider, ptrProvider},
}
}
default:
@@ -312,7 +410,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
return nil
}
func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*providerInfo, error) {
func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*Provider, error) {
sig := fn.Type().(*types.Signature)
optionals := make([]bool, sig.Params().Len())
@@ -352,30 +450,30 @@ func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[stri
}
out := r.At(0).Type()
params := sig.Params()
provider := &providerInfo{
importPath: fctx.pkg.Path(),
name: fn.Name(),
pos: fn.Pos(),
args: make([]providerInput, params.Len()),
out: out,
hasCleanup: hasCleanup,
hasErr: hasErr,
provider := &Provider{
ImportPath: fctx.pkg.Path(),
Name: fn.Name(),
Pos: fn.Pos(),
Args: make([]ProviderInput, params.Len()),
Out: out,
HasCleanup: hasCleanup,
HasErr: hasErr,
}
for i := 0; i < params.Len(); i++ {
provider.args[i] = providerInput{
typ: params.At(i).Type(),
optional: optionals[i],
provider.Args[i] = ProviderInput{
Type: params.At(i).Type(),
Optional: optionals[i],
}
for j := 0; j < i; j++ {
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil))
if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil))
}
}
}
return provider, nil
}
func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*providerInfo, error) {
func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*Provider, error) {
out := typeName.Type()
st := out.Underlying().(*types.Struct)
for arg, dpos := range optionals {
@@ -392,26 +490,26 @@ func processStructProvider(fctx findContext, typeName *types.TypeName, optionals
}
pos := typeName.Pos()
provider := &providerInfo{
importPath: fctx.pkg.Path(),
name: typeName.Name(),
pos: pos,
args: make([]providerInput, st.NumFields()),
fields: make([]string, st.NumFields()),
isStruct: true,
out: out,
provider := &Provider{
ImportPath: fctx.pkg.Path(),
Name: typeName.Name(),
Pos: pos,
Args: make([]ProviderInput, st.NumFields()),
Fields: make([]string, st.NumFields()),
IsStruct: true,
Out: out,
}
for i := 0; i < st.NumFields(); i++ {
f := st.Field(i)
_, optional := optionals[f.Name()]
provider.args[i] = providerInput{
typ: f.Type(),
optional: optional,
provider.Args[i] = ProviderInput{
Type: f.Type(),
Optional: optional,
}
provider.fields[i] = f.Name()
provider.Fields[i] = f.Name()
for j := 0; j < i; j++ {
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.args[j].typ, nil))
if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.Args[j].Type, nil))
}
}
}
@@ -420,7 +518,7 @@ func processStructProvider(fctx findContext, typeName *types.TypeName, optionals
// providerSetCache is a lazily evaluated index of provider sets.
type providerSetCache struct {
sets map[string]map[string]*providerSet
sets map[string]map[string]*ProviderSet
fset *token.FileSet
prog *loader.Program
r *importResolver
@@ -434,7 +532,7 @@ func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCa
}
}
func (mc *providerSetCache) get(ref symref) (*providerSet, error) {
func (mc *providerSetCache) get(ref symref) (*ProviderSet, error) {
if mods, cached := mc.sets[ref.importPath]; cached {
mod := mods[ref.name]
if mod == nil {
@@ -443,7 +541,7 @@ func (mc *providerSetCache) get(ref symref) (*providerSet, error) {
return mod, nil
}
if mc.sets == nil {
mc.sets = make(map[string]map[string]*providerSet)
mc.sets = make(map[string]map[string]*ProviderSet)
}
pkg := mc.prog.Package(ref.importPath)
mods, err := findProviderSets(findContext{

342
main.go
View File

@@ -6,47 +6,333 @@ package main
import (
"fmt"
"go/build"
"go/token"
"go/types"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"sort"
"strings"
"codename/goose/internal/goose"
"golang.org/x/tools/go/types/typeutil"
)
func main() {
var pkg string
switch len(os.Args) {
case 1:
pkg = "."
case 2:
pkg = os.Args[1]
var err error
switch {
case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen":
err = generate(".")
case len(os.Args) == 2 && os.Args[1] == "show":
err = show(".")
case len(os.Args) == 2:
err = generate(os.Args[1])
case len(os.Args) > 2 && os.Args[1] == "show":
err = show(os.Args[2:]...)
case len(os.Args) == 3 && os.Args[1] == "gen":
err = generate(os.Args[2])
default:
fmt.Fprintln(os.Stderr, "goose: usage: goose [PKG]")
fmt.Fprintln(os.Stderr, "goose: usage: goose [gen] [PKG] | goose show [...]")
os.Exit(64)
}
wd, err := os.Getwd()
if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1)
}
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1)
}
out, err := goose.Generate(&build.Default, wd, pkg)
if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1)
}
if len(out) == 0 {
// No Goose directives, don't write anything.
fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg)
return
}
p := filepath.Join(pkgInfo.Dir, "goose_gen.go")
if err := ioutil.WriteFile(p, out, 0666); err != nil {
fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1)
}
}
// generate runs the gen subcommand. Given a package, gen will create
// the goose_gen.go file.
func generate(pkg string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
if err != nil {
return err
}
out, err := goose.Generate(&build.Default, wd, pkg)
if err != nil {
return err
}
if len(out) == 0 {
// No Goose directives, don't write anything.
fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg)
return nil
}
p := filepath.Join(pkgInfo.Dir, "goose_gen.go")
if err := ioutil.WriteFile(p, out, 0666); err != nil {
return err
}
return nil
}
// show runs the show subcommand.
//
// Given one or more packages, show will find all the declared provider
// sets and print what other provider sets it imports and what outputs
// it can produce, given possible inputs.
func show(pkgs ...string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
info, err := goose.Load(&build.Default, wd, pkgs)
if err != nil {
return err
}
keys := make([]goose.ProviderSetID, 0, len(info.Sets))
for k := range info.Sets {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].ImportPath == keys[j].ImportPath {
return keys[i].Name < keys[j].Name
}
return keys[i].ImportPath < keys[j].ImportPath
})
// ANSI color codes.
const (
reset = "\x1b[0m"
redBold = "\x1b[0;1;31m"
blue = "\x1b[0;34m"
green = "\x1b[0;32m"
)
for i, k := range keys {
if i > 0 {
fmt.Println()
}
outGroups, imports := gather(info, k)
fmt.Printf("%s%s%s\n", redBold, k, reset)
for _, imp := range sortSet(imports) {
fmt.Printf("\t%s\n", imp)
}
for i := range outGroups {
fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset)
out := make(map[string]token.Pos, outGroups[i].outputs.Len())
outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) {
switch v := v.(type) {
case *goose.Provider:
out[types.TypeString(t, nil)] = v.Pos
case goose.IfaceBinding:
out[types.TypeString(t, nil)] = v.Pos
default:
panic("unreachable")
}
})
for _, t := range sortSet(out) {
fmt.Printf("\t%s%s%s\n", green, t, reset)
fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t]))
}
}
}
return nil
}
type outGroup struct {
name string
inputs *typeutil.Map // values are not important
outputs *typeutil.Map // values are either *goose.Provider or goose.IfaceBinding
}
// gather flattens a provider set into outputs grouped by the inputs
// required to create them. As it flattens the provider set, it records
// the visited provider sets as imports.
func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports map[string]struct{}) {
hash := typeutil.MakeHasher()
// Map types to providers and bindings.
pm := new(typeutil.Map)
pm.SetHasher(hash)
next := []goose.ProviderSetID{key}
visited := make(map[goose.ProviderSetID]struct{})
imports = make(map[string]struct{})
for len(next) > 0 {
curr := next[len(next)-1]
next = next[:len(next)-1]
if _, found := visited[curr]; found {
continue
}
visited[curr] = struct{}{}
if curr != key {
imports[curr.String()] = struct{}{}
}
set := info.All[curr]
for _, p := range set.Providers {
pm.Set(p.Out, p)
}
for _, b := range set.Bindings {
pm.Set(b.Iface, b)
}
for _, imp := range set.Imports {
next = append(next, imp.ProviderSetID)
}
}
// Depth-first search to build groups.
var groups []outGroup
inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input.
inputVisited.SetHasher(hash)
pmKeys := pm.Keys()
var stk []types.Type
for _, k := range pmKeys {
// Start a DFS by picking a random unvisited node.
if inputVisited.At(k) == nil {
stk = append(stk, k)
}
// Run DFS
dfs:
for len(stk) > 0 {
curr := stk[len(stk)-1]
stk = stk[:len(stk)-1]
if inputVisited.At(curr) != nil {
continue
}
switch p := pm.At(curr).(type) {
case nil:
// This is an input.
inputVisited.Set(curr, -1)
case *goose.Provider:
// Try to see if any args haven't been visited.
allPresent := true
for _, arg := range p.Args {
if arg.Optional {
continue
}
if inputVisited.At(arg.Type) == nil {
allPresent = false
}
}
if !allPresent {
stk = append(stk, curr)
for _, arg := range p.Args {
if arg.Optional {
continue
}
if inputVisited.At(arg.Type) == nil {
stk = append(stk, arg.Type)
}
}
continue dfs
}
// Build up set of input types, match to a group.
in := new(typeutil.Map)
in.SetHasher(hash)
for _, arg := range p.Args {
if arg.Optional {
continue
}
i := inputVisited.At(arg.Type).(int)
if i == -1 {
in.Set(arg.Type, true)
} else {
mergeTypeSets(in, groups[i].inputs)
}
}
for i := range groups {
if sameTypeKeys(groups[i].inputs, in) {
groups[i].outputs.Set(p.Out, p)
inputVisited.Set(p.Out, i)
continue dfs
}
}
out := new(typeutil.Map)
out.SetHasher(hash)
out.Set(p.Out, p)
inputVisited.Set(p.Out, len(groups))
groups = append(groups, outGroup{
inputs: in,
outputs: out,
})
case goose.IfaceBinding:
i, ok := inputVisited.At(p.Provided).(int)
if !ok {
stk = append(stk, curr, p.Provided)
continue dfs
}
if i != -1 {
groups[i].outputs.Set(p.Iface, p)
inputVisited.Set(p.Iface, i)
continue dfs
}
// Binding must be provided. Find or add a group.
for i := range groups {
if groups[i].inputs.Len() != 1 {
continue
}
if groups[i].inputs.At(p.Provided) != nil {
groups[i].outputs.Set(p.Iface, p)
inputVisited.Set(p.Iface, i)
continue dfs
}
}
in := new(typeutil.Map)
in.SetHasher(hash)
in.Set(p.Provided, true)
out := new(typeutil.Map)
out.SetHasher(hash)
out.Set(p.Iface, p)
groups = append(groups, outGroup{
inputs: in,
outputs: out,
})
default:
panic("unreachable")
}
}
}
// Name and sort groups
for i := range groups {
if groups[i].inputs.Len() == 0 {
groups[i].name = "no inputs"
continue
}
instr := make([]string, 0, groups[i].inputs.Len())
groups[i].inputs.Iterate(func(k types.Type, _ interface{}) {
instr = append(instr, types.TypeString(k, nil))
})
sort.Strings(instr)
groups[i].name = strings.Join(instr, ", ")
}
sort.Slice(groups, func(i, j int) bool {
if groups[i].inputs.Len() == groups[j].inputs.Len() {
return groups[i].name < groups[j].name
}
return groups[i].inputs.Len() < groups[j].inputs.Len()
})
return groups, imports
}
func mergeTypeSets(dst, src *typeutil.Map) {
src.Iterate(func(k types.Type, _ interface{}) {
dst.Set(k, true)
})
}
func sameTypeKeys(a, b *typeutil.Map) bool {
if a.Len() != b.Len() {
return false
}
same := true
a.Iterate(func(k types.Type, _ interface{}) {
if b.At(k) == nil {
same = false
}
})
return same
}
func sortSet(set interface{}) []string {
rv := reflect.ValueOf(set)
a := make([]string, 0, rv.Len())
keys := rv.MapKeys()
for _, k := range keys {
a = append(a, k.String())
}
sort.Strings(a)
return a
}