Files
wire/cmd/gowire/main.go
Ross Light b12449f9e3 wire: build provider map incrementally (google/go-cloud#96)
One small breaking change: a provider set can no longer include an
interface binding to a concrete type that is not being provided
(directly or indirectly) by the provider set. I can't imagine a
reasonable use case for the previous behavior, so this likely will
catch more errors

In terms of operation, binding conflict error messages will now give
much more specific line numbers, since they will be reported closer to
where the problem occurred.

Now that provider sets gather this information, it can be exposed in
the package API. gowire now uses this information instead of
trying to build it itself.

Fixes google/go-cloud#29
2018-11-13 13:16:45 -08:00

331 lines
8.3 KiB
Go

// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// gowire is a compile-time dependency injection tool.
//
// See README.md for an overview.
package main
import (
"fmt"
"go/build"
"go/token"
"go/types"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"sort"
"strconv"
"strings"
"github.com/google/go-cloud/wire/internal/wire"
"golang.org/x/tools/go/types/typeutil"
)
func main() {
var err error
switch {
case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen":
err = generate(".")
case len(os.Args) == 2 && os.Args[1] == "show":
err = show(".")
case len(os.Args) == 2:
err = generate(os.Args[1])
case len(os.Args) > 2 && os.Args[1] == "show":
err = show(os.Args[2:]...)
case len(os.Args) == 3 && os.Args[1] == "gen":
err = generate(os.Args[2])
default:
fmt.Fprintln(os.Stderr, "gowire: usage: gowire [gen] [PKG] | gowire show [...]")
os.Exit(64)
}
if err != nil {
fmt.Fprintln(os.Stderr, "gowire:", err)
os.Exit(1)
}
}
// generate runs the gen subcommand. Given a package, gen will create
// the wire_gen.go file.
func generate(pkg string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly)
if err != nil {
return err
}
out, err := wire.Generate(&build.Default, wd, pkg)
if err != nil {
return err
}
if len(out) == 0 {
// No Wire directives, don't write anything.
fmt.Fprintln(os.Stderr, "gowire: no injector found for", pkg)
return nil
}
p := filepath.Join(pkgInfo.Dir, "wire_gen.go")
if err := ioutil.WriteFile(p, out, 0666); err != nil {
return err
}
return nil
}
// show runs the show subcommand.
//
// Given one or more packages, show will find all the provider sets
// declared as top-level variables and print what other provider sets it
// imports and what outputs it can produce, given possible inputs.
func show(pkgs ...string) error {
wd, err := os.Getwd()
if err != nil {
return err
}
info, err := wire.Load(&build.Default, wd, pkgs)
if err != nil {
return err
}
keys := make([]wire.ProviderSetID, 0, len(info.Sets))
for k := range info.Sets {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].ImportPath == keys[j].ImportPath {
return keys[i].VarName < keys[j].VarName
}
return keys[i].ImportPath < keys[j].ImportPath
})
// ANSI color codes.
// TODO(light): Possibly use github.com/fatih/color?
const (
reset = "\x1b[0m"
redBold = "\x1b[0;1;31m"
blue = "\x1b[0;34m"
green = "\x1b[0;32m"
)
for i, k := range keys {
if i > 0 {
fmt.Println()
}
outGroups, imports := gather(info, k)
fmt.Printf("%s%s%s\n", redBold, k, reset)
for _, imp := range sortSet(imports) {
fmt.Printf("\t%s\n", imp)
}
for i := range outGroups {
fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset)
out := make(map[string]token.Pos, outGroups[i].outputs.Len())
outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) {
switch v := v.(type) {
case *wire.Provider:
out[types.TypeString(t, nil)] = v.Pos
case *wire.Value:
out[types.TypeString(t, nil)] = v.Pos
default:
panic("unreachable")
}
})
for _, t := range sortSet(out) {
fmt.Printf("\t%s%s%s\n", green, t, reset)
fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t]))
}
}
}
return nil
}
type outGroup struct {
name string
inputs *typeutil.Map // values are not important
outputs *typeutil.Map // values are *wire.Provider or *wire.Value
}
// gather flattens a provider set into outputs grouped by the inputs
// required to create them. As it flattens the provider set, it records
// the visited named provider sets as imports.
func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[string]struct{}) {
set := info.Sets[key]
hash := typeutil.MakeHasher()
// Find imports.
next := []*wire.ProviderSet{info.Sets[key]}
visited := make(map[*wire.ProviderSet]struct{})
imports = make(map[string]struct{})
for len(next) > 0 {
curr := next[len(next)-1]
next = next[:len(next)-1]
if _, found := visited[curr]; found {
continue
}
visited[curr] = struct{}{}
if curr.Name != "" && !(curr.PkgPath == key.ImportPath && curr.Name == key.VarName) {
imports[formatProviderSetName(curr.PkgPath, curr.Name)] = struct{}{}
}
for _, imp := range curr.Imports {
next = append(next, imp)
}
}
// Depth-first search to build groups.
var groups []outGroup
inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input.
inputVisited.SetHasher(hash)
var stk []types.Type
for _, k := range set.Outputs() {
// Start a DFS by picking a random unvisited node.
if inputVisited.At(k) == nil {
stk = append(stk, k)
}
// Run DFS
dfs:
for len(stk) > 0 {
curr := stk[len(stk)-1]
stk = stk[:len(stk)-1]
if inputVisited.At(curr) != nil {
continue
}
switch pv := set.For(curr); {
case pv.IsNil():
// This is an input.
inputVisited.Set(curr, -1)
case pv.IsProvider():
// Try to see if any args haven't been visited.
p := pv.Provider()
allPresent := true
for _, arg := range p.Args {
if inputVisited.At(arg.Type) == nil {
allPresent = false
}
}
if !allPresent {
stk = append(stk, curr)
for _, arg := range p.Args {
if inputVisited.At(arg.Type) == nil {
stk = append(stk, arg.Type)
}
}
continue dfs
}
// Build up set of input types, match to a group.
in := new(typeutil.Map)
in.SetHasher(hash)
for _, arg := range p.Args {
i := inputVisited.At(arg.Type).(int)
if i == -1 {
in.Set(arg.Type, true)
} else {
mergeTypeSets(in, groups[i].inputs)
}
}
for i := range groups {
if sameTypeKeys(groups[i].inputs, in) {
groups[i].outputs.Set(curr, p)
inputVisited.Set(curr, i)
continue dfs
}
}
out := new(typeutil.Map)
out.SetHasher(hash)
out.Set(curr, p)
inputVisited.Set(curr, len(groups))
groups = append(groups, outGroup{
inputs: in,
outputs: out,
})
case pv.IsValue():
v := pv.Value()
for i := range groups {
if groups[i].inputs.Len() == 0 {
groups[i].outputs.Set(curr, v)
inputVisited.Set(curr, i)
continue dfs
}
}
in := new(typeutil.Map)
in.SetHasher(hash)
out := new(typeutil.Map)
out.SetHasher(hash)
out.Set(curr, v)
inputVisited.Set(curr, len(groups))
groups = append(groups, outGroup{
inputs: in,
outputs: out,
})
default:
panic("unreachable")
}
}
}
// Name and sort groups.
for i := range groups {
if groups[i].inputs.Len() == 0 {
groups[i].name = "no inputs"
continue
}
instr := make([]string, 0, groups[i].inputs.Len())
groups[i].inputs.Iterate(func(k types.Type, _ interface{}) {
instr = append(instr, types.TypeString(k, nil))
})
sort.Strings(instr)
groups[i].name = strings.Join(instr, ", ")
}
sort.Slice(groups, func(i, j int) bool {
if groups[i].inputs.Len() == groups[j].inputs.Len() {
return groups[i].name < groups[j].name
}
return groups[i].inputs.Len() < groups[j].inputs.Len()
})
return groups, imports
}
func mergeTypeSets(dst, src *typeutil.Map) {
src.Iterate(func(k types.Type, _ interface{}) {
dst.Set(k, true)
})
}
func sameTypeKeys(a, b *typeutil.Map) bool {
if a.Len() != b.Len() {
return false
}
same := true
a.Iterate(func(k types.Type, _ interface{}) {
if b.At(k) == nil {
same = false
}
})
return same
}
func sortSet(set interface{}) []string {
rv := reflect.ValueOf(set)
a := make([]string, 0, rv.Len())
keys := rv.MapKeys()
for _, k := range keys {
a = append(a, k.String())
}
sort.Strings(a)
return a
}
func formatProviderSetName(importPath, varName string) string {
// Since varName is an identifier, it doesn't make sense to quote.
return strconv.Quote(importPath) + "." + varName
}