2018-03-26 07:39:00 -07:00
|
|
|
// goose is a compile-time dependency injection tool.
|
|
|
|
|
//
|
|
|
|
|
// See README.md for an overview.
|
|
|
|
|
package main
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"go/build"
|
2018-04-04 14:42:56 -07:00
|
|
|
"go/token"
|
|
|
|
|
"go/types"
|
2018-03-26 07:39:00 -07:00
|
|
|
"io/ioutil"
|
|
|
|
|
"os"
|
|
|
|
|
"path/filepath"
|
2018-04-04 14:42:56 -07:00
|
|
|
"reflect"
|
|
|
|
|
"sort"
|
|
|
|
|
"strings"
|
2018-03-26 07:39:00 -07:00
|
|
|
|
|
|
|
|
"codename/goose/internal/goose"
|
2018-04-04 14:42:56 -07:00
|
|
|
"golang.org/x/tools/go/types/typeutil"
|
2018-03-26 07:39:00 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func main() {
|
2018-04-04 14:42:56 -07:00
|
|
|
var err error
|
|
|
|
|
switch {
|
|
|
|
|
case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen":
|
|
|
|
|
err = generate(".")
|
|
|
|
|
case len(os.Args) == 2 && os.Args[1] == "show":
|
|
|
|
|
err = show(".")
|
|
|
|
|
case len(os.Args) == 2:
|
|
|
|
|
err = generate(os.Args[1])
|
|
|
|
|
case len(os.Args) > 2 && os.Args[1] == "show":
|
|
|
|
|
err = show(os.Args[2:]...)
|
|
|
|
|
case len(os.Args) == 3 && os.Args[1] == "gen":
|
|
|
|
|
err = generate(os.Args[2])
|
2018-03-26 07:39:00 -07:00
|
|
|
default:
|
2018-04-04 14:42:56 -07:00
|
|
|
fmt.Fprintln(os.Stderr, "goose: usage: goose [gen] [PKG] | goose show [...]")
|
2018-03-26 07:39:00 -07:00
|
|
|
os.Exit(64)
|
|
|
|
|
}
|
|
|
|
|
if err != nil {
|
|
|
|
|
fmt.Fprintln(os.Stderr, "goose:", err)
|
|
|
|
|
os.Exit(1)
|
|
|
|
|
}
|
2018-04-04 14:42:56 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generate runs the gen subcommand. Given a package, gen will create
|
|
|
|
|
// the goose_gen.go file.
|
|
|
|
|
func generate(pkg string) error {
|
|
|
|
|
wd, err := os.Getwd()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
2018-03-26 07:39:00 -07:00
|
|
|
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
|
|
|
|
|
if err != nil {
|
2018-04-04 14:42:56 -07:00
|
|
|
return err
|
2018-03-26 07:39:00 -07:00
|
|
|
}
|
|
|
|
|
out, err := goose.Generate(&build.Default, wd, pkg)
|
|
|
|
|
if err != nil {
|
2018-04-04 14:42:56 -07:00
|
|
|
return err
|
2018-03-26 07:39:00 -07:00
|
|
|
}
|
|
|
|
|
if len(out) == 0 {
|
|
|
|
|
// No Goose directives, don't write anything.
|
|
|
|
|
fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg)
|
2018-04-04 14:42:56 -07:00
|
|
|
return nil
|
2018-03-26 07:39:00 -07:00
|
|
|
}
|
|
|
|
|
p := filepath.Join(pkgInfo.Dir, "goose_gen.go")
|
|
|
|
|
if err := ioutil.WriteFile(p, out, 0666); err != nil {
|
2018-04-04 14:42:56 -07:00
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// show runs the show subcommand.
|
|
|
|
|
//
|
|
|
|
|
// Given one or more packages, show will find all the declared provider
|
|
|
|
|
// sets and print what other provider sets it imports and what outputs
|
|
|
|
|
// it can produce, given possible inputs.
|
|
|
|
|
func show(pkgs ...string) error {
|
|
|
|
|
wd, err := os.Getwd()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
info, err := goose.Load(&build.Default, wd, pkgs)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
keys := make([]goose.ProviderSetID, 0, len(info.Sets))
|
|
|
|
|
for k := range info.Sets {
|
|
|
|
|
keys = append(keys, k)
|
|
|
|
|
}
|
|
|
|
|
sort.Slice(keys, func(i, j int) bool {
|
|
|
|
|
if keys[i].ImportPath == keys[j].ImportPath {
|
|
|
|
|
return keys[i].Name < keys[j].Name
|
|
|
|
|
}
|
|
|
|
|
return keys[i].ImportPath < keys[j].ImportPath
|
|
|
|
|
})
|
|
|
|
|
// ANSI color codes.
|
|
|
|
|
const (
|
|
|
|
|
reset = "\x1b[0m"
|
|
|
|
|
redBold = "\x1b[0;1;31m"
|
|
|
|
|
blue = "\x1b[0;34m"
|
|
|
|
|
green = "\x1b[0;32m"
|
|
|
|
|
)
|
|
|
|
|
for i, k := range keys {
|
|
|
|
|
if i > 0 {
|
|
|
|
|
fmt.Println()
|
|
|
|
|
}
|
|
|
|
|
outGroups, imports := gather(info, k)
|
|
|
|
|
fmt.Printf("%s%s%s\n", redBold, k, reset)
|
|
|
|
|
for _, imp := range sortSet(imports) {
|
|
|
|
|
fmt.Printf("\t%s\n", imp)
|
|
|
|
|
}
|
|
|
|
|
for i := range outGroups {
|
|
|
|
|
fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset)
|
|
|
|
|
out := make(map[string]token.Pos, outGroups[i].outputs.Len())
|
|
|
|
|
outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) {
|
|
|
|
|
switch v := v.(type) {
|
|
|
|
|
case *goose.Provider:
|
|
|
|
|
out[types.TypeString(t, nil)] = v.Pos
|
|
|
|
|
case goose.IfaceBinding:
|
|
|
|
|
out[types.TypeString(t, nil)] = v.Pos
|
|
|
|
|
default:
|
|
|
|
|
panic("unreachable")
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
for _, t := range sortSet(out) {
|
|
|
|
|
fmt.Printf("\t%s%s%s\n", green, t, reset)
|
|
|
|
|
fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t]))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type outGroup struct {
|
|
|
|
|
name string
|
|
|
|
|
inputs *typeutil.Map // values are not important
|
|
|
|
|
outputs *typeutil.Map // values are either *goose.Provider or goose.IfaceBinding
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// gather flattens a provider set into outputs grouped by the inputs
|
|
|
|
|
// required to create them. As it flattens the provider set, it records
|
|
|
|
|
// the visited provider sets as imports.
|
|
|
|
|
func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports map[string]struct{}) {
|
|
|
|
|
hash := typeutil.MakeHasher()
|
|
|
|
|
// Map types to providers and bindings.
|
|
|
|
|
pm := new(typeutil.Map)
|
|
|
|
|
pm.SetHasher(hash)
|
|
|
|
|
next := []goose.ProviderSetID{key}
|
|
|
|
|
visited := make(map[goose.ProviderSetID]struct{})
|
|
|
|
|
imports = make(map[string]struct{})
|
|
|
|
|
for len(next) > 0 {
|
|
|
|
|
curr := next[len(next)-1]
|
|
|
|
|
next = next[:len(next)-1]
|
|
|
|
|
if _, found := visited[curr]; found {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
visited[curr] = struct{}{}
|
|
|
|
|
if curr != key {
|
|
|
|
|
imports[curr.String()] = struct{}{}
|
|
|
|
|
}
|
|
|
|
|
set := info.All[curr]
|
|
|
|
|
for _, p := range set.Providers {
|
|
|
|
|
pm.Set(p.Out, p)
|
|
|
|
|
}
|
|
|
|
|
for _, b := range set.Bindings {
|
|
|
|
|
pm.Set(b.Iface, b)
|
|
|
|
|
}
|
|
|
|
|
for _, imp := range set.Imports {
|
|
|
|
|
next = append(next, imp.ProviderSetID)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Depth-first search to build groups.
|
|
|
|
|
var groups []outGroup
|
|
|
|
|
inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input.
|
|
|
|
|
inputVisited.SetHasher(hash)
|
|
|
|
|
pmKeys := pm.Keys()
|
|
|
|
|
var stk []types.Type
|
|
|
|
|
for _, k := range pmKeys {
|
|
|
|
|
// Start a DFS by picking a random unvisited node.
|
|
|
|
|
if inputVisited.At(k) == nil {
|
|
|
|
|
stk = append(stk, k)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Run DFS
|
|
|
|
|
dfs:
|
|
|
|
|
for len(stk) > 0 {
|
|
|
|
|
curr := stk[len(stk)-1]
|
|
|
|
|
stk = stk[:len(stk)-1]
|
|
|
|
|
if inputVisited.At(curr) != nil {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
switch p := pm.At(curr).(type) {
|
|
|
|
|
case nil:
|
|
|
|
|
// This is an input.
|
|
|
|
|
inputVisited.Set(curr, -1)
|
|
|
|
|
case *goose.Provider:
|
|
|
|
|
// Try to see if any args haven't been visited.
|
|
|
|
|
allPresent := true
|
|
|
|
|
for _, arg := range p.Args {
|
|
|
|
|
if arg.Optional {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if inputVisited.At(arg.Type) == nil {
|
|
|
|
|
allPresent = false
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if !allPresent {
|
|
|
|
|
stk = append(stk, curr)
|
|
|
|
|
for _, arg := range p.Args {
|
|
|
|
|
if arg.Optional {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if inputVisited.At(arg.Type) == nil {
|
|
|
|
|
stk = append(stk, arg.Type)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
continue dfs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Build up set of input types, match to a group.
|
|
|
|
|
in := new(typeutil.Map)
|
|
|
|
|
in.SetHasher(hash)
|
|
|
|
|
for _, arg := range p.Args {
|
|
|
|
|
if arg.Optional {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
i := inputVisited.At(arg.Type).(int)
|
|
|
|
|
if i == -1 {
|
|
|
|
|
in.Set(arg.Type, true)
|
|
|
|
|
} else {
|
|
|
|
|
mergeTypeSets(in, groups[i].inputs)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for i := range groups {
|
|
|
|
|
if sameTypeKeys(groups[i].inputs, in) {
|
|
|
|
|
groups[i].outputs.Set(p.Out, p)
|
|
|
|
|
inputVisited.Set(p.Out, i)
|
|
|
|
|
continue dfs
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
out := new(typeutil.Map)
|
|
|
|
|
out.SetHasher(hash)
|
|
|
|
|
out.Set(p.Out, p)
|
|
|
|
|
inputVisited.Set(p.Out, len(groups))
|
|
|
|
|
groups = append(groups, outGroup{
|
|
|
|
|
inputs: in,
|
|
|
|
|
outputs: out,
|
|
|
|
|
})
|
|
|
|
|
case goose.IfaceBinding:
|
|
|
|
|
i, ok := inputVisited.At(p.Provided).(int)
|
|
|
|
|
if !ok {
|
|
|
|
|
stk = append(stk, curr, p.Provided)
|
|
|
|
|
continue dfs
|
|
|
|
|
}
|
|
|
|
|
if i != -1 {
|
|
|
|
|
groups[i].outputs.Set(p.Iface, p)
|
|
|
|
|
inputVisited.Set(p.Iface, i)
|
|
|
|
|
continue dfs
|
|
|
|
|
}
|
|
|
|
|
// Binding must be provided. Find or add a group.
|
|
|
|
|
for i := range groups {
|
|
|
|
|
if groups[i].inputs.Len() != 1 {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if groups[i].inputs.At(p.Provided) != nil {
|
|
|
|
|
groups[i].outputs.Set(p.Iface, p)
|
|
|
|
|
inputVisited.Set(p.Iface, i)
|
|
|
|
|
continue dfs
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
in := new(typeutil.Map)
|
|
|
|
|
in.SetHasher(hash)
|
|
|
|
|
in.Set(p.Provided, true)
|
|
|
|
|
out := new(typeutil.Map)
|
|
|
|
|
out.SetHasher(hash)
|
|
|
|
|
out.Set(p.Iface, p)
|
|
|
|
|
groups = append(groups, outGroup{
|
|
|
|
|
inputs: in,
|
|
|
|
|
outputs: out,
|
|
|
|
|
})
|
|
|
|
|
default:
|
|
|
|
|
panic("unreachable")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Name and sort groups
|
|
|
|
|
for i := range groups {
|
|
|
|
|
if groups[i].inputs.Len() == 0 {
|
|
|
|
|
groups[i].name = "no inputs"
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
instr := make([]string, 0, groups[i].inputs.Len())
|
|
|
|
|
groups[i].inputs.Iterate(func(k types.Type, _ interface{}) {
|
|
|
|
|
instr = append(instr, types.TypeString(k, nil))
|
|
|
|
|
})
|
|
|
|
|
sort.Strings(instr)
|
|
|
|
|
groups[i].name = strings.Join(instr, ", ")
|
|
|
|
|
}
|
|
|
|
|
sort.Slice(groups, func(i, j int) bool {
|
|
|
|
|
if groups[i].inputs.Len() == groups[j].inputs.Len() {
|
|
|
|
|
return groups[i].name < groups[j].name
|
|
|
|
|
}
|
|
|
|
|
return groups[i].inputs.Len() < groups[j].inputs.Len()
|
|
|
|
|
})
|
|
|
|
|
return groups, imports
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func mergeTypeSets(dst, src *typeutil.Map) {
|
|
|
|
|
src.Iterate(func(k types.Type, _ interface{}) {
|
|
|
|
|
dst.Set(k, true)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func sameTypeKeys(a, b *typeutil.Map) bool {
|
|
|
|
|
if a.Len() != b.Len() {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
same := true
|
|
|
|
|
a.Iterate(func(k types.Type, _ interface{}) {
|
|
|
|
|
if b.At(k) == nil {
|
|
|
|
|
same = false
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
return same
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func sortSet(set interface{}) []string {
|
|
|
|
|
rv := reflect.ValueOf(set)
|
|
|
|
|
a := make([]string, 0, rv.Len())
|
|
|
|
|
keys := rv.MapKeys()
|
|
|
|
|
for _, k := range keys {
|
|
|
|
|
a = append(a, k.String())
|
2018-03-26 07:39:00 -07:00
|
|
|
}
|
2018-04-04 14:42:56 -07:00
|
|
|
sort.Strings(a)
|
|
|
|
|
return a
|
2018-03-26 07:39:00 -07:00
|
|
|
}
|