Files
wire/main.go
Ross Light cfc6111ea5 goose: add show command
Lists provider sets in packages given on the command line, including
outputs grouped by what is needed to obtain them.

The goose package now exports the loading phase as an API.

Example output: https://paste.googleplex.com/5509965720584192

Reviewed-by: Tuo Shan <shantuo@google.com>
2018-11-12 14:09:56 -08:00

339 lines
8.0 KiB
Go

// goose is a compile-time dependency injection tool.
//
// See README.md for an overview.
package main
import (
"fmt"
"go/build"
"go/token"
"go/types"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"sort"
"strings"
"codename/goose/internal/goose"
"golang.org/x/tools/go/types/typeutil"
)
func main() {
var err error
switch {
case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen":
err = generate(".")
case len(os.Args) == 2 && os.Args[1] == "show":
err = show(".")
case len(os.Args) == 2:
err = generate(os.Args[1])
case len(os.Args) > 2 && os.Args[1] == "show":
err = show(os.Args[2:]...)
case len(os.Args) == 3 && os.Args[1] == "gen":
err = generate(os.Args[2])
default:
fmt.Fprintln(os.Stderr, "goose: usage: goose [gen] [PKG] | goose show [...]")
os.Exit(64)
}
if err != nil {
fmt.Fprintln(os.Stderr, "goose:", err)
os.Exit(1)
}
}
// generate runs the gen subcommand. Given a package, gen will create
// the goose_gen.go file.
func generate(pkg string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
if err != nil {
return err
}
out, err := goose.Generate(&build.Default, wd, pkg)
if err != nil {
return err
}
if len(out) == 0 {
// No Goose directives, don't write anything.
fmt.Fprintln(os.Stderr, "goose: no injector found for", pkg)
return nil
}
p := filepath.Join(pkgInfo.Dir, "goose_gen.go")
if err := ioutil.WriteFile(p, out, 0666); err != nil {
return err
}
return nil
}
// show runs the show subcommand.
//
// Given one or more packages, show will find all the declared provider
// sets and print what other provider sets it imports and what outputs
// it can produce, given possible inputs.
func show(pkgs ...string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
info, err := goose.Load(&build.Default, wd, pkgs)
if err != nil {
return err
}
keys := make([]goose.ProviderSetID, 0, len(info.Sets))
for k := range info.Sets {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].ImportPath == keys[j].ImportPath {
return keys[i].Name < keys[j].Name
}
return keys[i].ImportPath < keys[j].ImportPath
})
// ANSI color codes.
const (
reset = "\x1b[0m"
redBold = "\x1b[0;1;31m"
blue = "\x1b[0;34m"
green = "\x1b[0;32m"
)
for i, k := range keys {
if i > 0 {
fmt.Println()
}
outGroups, imports := gather(info, k)
fmt.Printf("%s%s%s\n", redBold, k, reset)
for _, imp := range sortSet(imports) {
fmt.Printf("\t%s\n", imp)
}
for i := range outGroups {
fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset)
out := make(map[string]token.Pos, outGroups[i].outputs.Len())
outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) {
switch v := v.(type) {
case *goose.Provider:
out[types.TypeString(t, nil)] = v.Pos
case goose.IfaceBinding:
out[types.TypeString(t, nil)] = v.Pos
default:
panic("unreachable")
}
})
for _, t := range sortSet(out) {
fmt.Printf("\t%s%s%s\n", green, t, reset)
fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t]))
}
}
}
return nil
}
type outGroup struct {
name string
inputs *typeutil.Map // values are not important
outputs *typeutil.Map // values are either *goose.Provider or goose.IfaceBinding
}
// gather flattens a provider set into outputs grouped by the inputs
// required to create them. As it flattens the provider set, it records
// the visited provider sets as imports.
func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports map[string]struct{}) {
hash := typeutil.MakeHasher()
// Map types to providers and bindings.
pm := new(typeutil.Map)
pm.SetHasher(hash)
next := []goose.ProviderSetID{key}
visited := make(map[goose.ProviderSetID]struct{})
imports = make(map[string]struct{})
for len(next) > 0 {
curr := next[len(next)-1]
next = next[:len(next)-1]
if _, found := visited[curr]; found {
continue
}
visited[curr] = struct{}{}
if curr != key {
imports[curr.String()] = struct{}{}
}
set := info.All[curr]
for _, p := range set.Providers {
pm.Set(p.Out, p)
}
for _, b := range set.Bindings {
pm.Set(b.Iface, b)
}
for _, imp := range set.Imports {
next = append(next, imp.ProviderSetID)
}
}
// Depth-first search to build groups.
var groups []outGroup
inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input.
inputVisited.SetHasher(hash)
pmKeys := pm.Keys()
var stk []types.Type
for _, k := range pmKeys {
// Start a DFS by picking a random unvisited node.
if inputVisited.At(k) == nil {
stk = append(stk, k)
}
// Run DFS
dfs:
for len(stk) > 0 {
curr := stk[len(stk)-1]
stk = stk[:len(stk)-1]
if inputVisited.At(curr) != nil {
continue
}
switch p := pm.At(curr).(type) {
case nil:
// This is an input.
inputVisited.Set(curr, -1)
case *goose.Provider:
// Try to see if any args haven't been visited.
allPresent := true
for _, arg := range p.Args {
if arg.Optional {
continue
}
if inputVisited.At(arg.Type) == nil {
allPresent = false
}
}
if !allPresent {
stk = append(stk, curr)
for _, arg := range p.Args {
if arg.Optional {
continue
}
if inputVisited.At(arg.Type) == nil {
stk = append(stk, arg.Type)
}
}
continue dfs
}
// Build up set of input types, match to a group.
in := new(typeutil.Map)
in.SetHasher(hash)
for _, arg := range p.Args {
if arg.Optional {
continue
}
i := inputVisited.At(arg.Type).(int)
if i == -1 {
in.Set(arg.Type, true)
} else {
mergeTypeSets(in, groups[i].inputs)
}
}
for i := range groups {
if sameTypeKeys(groups[i].inputs, in) {
groups[i].outputs.Set(p.Out, p)
inputVisited.Set(p.Out, i)
continue dfs
}
}
out := new(typeutil.Map)
out.SetHasher(hash)
out.Set(p.Out, p)
inputVisited.Set(p.Out, len(groups))
groups = append(groups, outGroup{
inputs: in,
outputs: out,
})
case goose.IfaceBinding:
i, ok := inputVisited.At(p.Provided).(int)
if !ok {
stk = append(stk, curr, p.Provided)
continue dfs
}
if i != -1 {
groups[i].outputs.Set(p.Iface, p)
inputVisited.Set(p.Iface, i)
continue dfs
}
// Binding must be provided. Find or add a group.
for i := range groups {
if groups[i].inputs.Len() != 1 {
continue
}
if groups[i].inputs.At(p.Provided) != nil {
groups[i].outputs.Set(p.Iface, p)
inputVisited.Set(p.Iface, i)
continue dfs
}
}
in := new(typeutil.Map)
in.SetHasher(hash)
in.Set(p.Provided, true)
out := new(typeutil.Map)
out.SetHasher(hash)
out.Set(p.Iface, p)
groups = append(groups, outGroup{
inputs: in,
outputs: out,
})
default:
panic("unreachable")
}
}
}
// Name and sort groups
for i := range groups {
if groups[i].inputs.Len() == 0 {
groups[i].name = "no inputs"
continue
}
instr := make([]string, 0, groups[i].inputs.Len())
groups[i].inputs.Iterate(func(k types.Type, _ interface{}) {
instr = append(instr, types.TypeString(k, nil))
})
sort.Strings(instr)
groups[i].name = strings.Join(instr, ", ")
}
sort.Slice(groups, func(i, j int) bool {
if groups[i].inputs.Len() == groups[j].inputs.Len() {
return groups[i].name < groups[j].name
}
return groups[i].inputs.Len() < groups[j].inputs.Len()
})
return groups, imports
}
func mergeTypeSets(dst, src *typeutil.Map) {
src.Iterate(func(k types.Type, _ interface{}) {
dst.Set(k, true)
})
}
func sameTypeKeys(a, b *typeutil.Map) bool {
if a.Len() != b.Len() {
return false
}
same := true
a.Iterate(func(k types.Type, _ interface{}) {
if b.At(k) == nil {
same = false
}
})
return same
}
func sortSet(set interface{}) []string {
rv := reflect.ValueOf(set)
a := make([]string, 0, rv.Len())
keys := rv.MapKeys()
for _, k := range keys {
a = append(a, k.String())
}
sort.Strings(a)
return a
}