goose: add show command
Lists provider sets in packages given on the command line, including outputs grouped by what is needed to obtain them. The goose package now exports the loading phase as an API. Example output: https://paste.googleplex.com/5509965720584192 Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
@@ -60,8 +60,8 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
|
||||
index := new(typeutil.Map)
|
||||
for i, g := range given {
|
||||
if p := providers.At(g); p != nil {
|
||||
pp := p.(*providerInfo)
|
||||
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.name, mc.fset.Position(pp.pos))
|
||||
pp := p.(*Provider)
|
||||
return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, mc.fset.Position(pp.Pos))
|
||||
}
|
||||
index.Set(g, i)
|
||||
}
|
||||
@@ -70,49 +70,49 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
|
||||
// using a depth-first search. The graph may contain cycles, which
|
||||
// should trigger an error.
|
||||
var calls []call
|
||||
var visit func(trail []providerInput) error
|
||||
visit = func(trail []providerInput) error {
|
||||
typ := trail[len(trail)-1].typ
|
||||
var visit func(trail []ProviderInput) error
|
||||
visit = func(trail []ProviderInput) error {
|
||||
typ := trail[len(trail)-1].Type
|
||||
if index.At(typ) != nil {
|
||||
return nil
|
||||
}
|
||||
for _, in := range trail[:len(trail)-1] {
|
||||
if types.Identical(typ, in.typ) {
|
||||
if types.Identical(typ, in.Type) {
|
||||
// TODO(light): describe cycle
|
||||
return fmt.Errorf("cycle for %s", types.TypeString(typ, nil))
|
||||
}
|
||||
}
|
||||
|
||||
p, _ := providers.At(typ).(*providerInfo)
|
||||
p, _ := providers.At(typ).(*Provider)
|
||||
if p == nil {
|
||||
if trail[len(trail)-1].optional {
|
||||
if trail[len(trail)-1].Optional {
|
||||
return nil
|
||||
}
|
||||
if len(trail) == 1 {
|
||||
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
|
||||
}
|
||||
// TODO(light): give name of provider
|
||||
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil))
|
||||
return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil))
|
||||
}
|
||||
if !types.Identical(p.out, typ) {
|
||||
if !types.Identical(p.Out, typ) {
|
||||
// Interface binding. Don't create a call ourselves.
|
||||
if err := visit(append(trail, providerInput{typ: p.out})); err != nil {
|
||||
if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil {
|
||||
return err
|
||||
}
|
||||
index.Set(typ, index.At(p.out))
|
||||
index.Set(typ, index.At(p.Out))
|
||||
return nil
|
||||
}
|
||||
for _, a := range p.args {
|
||||
for _, a := range p.Args {
|
||||
// TODO(light): this will discard grown trail arrays.
|
||||
if err := visit(append(trail, a)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
args := make([]int, len(p.args))
|
||||
ins := make([]types.Type, len(p.args))
|
||||
for i := range p.args {
|
||||
ins[i] = p.args[i].typ
|
||||
if x := index.At(p.args[i].typ); x != nil {
|
||||
args := make([]int, len(p.Args))
|
||||
ins := make([]types.Type, len(p.Args))
|
||||
for i := range p.Args {
|
||||
ins[i] = p.Args[i].Type
|
||||
if x := index.At(p.Args[i].Type); x != nil {
|
||||
args[i] = x.(int)
|
||||
} else {
|
||||
args[i] = -1
|
||||
@@ -120,19 +120,19 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
|
||||
}
|
||||
index.Set(typ, len(given)+len(calls))
|
||||
calls = append(calls, call{
|
||||
importPath: p.importPath,
|
||||
name: p.name,
|
||||
importPath: p.ImportPath,
|
||||
name: p.Name,
|
||||
args: args,
|
||||
isStruct: p.isStruct,
|
||||
fieldNames: p.fields,
|
||||
isStruct: p.IsStruct,
|
||||
fieldNames: p.Fields,
|
||||
ins: ins,
|
||||
out: typ,
|
||||
hasCleanup: p.hasCleanup,
|
||||
hasErr: p.hasErr,
|
||||
hasCleanup: p.HasCleanup,
|
||||
hasErr: p.HasErr,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
if err := visit([]providerInput{{typ: out}}); err != nil {
|
||||
if err := visit([]ProviderInput{{Type: out}}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return calls, nil
|
||||
@@ -146,7 +146,7 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error
|
||||
pos token.Pos
|
||||
}
|
||||
type binding struct {
|
||||
ifaceBinding
|
||||
IfaceBinding
|
||||
pset symref
|
||||
from symref
|
||||
}
|
||||
@@ -173,53 +173,53 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error
|
||||
}
|
||||
return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err)
|
||||
}
|
||||
for _, p := range pset.providers {
|
||||
if prev := pm.At(p.out); prev != nil {
|
||||
pos := mc.fset.Position(p.pos)
|
||||
typ := types.TypeString(p.out, nil)
|
||||
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
|
||||
for _, p := range pset.Providers {
|
||||
if prev := pm.At(p.Out); prev != nil {
|
||||
pos := mc.fset.Position(p.Pos)
|
||||
typ := types.TypeString(p.Out, nil)
|
||||
prevPos := mc.fset.Position(prev.(*Provider).Pos)
|
||||
if curr.from.importPath == "" {
|
||||
// Provider set is imported directly by injector.
|
||||
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
|
||||
}
|
||||
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos)
|
||||
}
|
||||
pm.Set(p.out, p)
|
||||
pm.Set(p.Out, p)
|
||||
}
|
||||
for _, b := range pset.bindings {
|
||||
for _, b := range pset.Bindings {
|
||||
bindings = append(bindings, binding{
|
||||
ifaceBinding: b,
|
||||
IfaceBinding: b,
|
||||
pset: curr.to,
|
||||
from: curr.from,
|
||||
})
|
||||
}
|
||||
for _, imp := range pset.imports {
|
||||
next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos})
|
||||
for _, imp := range pset.Imports {
|
||||
next = append(next, nextEnt{to: imp.symref(), from: curr.to, pos: imp.Pos})
|
||||
}
|
||||
}
|
||||
for _, b := range bindings {
|
||||
if prev := pm.At(b.iface); prev != nil {
|
||||
pos := mc.fset.Position(b.pos)
|
||||
typ := types.TypeString(b.iface, nil)
|
||||
if prev := pm.At(b.Iface); prev != nil {
|
||||
pos := mc.fset.Position(b.Pos)
|
||||
typ := types.TypeString(b.Iface, nil)
|
||||
// TODO(light): error message for conflicting with another interface binding will point at provider instead of binding.
|
||||
prevPos := mc.fset.Position(prev.(*providerInfo).pos)
|
||||
prevPos := mc.fset.Position(prev.(*Provider).Pos)
|
||||
if b.from.importPath == "" {
|
||||
// Provider set is imported directly by injector.
|
||||
return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos)
|
||||
}
|
||||
return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos)
|
||||
}
|
||||
concrete := pm.At(b.provided)
|
||||
concrete := pm.At(b.Provided)
|
||||
if concrete == nil {
|
||||
pos := mc.fset.Position(b.pos)
|
||||
typ := types.TypeString(b.provided, nil)
|
||||
pos := mc.fset.Position(b.Pos)
|
||||
typ := types.TypeString(b.Provided, nil)
|
||||
if b.from.importPath == "" {
|
||||
// Concrete provider is imported directly by injector.
|
||||
return nil, fmt.Errorf("%v: no binding for %s", pos, typ)
|
||||
}
|
||||
return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from)
|
||||
}
|
||||
pm.Set(b.iface, concrete)
|
||||
pm.Set(b.Iface, concrete)
|
||||
}
|
||||
return pm, nil
|
||||
}
|
||||
|
||||
@@ -22,17 +22,7 @@ import (
|
||||
// Generate performs dependency injection for a single package,
|
||||
// returning the gofmt'd Go source code.
|
||||
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||
// TODO(light): allow errors
|
||||
// TODO(light): stop errors from printing to stderr
|
||||
conf := &loader.Config{
|
||||
Build: new(build.Context),
|
||||
ParserMode: parser.ParseComments,
|
||||
Cwd: wd,
|
||||
TypeCheckFuncBodies: func(string) bool { return false },
|
||||
}
|
||||
*conf.Build = *bctx
|
||||
n := len(conf.Build.BuildTags)
|
||||
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
|
||||
conf := newLoaderConfig(bctx, wd, true)
|
||||
conf.Import(pkg)
|
||||
prog, err := conf.Load()
|
||||
if err != nil {
|
||||
@@ -99,6 +89,24 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
|
||||
return fmtSrc, nil
|
||||
}
|
||||
|
||||
func newLoaderConfig(bctx *build.Context, wd string, inject bool) *loader.Config {
|
||||
// TODO(light): allow errors
|
||||
// TODO(light): stop errors from printing to stderr
|
||||
conf := &loader.Config{
|
||||
Build: bctx,
|
||||
ParserMode: parser.ParseComments,
|
||||
Cwd: wd,
|
||||
TypeCheckFuncBodies: func(string) bool { return false },
|
||||
}
|
||||
if inject {
|
||||
conf.Build = new(build.Context)
|
||||
*conf.Build = *bctx
|
||||
n := len(conf.Build.BuildTags)
|
||||
conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject")
|
||||
}
|
||||
return conf
|
||||
}
|
||||
|
||||
// gen is the generator state.
|
||||
type gen struct {
|
||||
currPackage string
|
||||
|
||||
@@ -14,72 +14,158 @@ import (
|
||||
"golang.org/x/tools/go/loader"
|
||||
)
|
||||
|
||||
// A providerSet describes a set of providers. The zero value is an empty
|
||||
// providerSet.
|
||||
type providerSet struct {
|
||||
providers []*providerInfo
|
||||
bindings []ifaceBinding
|
||||
imports []providerSetImport
|
||||
// A ProviderSet describes a set of providers. The zero value is an empty
|
||||
// ProviderSet.
|
||||
type ProviderSet struct {
|
||||
Providers []*Provider
|
||||
Bindings []IfaceBinding
|
||||
Imports []ProviderSetImport
|
||||
}
|
||||
|
||||
// An ifaceBinding declares that a type should be used to satisfy inputs
|
||||
// An IfaceBinding declares that a type should be used to satisfy inputs
|
||||
// of the given interface type.
|
||||
//
|
||||
// provided is always a type that is assignable to iface.
|
||||
type ifaceBinding struct {
|
||||
// iface is the interface type, which is what can be injected.
|
||||
iface types.Type
|
||||
type IfaceBinding struct {
|
||||
// Iface is the interface type, which is what can be injected.
|
||||
Iface types.Type
|
||||
|
||||
// provided is always a type that is assignable to Iface.
|
||||
provided types.Type
|
||||
// Provided is always a type that is assignable to Iface.
|
||||
Provided types.Type
|
||||
|
||||
// pos is the position where the binding was declared.
|
||||
pos token.Pos
|
||||
// Pos is the position where the binding was declared.
|
||||
Pos token.Pos
|
||||
}
|
||||
|
||||
type providerSetImport struct {
|
||||
symref
|
||||
pos token.Pos
|
||||
// A ProviderSetImport adds providers from one provider set into another.
|
||||
type ProviderSetImport struct {
|
||||
ProviderSetID
|
||||
Pos token.Pos
|
||||
}
|
||||
|
||||
// providerInfo records the signature of a provider.
|
||||
type providerInfo struct {
|
||||
// importPath is the package path that the Go object resides in.
|
||||
importPath string
|
||||
// Provider records the signature of a provider. A provider is a
|
||||
// single Go object, either a function or a named struct type.
|
||||
type Provider struct {
|
||||
// ImportPath is the package path that the Go object resides in.
|
||||
ImportPath string
|
||||
|
||||
// name is the name of the Go object.
|
||||
name string
|
||||
// Name is the name of the Go object.
|
||||
Name string
|
||||
|
||||
// pos is the source position of the func keyword or type spec
|
||||
// Pos is the source position of the func keyword or type spec
|
||||
// defining this provider.
|
||||
pos token.Pos
|
||||
Pos token.Pos
|
||||
|
||||
// args is the list of data dependencies this provider has.
|
||||
args []providerInput
|
||||
// Args is the list of data dependencies this provider has.
|
||||
Args []ProviderInput
|
||||
|
||||
// isStruct is true if this provider is a named struct type.
|
||||
// IsStruct is true if this provider is a named struct type.
|
||||
// Otherwise it's a function.
|
||||
isStruct bool
|
||||
IsStruct bool
|
||||
|
||||
// fields lists the field names to populate. This will map 1:1 with
|
||||
// Fields lists the field names to populate. This will map 1:1 with
|
||||
// elements in Args.
|
||||
fields []string
|
||||
Fields []string
|
||||
|
||||
// out is the type this provider produces.
|
||||
out types.Type
|
||||
// Out is the type this provider produces.
|
||||
Out types.Type
|
||||
|
||||
// hasCleanup reports whether the provider function returns a cleanup
|
||||
// HasCleanup reports whether the provider function returns a cleanup
|
||||
// function. (Always false for structs.)
|
||||
hasCleanup bool
|
||||
HasCleanup bool
|
||||
|
||||
// hasErr reports whether the provider function can return an error.
|
||||
// HasErr reports whether the provider function can return an error.
|
||||
// (Always false for structs.)
|
||||
hasErr bool
|
||||
HasErr bool
|
||||
}
|
||||
|
||||
type providerInput struct {
|
||||
typ types.Type
|
||||
optional bool
|
||||
// ProviderInput describes an incoming edge in the provider graph.
|
||||
type ProviderInput struct {
|
||||
Type types.Type
|
||||
Optional bool
|
||||
}
|
||||
|
||||
// Load finds all the provider sets in the given packages, as well as
|
||||
// the provider sets' transitive dependencies.
|
||||
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) {
|
||||
conf := newLoaderConfig(bctx, wd, false)
|
||||
for _, p := range pkgs {
|
||||
conf.Import(p)
|
||||
}
|
||||
prog, err := conf.Load()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load: %v", err)
|
||||
}
|
||||
r := newImportResolver(conf, prog.Fset)
|
||||
var next []string
|
||||
initial := make(map[string]struct{})
|
||||
for _, pkgInfo := range prog.InitialPackages() {
|
||||
path := pkgInfo.Pkg.Path()
|
||||
next = append(next, path)
|
||||
initial[path] = struct{}{}
|
||||
}
|
||||
visited := make(map[string]struct{})
|
||||
info := &Info{
|
||||
Fset: prog.Fset,
|
||||
Sets: make(map[ProviderSetID]*ProviderSet),
|
||||
All: make(map[ProviderSetID]*ProviderSet),
|
||||
}
|
||||
for len(next) > 0 {
|
||||
curr := next[len(next)-1]
|
||||
next = next[:len(next)-1]
|
||||
if _, ok := visited[curr]; ok {
|
||||
continue
|
||||
}
|
||||
visited[curr] = struct{}{}
|
||||
pkgInfo := prog.Package(curr)
|
||||
sets, err := findProviderSets(findContext{
|
||||
fset: prog.Fset,
|
||||
pkg: pkgInfo.Pkg,
|
||||
typeInfo: &pkgInfo.Info,
|
||||
r: r,
|
||||
}, pkgInfo.Files)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load: %v", err)
|
||||
}
|
||||
path := pkgInfo.Pkg.Path()
|
||||
for name, set := range sets {
|
||||
info.All[ProviderSetID{path, name}] = set
|
||||
for _, imp := range set.Imports {
|
||||
next = append(next, imp.ImportPath)
|
||||
}
|
||||
}
|
||||
if _, ok := initial[path]; ok {
|
||||
for name, set := range sets {
|
||||
info.Sets[ProviderSetID{path, name}] = set
|
||||
}
|
||||
}
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// Info holds the result of Load.
|
||||
type Info struct {
|
||||
Fset *token.FileSet
|
||||
|
||||
// Sets contains all the provider sets in the initial packages.
|
||||
Sets map[ProviderSetID]*ProviderSet
|
||||
|
||||
// All contains all the provider sets transitively depended on by the
|
||||
// initial packages' provider sets.
|
||||
All map[ProviderSetID]*ProviderSet
|
||||
}
|
||||
|
||||
// A ProviderSetID identifies a provider set.
|
||||
type ProviderSetID struct {
|
||||
ImportPath string
|
||||
Name string
|
||||
}
|
||||
|
||||
// String returns the ID as ""path/to/pkg".Foo".
|
||||
func (id ProviderSetID) String() string {
|
||||
return id.symref().String()
|
||||
}
|
||||
|
||||
func (id ProviderSetID) symref() symref {
|
||||
return symref{importPath: id.ImportPath, name: id.Name}
|
||||
}
|
||||
|
||||
type findContext struct {
|
||||
@@ -90,8 +176,8 @@ type findContext struct {
|
||||
}
|
||||
|
||||
// findProviderSets processes a package and extracts the provider sets declared in it.
|
||||
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*providerSet, error) {
|
||||
sets := make(map[string]*providerSet)
|
||||
func findProviderSets(fctx findContext, files []*ast.File) (map[string]*ProviderSet, error) {
|
||||
sets := make(map[string]*ProviderSet)
|
||||
for _, f := range files {
|
||||
fileScope := fctx.typeInfo.Scopes[f]
|
||||
if fileScope == nil {
|
||||
@@ -115,7 +201,7 @@ func findProviderSets(fctx findContext, files []*ast.File) (map[string]*provider
|
||||
}
|
||||
|
||||
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
|
||||
func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet, scope *types.Scope, d directive) error {
|
||||
func processUnassociatedDirective(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, d directive) error {
|
||||
switch d.kind {
|
||||
case "provide", "optional":
|
||||
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
|
||||
@@ -169,15 +255,15 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
|
||||
|
||||
name := args[0]
|
||||
if pset := sets[name]; pset != nil {
|
||||
pset.bindings = append(pset.bindings, ifaceBinding{
|
||||
iface: iface,
|
||||
provided: provided,
|
||||
pset.Bindings = append(pset.Bindings, IfaceBinding{
|
||||
Iface: iface,
|
||||
Provided: provided,
|
||||
})
|
||||
} else {
|
||||
sets[name] = &providerSet{
|
||||
bindings: []ifaceBinding{{
|
||||
iface: iface,
|
||||
provided: provided,
|
||||
sets[name] = &ProviderSet{
|
||||
Bindings: []IfaceBinding{{
|
||||
Iface: iface,
|
||||
Provided: provided,
|
||||
}},
|
||||
}
|
||||
}
|
||||
@@ -197,18 +283,30 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
|
||||
}
|
||||
if mod := sets[name]; mod != nil {
|
||||
found := false
|
||||
for _, other := range mod.imports {
|
||||
if ref == other.symref {
|
||||
for _, other := range mod.Imports {
|
||||
if ref == other.symref() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mod.imports = append(mod.imports, providerSetImport{symref: ref, pos: d.pos})
|
||||
mod.Imports = append(mod.Imports, ProviderSetImport{
|
||||
ProviderSetID: ProviderSetID{
|
||||
ImportPath: ref.importPath,
|
||||
Name: ref.name,
|
||||
},
|
||||
Pos: d.pos,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
sets[name] = &providerSet{
|
||||
imports: []providerSetImport{{symref: ref, pos: d.pos}},
|
||||
sets[name] = &ProviderSet{
|
||||
Imports: []ProviderSetImport{{
|
||||
ProviderSetID: ProviderSetID{
|
||||
ImportPath: ref.importPath,
|
||||
Name: ref.name,
|
||||
},
|
||||
Pos: d.pos,
|
||||
}},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -219,7 +317,7 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
|
||||
}
|
||||
|
||||
// processDeclDirectives processes the directives associated with a top-level declaration.
|
||||
func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope *types.Scope, dg directiveGroup) error {
|
||||
func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, dg directiveGroup) error {
|
||||
p, err := dg.single(fctx.fset, "provide")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -258,15 +356,15 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
|
||||
providerSetName = fn.Name()
|
||||
}
|
||||
if mod := sets[providerSetName]; mod != nil {
|
||||
for _, other := range mod.providers {
|
||||
if types.Identical(other.out, provider.out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
|
||||
for _, other := range mod.Providers {
|
||||
if types.Identical(other.Out, provider.Out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos))
|
||||
}
|
||||
}
|
||||
mod.providers = append(mod.providers, provider)
|
||||
mod.Providers = append(mod.Providers, provider)
|
||||
} else {
|
||||
sets[providerSetName] = &providerSet{
|
||||
providers: []*providerInfo{provider},
|
||||
sets[providerSetName] = &ProviderSet{
|
||||
Providers: []*Provider{provider},
|
||||
}
|
||||
}
|
||||
case *ast.GenDecl:
|
||||
@@ -288,22 +386,22 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
|
||||
if providerSetName == "" {
|
||||
providerSetName = typeName.Name()
|
||||
}
|
||||
ptrProvider := new(providerInfo)
|
||||
ptrProvider := new(Provider)
|
||||
*ptrProvider = *provider
|
||||
ptrProvider.out = types.NewPointer(provider.out)
|
||||
ptrProvider.Out = types.NewPointer(provider.Out)
|
||||
if mod := sets[providerSetName]; mod != nil {
|
||||
for _, other := range mod.providers {
|
||||
if types.Identical(other.out, provider.out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos))
|
||||
for _, other := range mod.Providers {
|
||||
if types.Identical(other.Out, provider.Out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos))
|
||||
}
|
||||
if types.Identical(other.out, ptrProvider.out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.out, nil), fctx.fset.Position(other.pos))
|
||||
if types.Identical(other.Out, ptrProvider.Out) {
|
||||
return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.Out, nil), fctx.fset.Position(other.Pos))
|
||||
}
|
||||
}
|
||||
mod.providers = append(mod.providers, provider, ptrProvider)
|
||||
mod.Providers = append(mod.Providers, provider, ptrProvider)
|
||||
} else {
|
||||
sets[providerSetName] = &providerSet{
|
||||
providers: []*providerInfo{provider, ptrProvider},
|
||||
sets[providerSetName] = &ProviderSet{
|
||||
Providers: []*Provider{provider, ptrProvider},
|
||||
}
|
||||
}
|
||||
default:
|
||||
@@ -312,7 +410,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope
|
||||
return nil
|
||||
}
|
||||
|
||||
func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*providerInfo, error) {
|
||||
func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*Provider, error) {
|
||||
sig := fn.Type().(*types.Signature)
|
||||
|
||||
optionals := make([]bool, sig.Params().Len())
|
||||
@@ -352,30 +450,30 @@ func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[stri
|
||||
}
|
||||
out := r.At(0).Type()
|
||||
params := sig.Params()
|
||||
provider := &providerInfo{
|
||||
importPath: fctx.pkg.Path(),
|
||||
name: fn.Name(),
|
||||
pos: fn.Pos(),
|
||||
args: make([]providerInput, params.Len()),
|
||||
out: out,
|
||||
hasCleanup: hasCleanup,
|
||||
hasErr: hasErr,
|
||||
provider := &Provider{
|
||||
ImportPath: fctx.pkg.Path(),
|
||||
Name: fn.Name(),
|
||||
Pos: fn.Pos(),
|
||||
Args: make([]ProviderInput, params.Len()),
|
||||
Out: out,
|
||||
HasCleanup: hasCleanup,
|
||||
HasErr: hasErr,
|
||||
}
|
||||
for i := 0; i < params.Len(); i++ {
|
||||
provider.args[i] = providerInput{
|
||||
typ: params.At(i).Type(),
|
||||
optional: optionals[i],
|
||||
provider.Args[i] = ProviderInput{
|
||||
Type: params.At(i).Type(),
|
||||
Optional: optionals[i],
|
||||
}
|
||||
for j := 0; j < i; j++ {
|
||||
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
|
||||
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil))
|
||||
if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
|
||||
return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil))
|
||||
}
|
||||
}
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*providerInfo, error) {
|
||||
func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*Provider, error) {
|
||||
out := typeName.Type()
|
||||
st := out.Underlying().(*types.Struct)
|
||||
for arg, dpos := range optionals {
|
||||
@@ -392,26 +490,26 @@ func processStructProvider(fctx findContext, typeName *types.TypeName, optionals
|
||||
}
|
||||
|
||||
pos := typeName.Pos()
|
||||
provider := &providerInfo{
|
||||
importPath: fctx.pkg.Path(),
|
||||
name: typeName.Name(),
|
||||
pos: pos,
|
||||
args: make([]providerInput, st.NumFields()),
|
||||
fields: make([]string, st.NumFields()),
|
||||
isStruct: true,
|
||||
out: out,
|
||||
provider := &Provider{
|
||||
ImportPath: fctx.pkg.Path(),
|
||||
Name: typeName.Name(),
|
||||
Pos: pos,
|
||||
Args: make([]ProviderInput, st.NumFields()),
|
||||
Fields: make([]string, st.NumFields()),
|
||||
IsStruct: true,
|
||||
Out: out,
|
||||
}
|
||||
for i := 0; i < st.NumFields(); i++ {
|
||||
f := st.Field(i)
|
||||
_, optional := optionals[f.Name()]
|
||||
provider.args[i] = providerInput{
|
||||
typ: f.Type(),
|
||||
optional: optional,
|
||||
provider.Args[i] = ProviderInput{
|
||||
Type: f.Type(),
|
||||
Optional: optional,
|
||||
}
|
||||
provider.fields[i] = f.Name()
|
||||
provider.Fields[i] = f.Name()
|
||||
for j := 0; j < i; j++ {
|
||||
if types.Identical(provider.args[i].typ, provider.args[j].typ) {
|
||||
return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.args[j].typ, nil))
|
||||
if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
|
||||
return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.Args[j].Type, nil))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -420,7 +518,7 @@ func processStructProvider(fctx findContext, typeName *types.TypeName, optionals
|
||||
|
||||
// providerSetCache is a lazily evaluated index of provider sets.
|
||||
type providerSetCache struct {
|
||||
sets map[string]map[string]*providerSet
|
||||
sets map[string]map[string]*ProviderSet
|
||||
fset *token.FileSet
|
||||
prog *loader.Program
|
||||
r *importResolver
|
||||
@@ -434,7 +532,7 @@ func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCa
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *providerSetCache) get(ref symref) (*providerSet, error) {
|
||||
func (mc *providerSetCache) get(ref symref) (*ProviderSet, error) {
|
||||
if mods, cached := mc.sets[ref.importPath]; cached {
|
||||
mod := mods[ref.name]
|
||||
if mod == nil {
|
||||
@@ -443,7 +541,7 @@ func (mc *providerSetCache) get(ref symref) (*providerSet, error) {
|
||||
return mod, nil
|
||||
}
|
||||
if mc.sets == nil {
|
||||
mc.sets = make(map[string]map[string]*providerSet)
|
||||
mc.sets = make(map[string]map[string]*ProviderSet)
|
||||
}
|
||||
pkg := mc.prog.Package(ref.importPath)
|
||||
mods, err := findProviderSets(findContext{
|
||||
|
||||
342
main.go
342
main.go
@@ -6,47 +6,333 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"go/build"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"codename/goose/internal/goose"
|
||||
"golang.org/x/tools/go/types/typeutil"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var pkg string
|
||||
switch len(os.Args) {
|
||||
case 1:
|
||||
pkg = "."
|
||||
case 2:
|
||||
pkg = os.Args[1]
|
||||
var err error
|
||||
switch {
|
||||
case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen":
|
||||
err = generate(".")
|
||||
case len(os.Args) == 2 && os.Args[1] == "show":
|
||||
err = show(".")
|
||||
case len(os.Args) == 2:
|
||||
err = generate(os.Args[1])
|
||||
case len(os.Args) > 2 && os.Args[1] == "show":
|
||||
err = show(os.Args[2:]...)
|
||||
case len(os.Args) == 3 && os.Args[1] == "gen":
|
||||
err = generate(os.Args[2])
|
||||
default:
|
||||
fmt.Fprintln(os.Stderr, "goose: usage: goose [PKG]")
|
||||
fmt.Fprintln(os.Stderr, "goose: usage: goose [gen] [PKG] | goose show [...]")
|
||||
os.Exit(64)
|
||||
}
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "goose:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "goose:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
out, err := goose.Generate(&build.Default, wd, pkg)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "goose:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
// No Goose directives, don't write anything.
|
||||
fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg)
|
||||
return
|
||||
}
|
||||
p := filepath.Join(pkgInfo.Dir, "goose_gen.go")
|
||||
if err := ioutil.WriteFile(p, out, 0666); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "goose:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// generate runs the gen subcommand. Given a package, gen will create
|
||||
// the goose_gen.go file.
|
||||
func generate(pkg string) error {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out, err := goose.Generate(&build.Default, wd, pkg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(out) == 0 {
|
||||
// No Goose directives, don't write anything.
|
||||
fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg)
|
||||
return nil
|
||||
}
|
||||
p := filepath.Join(pkgInfo.Dir, "goose_gen.go")
|
||||
if err := ioutil.WriteFile(p, out, 0666); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// show runs the show subcommand.
|
||||
//
|
||||
// Given one or more packages, show will find all the declared provider
|
||||
// sets and print what other provider sets it imports and what outputs
|
||||
// it can produce, given possible inputs.
|
||||
func show(pkgs ...string) error {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info, err := goose.Load(&build.Default, wd, pkgs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keys := make([]goose.ProviderSetID, 0, len(info.Sets))
|
||||
for k := range info.Sets {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
if keys[i].ImportPath == keys[j].ImportPath {
|
||||
return keys[i].Name < keys[j].Name
|
||||
}
|
||||
return keys[i].ImportPath < keys[j].ImportPath
|
||||
})
|
||||
// ANSI color codes.
|
||||
const (
|
||||
reset = "\x1b[0m"
|
||||
redBold = "\x1b[0;1;31m"
|
||||
blue = "\x1b[0;34m"
|
||||
green = "\x1b[0;32m"
|
||||
)
|
||||
for i, k := range keys {
|
||||
if i > 0 {
|
||||
fmt.Println()
|
||||
}
|
||||
outGroups, imports := gather(info, k)
|
||||
fmt.Printf("%s%s%s\n", redBold, k, reset)
|
||||
for _, imp := range sortSet(imports) {
|
||||
fmt.Printf("\t%s\n", imp)
|
||||
}
|
||||
for i := range outGroups {
|
||||
fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset)
|
||||
out := make(map[string]token.Pos, outGroups[i].outputs.Len())
|
||||
outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) {
|
||||
switch v := v.(type) {
|
||||
case *goose.Provider:
|
||||
out[types.TypeString(t, nil)] = v.Pos
|
||||
case goose.IfaceBinding:
|
||||
out[types.TypeString(t, nil)] = v.Pos
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
})
|
||||
for _, t := range sortSet(out) {
|
||||
fmt.Printf("\t%s%s%s\n", green, t, reset)
|
||||
fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t]))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type outGroup struct {
|
||||
name string
|
||||
inputs *typeutil.Map // values are not important
|
||||
outputs *typeutil.Map // values are either *goose.Provider or goose.IfaceBinding
|
||||
}
|
||||
|
||||
// gather flattens a provider set into outputs grouped by the inputs
|
||||
// required to create them. As it flattens the provider set, it records
|
||||
// the visited provider sets as imports.
|
||||
func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports map[string]struct{}) {
|
||||
hash := typeutil.MakeHasher()
|
||||
// Map types to providers and bindings.
|
||||
pm := new(typeutil.Map)
|
||||
pm.SetHasher(hash)
|
||||
next := []goose.ProviderSetID{key}
|
||||
visited := make(map[goose.ProviderSetID]struct{})
|
||||
imports = make(map[string]struct{})
|
||||
for len(next) > 0 {
|
||||
curr := next[len(next)-1]
|
||||
next = next[:len(next)-1]
|
||||
if _, found := visited[curr]; found {
|
||||
continue
|
||||
}
|
||||
visited[curr] = struct{}{}
|
||||
if curr != key {
|
||||
imports[curr.String()] = struct{}{}
|
||||
}
|
||||
set := info.All[curr]
|
||||
for _, p := range set.Providers {
|
||||
pm.Set(p.Out, p)
|
||||
}
|
||||
for _, b := range set.Bindings {
|
||||
pm.Set(b.Iface, b)
|
||||
}
|
||||
for _, imp := range set.Imports {
|
||||
next = append(next, imp.ProviderSetID)
|
||||
}
|
||||
}
|
||||
|
||||
// Depth-first search to build groups.
|
||||
var groups []outGroup
|
||||
inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input.
|
||||
inputVisited.SetHasher(hash)
|
||||
pmKeys := pm.Keys()
|
||||
var stk []types.Type
|
||||
for _, k := range pmKeys {
|
||||
// Start a DFS by picking a random unvisited node.
|
||||
if inputVisited.At(k) == nil {
|
||||
stk = append(stk, k)
|
||||
}
|
||||
|
||||
// Run DFS
|
||||
dfs:
|
||||
for len(stk) > 0 {
|
||||
curr := stk[len(stk)-1]
|
||||
stk = stk[:len(stk)-1]
|
||||
if inputVisited.At(curr) != nil {
|
||||
continue
|
||||
}
|
||||
switch p := pm.At(curr).(type) {
|
||||
case nil:
|
||||
// This is an input.
|
||||
inputVisited.Set(curr, -1)
|
||||
case *goose.Provider:
|
||||
// Try to see if any args haven't been visited.
|
||||
allPresent := true
|
||||
for _, arg := range p.Args {
|
||||
if arg.Optional {
|
||||
continue
|
||||
}
|
||||
if inputVisited.At(arg.Type) == nil {
|
||||
allPresent = false
|
||||
}
|
||||
}
|
||||
if !allPresent {
|
||||
stk = append(stk, curr)
|
||||
for _, arg := range p.Args {
|
||||
if arg.Optional {
|
||||
continue
|
||||
}
|
||||
if inputVisited.At(arg.Type) == nil {
|
||||
stk = append(stk, arg.Type)
|
||||
}
|
||||
}
|
||||
continue dfs
|
||||
}
|
||||
|
||||
// Build up set of input types, match to a group.
|
||||
in := new(typeutil.Map)
|
||||
in.SetHasher(hash)
|
||||
for _, arg := range p.Args {
|
||||
if arg.Optional {
|
||||
continue
|
||||
}
|
||||
i := inputVisited.At(arg.Type).(int)
|
||||
if i == -1 {
|
||||
in.Set(arg.Type, true)
|
||||
} else {
|
||||
mergeTypeSets(in, groups[i].inputs)
|
||||
}
|
||||
}
|
||||
for i := range groups {
|
||||
if sameTypeKeys(groups[i].inputs, in) {
|
||||
groups[i].outputs.Set(p.Out, p)
|
||||
inputVisited.Set(p.Out, i)
|
||||
continue dfs
|
||||
}
|
||||
}
|
||||
out := new(typeutil.Map)
|
||||
out.SetHasher(hash)
|
||||
out.Set(p.Out, p)
|
||||
inputVisited.Set(p.Out, len(groups))
|
||||
groups = append(groups, outGroup{
|
||||
inputs: in,
|
||||
outputs: out,
|
||||
})
|
||||
case goose.IfaceBinding:
|
||||
i, ok := inputVisited.At(p.Provided).(int)
|
||||
if !ok {
|
||||
stk = append(stk, curr, p.Provided)
|
||||
continue dfs
|
||||
}
|
||||
if i != -1 {
|
||||
groups[i].outputs.Set(p.Iface, p)
|
||||
inputVisited.Set(p.Iface, i)
|
||||
continue dfs
|
||||
}
|
||||
// Binding must be provided. Find or add a group.
|
||||
for i := range groups {
|
||||
if groups[i].inputs.Len() != 1 {
|
||||
continue
|
||||
}
|
||||
if groups[i].inputs.At(p.Provided) != nil {
|
||||
groups[i].outputs.Set(p.Iface, p)
|
||||
inputVisited.Set(p.Iface, i)
|
||||
continue dfs
|
||||
}
|
||||
}
|
||||
in := new(typeutil.Map)
|
||||
in.SetHasher(hash)
|
||||
in.Set(p.Provided, true)
|
||||
out := new(typeutil.Map)
|
||||
out.SetHasher(hash)
|
||||
out.Set(p.Iface, p)
|
||||
groups = append(groups, outGroup{
|
||||
inputs: in,
|
||||
outputs: out,
|
||||
})
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Name and sort groups
|
||||
for i := range groups {
|
||||
if groups[i].inputs.Len() == 0 {
|
||||
groups[i].name = "no inputs"
|
||||
continue
|
||||
}
|
||||
instr := make([]string, 0, groups[i].inputs.Len())
|
||||
groups[i].inputs.Iterate(func(k types.Type, _ interface{}) {
|
||||
instr = append(instr, types.TypeString(k, nil))
|
||||
})
|
||||
sort.Strings(instr)
|
||||
groups[i].name = strings.Join(instr, ", ")
|
||||
}
|
||||
sort.Slice(groups, func(i, j int) bool {
|
||||
if groups[i].inputs.Len() == groups[j].inputs.Len() {
|
||||
return groups[i].name < groups[j].name
|
||||
}
|
||||
return groups[i].inputs.Len() < groups[j].inputs.Len()
|
||||
})
|
||||
return groups, imports
|
||||
}
|
||||
|
||||
func mergeTypeSets(dst, src *typeutil.Map) {
|
||||
src.Iterate(func(k types.Type, _ interface{}) {
|
||||
dst.Set(k, true)
|
||||
})
|
||||
}
|
||||
|
||||
func sameTypeKeys(a, b *typeutil.Map) bool {
|
||||
if a.Len() != b.Len() {
|
||||
return false
|
||||
}
|
||||
same := true
|
||||
a.Iterate(func(k types.Type, _ interface{}) {
|
||||
if b.At(k) == nil {
|
||||
same = false
|
||||
}
|
||||
})
|
||||
return same
|
||||
}
|
||||
|
||||
func sortSet(set interface{}) []string {
|
||||
rv := reflect.ValueOf(set)
|
||||
a := make([]string, 0, rv.Len())
|
||||
keys := rv.MapKeys()
|
||||
for _, k := range keys {
|
||||
a = append(a, k.String())
|
||||
}
|
||||
sort.Strings(a)
|
||||
return a
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user