diff --git a/internal/goose/analyze.go b/internal/goose/analyze.go index 3b066c5..7f1d3bc 100644 --- a/internal/goose/analyze.go +++ b/internal/goose/analyze.go @@ -60,8 +60,8 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr index := new(typeutil.Map) for i, g := range given { if p := providers.At(g); p != nil { - pp := p.(*providerInfo) - return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.name, mc.fset.Position(pp.pos)) + pp := p.(*Provider) + return nil, fmt.Errorf("input of %s conflicts with provider %s at %s", types.TypeString(g, nil), pp.Name, mc.fset.Position(pp.Pos)) } index.Set(g, i) } @@ -70,49 +70,49 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr // using a depth-first search. The graph may contain cycles, which // should trigger an error. var calls []call - var visit func(trail []providerInput) error - visit = func(trail []providerInput) error { - typ := trail[len(trail)-1].typ + var visit func(trail []ProviderInput) error + visit = func(trail []ProviderInput) error { + typ := trail[len(trail)-1].Type if index.At(typ) != nil { return nil } for _, in := range trail[:len(trail)-1] { - if types.Identical(typ, in.typ) { + if types.Identical(typ, in.Type) { // TODO(light): describe cycle return fmt.Errorf("cycle for %s", types.TypeString(typ, nil)) } } - p, _ := providers.At(typ).(*providerInfo) + p, _ := providers.At(typ).(*Provider) if p == nil { - if trail[len(trail)-1].optional { + if trail[len(trail)-1].Optional { return nil } if len(trail) == 1 { return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil)) } // TODO(light): give name of provider - return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].typ, nil)) + return fmt.Errorf("no provider found for %s (required by provider of %s)", types.TypeString(typ, nil), types.TypeString(trail[len(trail)-2].Type, nil)) } - if !types.Identical(p.out, typ) { + if !types.Identical(p.Out, typ) { // Interface binding. Don't create a call ourselves. - if err := visit(append(trail, providerInput{typ: p.out})); err != nil { + if err := visit(append(trail, ProviderInput{Type: p.Out})); err != nil { return err } - index.Set(typ, index.At(p.out)) + index.Set(typ, index.At(p.Out)) return nil } - for _, a := range p.args { + for _, a := range p.Args { // TODO(light): this will discard grown trail arrays. if err := visit(append(trail, a)); err != nil { return err } } - args := make([]int, len(p.args)) - ins := make([]types.Type, len(p.args)) - for i := range p.args { - ins[i] = p.args[i].typ - if x := index.At(p.args[i].typ); x != nil { + args := make([]int, len(p.Args)) + ins := make([]types.Type, len(p.Args)) + for i := range p.Args { + ins[i] = p.Args[i].Type + if x := index.At(p.Args[i].Type); x != nil { args[i] = x.(int) } else { args[i] = -1 @@ -120,19 +120,19 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr } index.Set(typ, len(given)+len(calls)) calls = append(calls, call{ - importPath: p.importPath, - name: p.name, + importPath: p.ImportPath, + name: p.Name, args: args, - isStruct: p.isStruct, - fieldNames: p.fields, + isStruct: p.IsStruct, + fieldNames: p.Fields, ins: ins, out: typ, - hasCleanup: p.hasCleanup, - hasErr: p.hasErr, + hasCleanup: p.HasCleanup, + hasErr: p.HasErr, }) return nil } - if err := visit([]providerInput{{typ: out}}); err != nil { + if err := visit([]ProviderInput{{Type: out}}); err != nil { return nil, err } return calls, nil @@ -146,7 +146,7 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error pos token.Pos } type binding struct { - ifaceBinding + IfaceBinding pset symref from symref } @@ -173,53 +173,53 @@ func buildProviderMap(mc *providerSetCache, sets []symref) (*typeutil.Map, error } return nil, fmt.Errorf("%v: %v", mc.fset.Position(curr.pos), err) } - for _, p := range pset.providers { - if prev := pm.At(p.out); prev != nil { - pos := mc.fset.Position(p.pos) - typ := types.TypeString(p.out, nil) - prevPos := mc.fset.Position(prev.(*providerInfo).pos) + for _, p := range pset.Providers { + if prev := pm.At(p.Out); prev != nil { + pos := mc.fset.Position(p.Pos) + typ := types.TypeString(p.Out, nil) + prevPos := mc.fset.Position(prev.(*Provider).Pos) if curr.from.importPath == "" { // Provider set is imported directly by injector. return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) } return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, curr.from, prevPos) } - pm.Set(p.out, p) + pm.Set(p.Out, p) } - for _, b := range pset.bindings { + for _, b := range pset.Bindings { bindings = append(bindings, binding{ - ifaceBinding: b, + IfaceBinding: b, pset: curr.to, from: curr.from, }) } - for _, imp := range pset.imports { - next = append(next, nextEnt{to: imp.symref, from: curr.to, pos: imp.pos}) + for _, imp := range pset.Imports { + next = append(next, nextEnt{to: imp.symref(), from: curr.to, pos: imp.Pos}) } } for _, b := range bindings { - if prev := pm.At(b.iface); prev != nil { - pos := mc.fset.Position(b.pos) - typ := types.TypeString(b.iface, nil) + if prev := pm.At(b.Iface); prev != nil { + pos := mc.fset.Position(b.Pos) + typ := types.TypeString(b.Iface, nil) // TODO(light): error message for conflicting with another interface binding will point at provider instead of binding. - prevPos := mc.fset.Position(prev.(*providerInfo).pos) + prevPos := mc.fset.Position(prev.(*Provider).Pos) if b.from.importPath == "" { // Provider set is imported directly by injector. return nil, fmt.Errorf("%v: multiple bindings for %s (added by injector, previous binding at %v)", pos, typ, prevPos) } return nil, fmt.Errorf("%v: multiple bindings for %s (imported by %v, previous binding at %v)", pos, typ, b.from, prevPos) } - concrete := pm.At(b.provided) + concrete := pm.At(b.Provided) if concrete == nil { - pos := mc.fset.Position(b.pos) - typ := types.TypeString(b.provided, nil) + pos := mc.fset.Position(b.Pos) + typ := types.TypeString(b.Provided, nil) if b.from.importPath == "" { // Concrete provider is imported directly by injector. return nil, fmt.Errorf("%v: no binding for %s", pos, typ) } return nil, fmt.Errorf("%v: no binding for %s (imported by %v)", pos, typ, b.from) } - pm.Set(b.iface, concrete) + pm.Set(b.Iface, concrete) } return pm, nil } diff --git a/internal/goose/goose.go b/internal/goose/goose.go index 366ae33..43041f1 100644 --- a/internal/goose/goose.go +++ b/internal/goose/goose.go @@ -22,17 +22,7 @@ import ( // Generate performs dependency injection for a single package, // returning the gofmt'd Go source code. func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { - // TODO(light): allow errors - // TODO(light): stop errors from printing to stderr - conf := &loader.Config{ - Build: new(build.Context), - ParserMode: parser.ParseComments, - Cwd: wd, - TypeCheckFuncBodies: func(string) bool { return false }, - } - *conf.Build = *bctx - n := len(conf.Build.BuildTags) - conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject") + conf := newLoaderConfig(bctx, wd, true) conf.Import(pkg) prog, err := conf.Load() if err != nil { @@ -99,6 +89,24 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { return fmtSrc, nil } +func newLoaderConfig(bctx *build.Context, wd string, inject bool) *loader.Config { + // TODO(light): allow errors + // TODO(light): stop errors from printing to stderr + conf := &loader.Config{ + Build: bctx, + ParserMode: parser.ParseComments, + Cwd: wd, + TypeCheckFuncBodies: func(string) bool { return false }, + } + if inject { + conf.Build = new(build.Context) + *conf.Build = *bctx + n := len(conf.Build.BuildTags) + conf.Build.BuildTags = append(conf.Build.BuildTags[:n:n], "gooseinject") + } + return conf +} + // gen is the generator state. type gen struct { currPackage string diff --git a/internal/goose/parse.go b/internal/goose/parse.go index 18f8cf3..019dd28 100644 --- a/internal/goose/parse.go +++ b/internal/goose/parse.go @@ -14,72 +14,158 @@ import ( "golang.org/x/tools/go/loader" ) -// A providerSet describes a set of providers. The zero value is an empty -// providerSet. -type providerSet struct { - providers []*providerInfo - bindings []ifaceBinding - imports []providerSetImport +// A ProviderSet describes a set of providers. The zero value is an empty +// ProviderSet. +type ProviderSet struct { + Providers []*Provider + Bindings []IfaceBinding + Imports []ProviderSetImport } -// An ifaceBinding declares that a type should be used to satisfy inputs +// An IfaceBinding declares that a type should be used to satisfy inputs // of the given interface type. -// -// provided is always a type that is assignable to iface. -type ifaceBinding struct { - // iface is the interface type, which is what can be injected. - iface types.Type +type IfaceBinding struct { + // Iface is the interface type, which is what can be injected. + Iface types.Type - // provided is always a type that is assignable to Iface. - provided types.Type + // Provided is always a type that is assignable to Iface. + Provided types.Type - // pos is the position where the binding was declared. - pos token.Pos + // Pos is the position where the binding was declared. + Pos token.Pos } -type providerSetImport struct { - symref - pos token.Pos +// A ProviderSetImport adds providers from one provider set into another. +type ProviderSetImport struct { + ProviderSetID + Pos token.Pos } -// providerInfo records the signature of a provider. -type providerInfo struct { - // importPath is the package path that the Go object resides in. - importPath string +// Provider records the signature of a provider. A provider is a +// single Go object, either a function or a named struct type. +type Provider struct { + // ImportPath is the package path that the Go object resides in. + ImportPath string - // name is the name of the Go object. - name string + // Name is the name of the Go object. + Name string - // pos is the source position of the func keyword or type spec + // Pos is the source position of the func keyword or type spec // defining this provider. - pos token.Pos + Pos token.Pos - // args is the list of data dependencies this provider has. - args []providerInput + // Args is the list of data dependencies this provider has. + Args []ProviderInput - // isStruct is true if this provider is a named struct type. + // IsStruct is true if this provider is a named struct type. // Otherwise it's a function. - isStruct bool + IsStruct bool - // fields lists the field names to populate. This will map 1:1 with + // Fields lists the field names to populate. This will map 1:1 with // elements in Args. - fields []string + Fields []string - // out is the type this provider produces. - out types.Type + // Out is the type this provider produces. + Out types.Type - // hasCleanup reports whether the provider function returns a cleanup + // HasCleanup reports whether the provider function returns a cleanup // function. (Always false for structs.) - hasCleanup bool + HasCleanup bool - // hasErr reports whether the provider function can return an error. + // HasErr reports whether the provider function can return an error. // (Always false for structs.) - hasErr bool + HasErr bool } -type providerInput struct { - typ types.Type - optional bool +// ProviderInput describes an incoming edge in the provider graph. +type ProviderInput struct { + Type types.Type + Optional bool +} + +// Load finds all the provider sets in the given packages, as well as +// the provider sets' transitive dependencies. +func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) { + conf := newLoaderConfig(bctx, wd, false) + for _, p := range pkgs { + conf.Import(p) + } + prog, err := conf.Load() + if err != nil { + return nil, fmt.Errorf("load: %v", err) + } + r := newImportResolver(conf, prog.Fset) + var next []string + initial := make(map[string]struct{}) + for _, pkgInfo := range prog.InitialPackages() { + path := pkgInfo.Pkg.Path() + next = append(next, path) + initial[path] = struct{}{} + } + visited := make(map[string]struct{}) + info := &Info{ + Fset: prog.Fset, + Sets: make(map[ProviderSetID]*ProviderSet), + All: make(map[ProviderSetID]*ProviderSet), + } + for len(next) > 0 { + curr := next[len(next)-1] + next = next[:len(next)-1] + if _, ok := visited[curr]; ok { + continue + } + visited[curr] = struct{}{} + pkgInfo := prog.Package(curr) + sets, err := findProviderSets(findContext{ + fset: prog.Fset, + pkg: pkgInfo.Pkg, + typeInfo: &pkgInfo.Info, + r: r, + }, pkgInfo.Files) + if err != nil { + return nil, fmt.Errorf("load: %v", err) + } + path := pkgInfo.Pkg.Path() + for name, set := range sets { + info.All[ProviderSetID{path, name}] = set + for _, imp := range set.Imports { + next = append(next, imp.ImportPath) + } + } + if _, ok := initial[path]; ok { + for name, set := range sets { + info.Sets[ProviderSetID{path, name}] = set + } + } + } + return info, nil +} + +// Info holds the result of Load. +type Info struct { + Fset *token.FileSet + + // Sets contains all the provider sets in the initial packages. + Sets map[ProviderSetID]*ProviderSet + + // All contains all the provider sets transitively depended on by the + // initial packages' provider sets. + All map[ProviderSetID]*ProviderSet +} + +// A ProviderSetID identifies a provider set. +type ProviderSetID struct { + ImportPath string + Name string +} + +// String returns the ID as ""path/to/pkg".Foo". +func (id ProviderSetID) String() string { + return id.symref().String() +} + +func (id ProviderSetID) symref() symref { + return symref{importPath: id.ImportPath, name: id.Name} } type findContext struct { @@ -90,8 +176,8 @@ type findContext struct { } // findProviderSets processes a package and extracts the provider sets declared in it. -func findProviderSets(fctx findContext, files []*ast.File) (map[string]*providerSet, error) { - sets := make(map[string]*providerSet) +func findProviderSets(fctx findContext, files []*ast.File) (map[string]*ProviderSet, error) { + sets := make(map[string]*ProviderSet) for _, f := range files { fileScope := fctx.typeInfo.Scopes[f] if fileScope == nil { @@ -115,7 +201,7 @@ func findProviderSets(fctx findContext, files []*ast.File) (map[string]*provider } // processUnassociatedDirective handles any directive that was not associated with a top-level declaration. -func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet, scope *types.Scope, d directive) error { +func processUnassociatedDirective(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, d directive) error { switch d.kind { case "provide", "optional": return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos)) @@ -169,15 +255,15 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet name := args[0] if pset := sets[name]; pset != nil { - pset.bindings = append(pset.bindings, ifaceBinding{ - iface: iface, - provided: provided, + pset.Bindings = append(pset.Bindings, IfaceBinding{ + Iface: iface, + Provided: provided, }) } else { - sets[name] = &providerSet{ - bindings: []ifaceBinding{{ - iface: iface, - provided: provided, + sets[name] = &ProviderSet{ + Bindings: []IfaceBinding{{ + Iface: iface, + Provided: provided, }}, } } @@ -197,18 +283,30 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet } if mod := sets[name]; mod != nil { found := false - for _, other := range mod.imports { - if ref == other.symref { + for _, other := range mod.Imports { + if ref == other.symref() { found = true break } } if !found { - mod.imports = append(mod.imports, providerSetImport{symref: ref, pos: d.pos}) + mod.Imports = append(mod.Imports, ProviderSetImport{ + ProviderSetID: ProviderSetID{ + ImportPath: ref.importPath, + Name: ref.name, + }, + Pos: d.pos, + }) } } else { - sets[name] = &providerSet{ - imports: []providerSetImport{{symref: ref, pos: d.pos}}, + sets[name] = &ProviderSet{ + Imports: []ProviderSetImport{{ + ProviderSetID: ProviderSetID{ + ImportPath: ref.importPath, + Name: ref.name, + }, + Pos: d.pos, + }}, } } } @@ -219,7 +317,7 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet } // processDeclDirectives processes the directives associated with a top-level declaration. -func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope *types.Scope, dg directiveGroup) error { +func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, dg directiveGroup) error { p, err := dg.single(fctx.fset, "provide") if err != nil { return err @@ -258,15 +356,15 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope providerSetName = fn.Name() } if mod := sets[providerSetName]; mod != nil { - for _, other := range mod.providers { - if types.Identical(other.out, provider.out) { - return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos)) + for _, other := range mod.Providers { + if types.Identical(other.Out, provider.Out) { + return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(fn.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos)) } } - mod.providers = append(mod.providers, provider) + mod.Providers = append(mod.Providers, provider) } else { - sets[providerSetName] = &providerSet{ - providers: []*providerInfo{provider}, + sets[providerSetName] = &ProviderSet{ + Providers: []*Provider{provider}, } } case *ast.GenDecl: @@ -288,22 +386,22 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope if providerSetName == "" { providerSetName = typeName.Name() } - ptrProvider := new(providerInfo) + ptrProvider := new(Provider) *ptrProvider = *provider - ptrProvider.out = types.NewPointer(provider.out) + ptrProvider.Out = types.NewPointer(provider.Out) if mod := sets[providerSetName]; mod != nil { - for _, other := range mod.providers { - if types.Identical(other.out, provider.out) { - return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.out, nil), fctx.fset.Position(other.pos)) + for _, other := range mod.Providers { + if types.Identical(other.Out, provider.Out) { + return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(provider.Out, nil), fctx.fset.Position(other.Pos)) } - if types.Identical(other.out, ptrProvider.out) { - return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.out, nil), fctx.fset.Position(other.pos)) + if types.Identical(other.Out, ptrProvider.Out) { + return fmt.Errorf("%v: provider set %s has multiple providers for %s (previous declaration at %v)", fctx.fset.Position(typeName.Pos()), providerSetName, types.TypeString(ptrProvider.Out, nil), fctx.fset.Position(other.Pos)) } } - mod.providers = append(mod.providers, provider, ptrProvider) + mod.Providers = append(mod.Providers, provider, ptrProvider) } else { - sets[providerSetName] = &providerSet{ - providers: []*providerInfo{provider, ptrProvider}, + sets[providerSetName] = &ProviderSet{ + Providers: []*Provider{provider, ptrProvider}, } } default: @@ -312,7 +410,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*providerSet, scope return nil } -func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*providerInfo, error) { +func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*Provider, error) { sig := fn.Type().(*types.Signature) optionals := make([]bool, sig.Params().Len()) @@ -352,30 +450,30 @@ func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[stri } out := r.At(0).Type() params := sig.Params() - provider := &providerInfo{ - importPath: fctx.pkg.Path(), - name: fn.Name(), - pos: fn.Pos(), - args: make([]providerInput, params.Len()), - out: out, - hasCleanup: hasCleanup, - hasErr: hasErr, + provider := &Provider{ + ImportPath: fctx.pkg.Path(), + Name: fn.Name(), + Pos: fn.Pos(), + Args: make([]ProviderInput, params.Len()), + Out: out, + HasCleanup: hasCleanup, + HasErr: hasErr, } for i := 0; i < params.Len(); i++ { - provider.args[i] = providerInput{ - typ: params.At(i).Type(), - optional: optionals[i], + provider.Args[i] = ProviderInput{ + Type: params.At(i).Type(), + Optional: optionals[i], } for j := 0; j < i; j++ { - if types.Identical(provider.args[i].typ, provider.args[j].typ) { - return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.args[j].typ, nil)) + if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { + return nil, fmt.Errorf("%v: provider has multiple parameters of type %s", fctx.fset.Position(fpos), types.TypeString(provider.Args[j].Type, nil)) } } } return provider, nil } -func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*providerInfo, error) { +func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*Provider, error) { out := typeName.Type() st := out.Underlying().(*types.Struct) for arg, dpos := range optionals { @@ -392,26 +490,26 @@ func processStructProvider(fctx findContext, typeName *types.TypeName, optionals } pos := typeName.Pos() - provider := &providerInfo{ - importPath: fctx.pkg.Path(), - name: typeName.Name(), - pos: pos, - args: make([]providerInput, st.NumFields()), - fields: make([]string, st.NumFields()), - isStruct: true, - out: out, + provider := &Provider{ + ImportPath: fctx.pkg.Path(), + Name: typeName.Name(), + Pos: pos, + Args: make([]ProviderInput, st.NumFields()), + Fields: make([]string, st.NumFields()), + IsStruct: true, + Out: out, } for i := 0; i < st.NumFields(); i++ { f := st.Field(i) _, optional := optionals[f.Name()] - provider.args[i] = providerInput{ - typ: f.Type(), - optional: optional, + provider.Args[i] = ProviderInput{ + Type: f.Type(), + Optional: optional, } - provider.fields[i] = f.Name() + provider.Fields[i] = f.Name() for j := 0; j < i; j++ { - if types.Identical(provider.args[i].typ, provider.args[j].typ) { - return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.args[j].typ, nil)) + if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { + return nil, fmt.Errorf("%v: provider struct has multiple fields of type %s", fctx.fset.Position(pos), types.TypeString(provider.Args[j].Type, nil)) } } } @@ -420,7 +518,7 @@ func processStructProvider(fctx findContext, typeName *types.TypeName, optionals // providerSetCache is a lazily evaluated index of provider sets. type providerSetCache struct { - sets map[string]map[string]*providerSet + sets map[string]map[string]*ProviderSet fset *token.FileSet prog *loader.Program r *importResolver @@ -434,7 +532,7 @@ func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCa } } -func (mc *providerSetCache) get(ref symref) (*providerSet, error) { +func (mc *providerSetCache) get(ref symref) (*ProviderSet, error) { if mods, cached := mc.sets[ref.importPath]; cached { mod := mods[ref.name] if mod == nil { @@ -443,7 +541,7 @@ func (mc *providerSetCache) get(ref symref) (*providerSet, error) { return mod, nil } if mc.sets == nil { - mc.sets = make(map[string]map[string]*providerSet) + mc.sets = make(map[string]map[string]*ProviderSet) } pkg := mc.prog.Package(ref.importPath) mods, err := findProviderSets(findContext{ diff --git a/main.go b/main.go index 30db70b..062b0c7 100644 --- a/main.go +++ b/main.go @@ -6,47 +6,333 @@ package main import ( "fmt" "go/build" + "go/token" + "go/types" "io/ioutil" "os" "path/filepath" + "reflect" + "sort" + "strings" "codename/goose/internal/goose" + "golang.org/x/tools/go/types/typeutil" ) func main() { - var pkg string - switch len(os.Args) { - case 1: - pkg = "." - case 2: - pkg = os.Args[1] + var err error + switch { + case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen": + err = generate(".") + case len(os.Args) == 2 && os.Args[1] == "show": + err = show(".") + case len(os.Args) == 2: + err = generate(os.Args[1]) + case len(os.Args) > 2 && os.Args[1] == "show": + err = show(os.Args[2:]...) + case len(os.Args) == 3 && os.Args[1] == "gen": + err = generate(os.Args[2]) default: - fmt.Fprintln(os.Stderr, "goose: usage: goose [PKG]") + fmt.Fprintln(os.Stderr, "goose: usage: goose [gen] [PKG] | goose show [...]") os.Exit(64) } - wd, err := os.Getwd() if err != nil { fmt.Fprintln(os.Stderr, "goose:", err) os.Exit(1) } - pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly) - if err != nil { - fmt.Fprintln(os.Stderr, "goose:", err) - os.Exit(1) - } - out, err := goose.Generate(&build.Default, wd, pkg) - if err != nil { - fmt.Fprintln(os.Stderr, "goose:", err) - os.Exit(1) - } - if len(out) == 0 { - // No Goose directives, don't write anything. - fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg) - return - } - p := filepath.Join(pkgInfo.Dir, "goose_gen.go") - if err := ioutil.WriteFile(p, out, 0666); err != nil { - fmt.Fprintln(os.Stderr, "goose:", err) - os.Exit(1) - } +} + +// generate runs the gen subcommand. Given a package, gen will create +// the goose_gen.go file. +func generate(pkg string) error { + wd, err := os.Getwd() + if err != nil { + return err + } + pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly) + if err != nil { + return err + } + out, err := goose.Generate(&build.Default, wd, pkg) + if err != nil { + return err + } + if len(out) == 0 { + // No Goose directives, don't write anything. + fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg) + return nil + } + p := filepath.Join(pkgInfo.Dir, "goose_gen.go") + if err := ioutil.WriteFile(p, out, 0666); err != nil { + return err + } + return nil +} + +// show runs the show subcommand. +// +// Given one or more packages, show will find all the declared provider +// sets and print what other provider sets it imports and what outputs +// it can produce, given possible inputs. +func show(pkgs ...string) error { + wd, err := os.Getwd() + if err != nil { + return err + } + info, err := goose.Load(&build.Default, wd, pkgs) + if err != nil { + return err + } + keys := make([]goose.ProviderSetID, 0, len(info.Sets)) + for k := range info.Sets { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + if keys[i].ImportPath == keys[j].ImportPath { + return keys[i].Name < keys[j].Name + } + return keys[i].ImportPath < keys[j].ImportPath + }) + // ANSI color codes. + const ( + reset = "\x1b[0m" + redBold = "\x1b[0;1;31m" + blue = "\x1b[0;34m" + green = "\x1b[0;32m" + ) + for i, k := range keys { + if i > 0 { + fmt.Println() + } + outGroups, imports := gather(info, k) + fmt.Printf("%s%s%s\n", redBold, k, reset) + for _, imp := range sortSet(imports) { + fmt.Printf("\t%s\n", imp) + } + for i := range outGroups { + fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset) + out := make(map[string]token.Pos, outGroups[i].outputs.Len()) + outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) { + switch v := v.(type) { + case *goose.Provider: + out[types.TypeString(t, nil)] = v.Pos + case goose.IfaceBinding: + out[types.TypeString(t, nil)] = v.Pos + default: + panic("unreachable") + } + }) + for _, t := range sortSet(out) { + fmt.Printf("\t%s%s%s\n", green, t, reset) + fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t])) + } + } + } + return nil +} + +type outGroup struct { + name string + inputs *typeutil.Map // values are not important + outputs *typeutil.Map // values are either *goose.Provider or goose.IfaceBinding +} + +// gather flattens a provider set into outputs grouped by the inputs +// required to create them. As it flattens the provider set, it records +// the visited provider sets as imports. +func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports map[string]struct{}) { + hash := typeutil.MakeHasher() + // Map types to providers and bindings. + pm := new(typeutil.Map) + pm.SetHasher(hash) + next := []goose.ProviderSetID{key} + visited := make(map[goose.ProviderSetID]struct{}) + imports = make(map[string]struct{}) + for len(next) > 0 { + curr := next[len(next)-1] + next = next[:len(next)-1] + if _, found := visited[curr]; found { + continue + } + visited[curr] = struct{}{} + if curr != key { + imports[curr.String()] = struct{}{} + } + set := info.All[curr] + for _, p := range set.Providers { + pm.Set(p.Out, p) + } + for _, b := range set.Bindings { + pm.Set(b.Iface, b) + } + for _, imp := range set.Imports { + next = append(next, imp.ProviderSetID) + } + } + + // Depth-first search to build groups. + var groups []outGroup + inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input. + inputVisited.SetHasher(hash) + pmKeys := pm.Keys() + var stk []types.Type + for _, k := range pmKeys { + // Start a DFS by picking a random unvisited node. + if inputVisited.At(k) == nil { + stk = append(stk, k) + } + + // Run DFS + dfs: + for len(stk) > 0 { + curr := stk[len(stk)-1] + stk = stk[:len(stk)-1] + if inputVisited.At(curr) != nil { + continue + } + switch p := pm.At(curr).(type) { + case nil: + // This is an input. + inputVisited.Set(curr, -1) + case *goose.Provider: + // Try to see if any args haven't been visited. + allPresent := true + for _, arg := range p.Args { + if arg.Optional { + continue + } + if inputVisited.At(arg.Type) == nil { + allPresent = false + } + } + if !allPresent { + stk = append(stk, curr) + for _, arg := range p.Args { + if arg.Optional { + continue + } + if inputVisited.At(arg.Type) == nil { + stk = append(stk, arg.Type) + } + } + continue dfs + } + + // Build up set of input types, match to a group. + in := new(typeutil.Map) + in.SetHasher(hash) + for _, arg := range p.Args { + if arg.Optional { + continue + } + i := inputVisited.At(arg.Type).(int) + if i == -1 { + in.Set(arg.Type, true) + } else { + mergeTypeSets(in, groups[i].inputs) + } + } + for i := range groups { + if sameTypeKeys(groups[i].inputs, in) { + groups[i].outputs.Set(p.Out, p) + inputVisited.Set(p.Out, i) + continue dfs + } + } + out := new(typeutil.Map) + out.SetHasher(hash) + out.Set(p.Out, p) + inputVisited.Set(p.Out, len(groups)) + groups = append(groups, outGroup{ + inputs: in, + outputs: out, + }) + case goose.IfaceBinding: + i, ok := inputVisited.At(p.Provided).(int) + if !ok { + stk = append(stk, curr, p.Provided) + continue dfs + } + if i != -1 { + groups[i].outputs.Set(p.Iface, p) + inputVisited.Set(p.Iface, i) + continue dfs + } + // Binding must be provided. Find or add a group. + for i := range groups { + if groups[i].inputs.Len() != 1 { + continue + } + if groups[i].inputs.At(p.Provided) != nil { + groups[i].outputs.Set(p.Iface, p) + inputVisited.Set(p.Iface, i) + continue dfs + } + } + in := new(typeutil.Map) + in.SetHasher(hash) + in.Set(p.Provided, true) + out := new(typeutil.Map) + out.SetHasher(hash) + out.Set(p.Iface, p) + groups = append(groups, outGroup{ + inputs: in, + outputs: out, + }) + default: + panic("unreachable") + } + } + } + + // Name and sort groups + for i := range groups { + if groups[i].inputs.Len() == 0 { + groups[i].name = "no inputs" + continue + } + instr := make([]string, 0, groups[i].inputs.Len()) + groups[i].inputs.Iterate(func(k types.Type, _ interface{}) { + instr = append(instr, types.TypeString(k, nil)) + }) + sort.Strings(instr) + groups[i].name = strings.Join(instr, ", ") + } + sort.Slice(groups, func(i, j int) bool { + if groups[i].inputs.Len() == groups[j].inputs.Len() { + return groups[i].name < groups[j].name + } + return groups[i].inputs.Len() < groups[j].inputs.Len() + }) + return groups, imports +} + +func mergeTypeSets(dst, src *typeutil.Map) { + src.Iterate(func(k types.Type, _ interface{}) { + dst.Set(k, true) + }) +} + +func sameTypeKeys(a, b *typeutil.Map) bool { + if a.Len() != b.Len() { + return false + } + same := true + a.Iterate(func(k types.Type, _ interface{}) { + if b.At(k) == nil { + same = false + } + }) + return same +} + +func sortSet(set interface{}) []string { + rv := reflect.ValueOf(set) + a := make([]string, 0, rv.Len()) + keys := rv.MapKeys() + for _, k := range keys { + a = append(a, k.String()) + } + sort.Strings(a) + return a }