// Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // gowire is a compile-time dependency injection tool. // // See README.md for an overview. package main import ( "fmt" "go/build" "go/token" "go/types" "io/ioutil" "os" "path/filepath" "reflect" "sort" "strconv" "strings" "github.com/google/go-x-cloud/wire/internal/wire" "golang.org/x/tools/go/types/typeutil" ) func main() { var err error switch { case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen": err = generate(".") case len(os.Args) == 2 && os.Args[1] == "show": err = show(".") case len(os.Args) == 2: err = generate(os.Args[1]) case len(os.Args) > 2 && os.Args[1] == "show": err = show(os.Args[2:]...) case len(os.Args) == 3 && os.Args[1] == "gen": err = generate(os.Args[2]) default: fmt.Fprintln(os.Stderr, "gowire: usage: gowire [gen] [PKG] | gowire show [...]") os.Exit(64) } if err != nil { fmt.Fprintln(os.Stderr, "gowire:", err) os.Exit(1) } } // generate runs the gen subcommand. Given a package, gen will create // the wire_gen.go file. func generate(pkg string) error { wd, err := os.Getwd() if err != nil { return err } pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly) if err != nil { return err } out, err := wire.Generate(&build.Default, wd, pkg) if err != nil { return err } if len(out) == 0 { // No Wire directives, don't write anything. fmt.Fprintln(os.Stderr, "gowire: no injector found for", pkg) return nil } p := filepath.Join(pkgInfo.Dir, "wire_gen.go") if err := ioutil.WriteFile(p, out, 0666); err != nil { return err } return nil } // show runs the show subcommand. // // Given one or more packages, show will find all the provider sets // declared as top-level variables and print what other provider sets it // imports and what outputs it can produce, given possible inputs. func show(pkgs ...string) error { wd, err := os.Getwd() if err != nil { return err } info, err := wire.Load(&build.Default, wd, pkgs) if err != nil { return err } keys := make([]wire.ProviderSetID, 0, len(info.Sets)) for k := range info.Sets { keys = append(keys, k) } sort.Slice(keys, func(i, j int) bool { if keys[i].ImportPath == keys[j].ImportPath { return keys[i].VarName < keys[j].VarName } return keys[i].ImportPath < keys[j].ImportPath }) // ANSI color codes. // TODO(light): Possibly use github.com/fatih/color? const ( reset = "\x1b[0m" redBold = "\x1b[0;1;31m" blue = "\x1b[0;34m" green = "\x1b[0;32m" ) for i, k := range keys { if i > 0 { fmt.Println() } outGroups, imports := gather(info, k) fmt.Printf("%s%s%s\n", redBold, k, reset) for _, imp := range sortSet(imports) { fmt.Printf("\t%s\n", imp) } for i := range outGroups { fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset) out := make(map[string]token.Pos, outGroups[i].outputs.Len()) outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) { switch v := v.(type) { case *wire.Provider: out[types.TypeString(t, nil)] = v.Pos case *wire.Value: out[types.TypeString(t, nil)] = v.Pos default: panic("unreachable") } }) for _, t := range sortSet(out) { fmt.Printf("\t%s%s%s\n", green, t, reset) fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t])) } } } return nil } type outGroup struct { name string inputs *typeutil.Map // values are not important outputs *typeutil.Map // values are *wire.Provider or *wire.Value } // gather flattens a provider set into outputs grouped by the inputs // required to create them. As it flattens the provider set, it records // the visited named provider sets as imports. func gather(info *wire.Info, key wire.ProviderSetID) (_ []outGroup, imports map[string]struct{}) { set := info.Sets[key] hash := typeutil.MakeHasher() // Find imports. next := []*wire.ProviderSet{info.Sets[key]} visited := make(map[*wire.ProviderSet]struct{}) imports = make(map[string]struct{}) for len(next) > 0 { curr := next[len(next)-1] next = next[:len(next)-1] if _, found := visited[curr]; found { continue } visited[curr] = struct{}{} if curr.Name != "" && !(curr.PkgPath == key.ImportPath && curr.Name == key.VarName) { imports[formatProviderSetName(curr.PkgPath, curr.Name)] = struct{}{} } for _, imp := range curr.Imports { next = append(next, imp) } } // Depth-first search to build groups. var groups []outGroup inputVisited := new(typeutil.Map) // values are int, indices into groups or -1 for input. inputVisited.SetHasher(hash) var stk []types.Type for _, k := range set.Outputs() { // Start a DFS by picking a random unvisited node. if inputVisited.At(k) == nil { stk = append(stk, k) } // Run DFS dfs: for len(stk) > 0 { curr := stk[len(stk)-1] stk = stk[:len(stk)-1] if inputVisited.At(curr) != nil { continue } switch pv := set.For(curr); { case pv.IsNil(): // This is an input. inputVisited.Set(curr, -1) case pv.IsProvider(): // Try to see if any args haven't been visited. p := pv.Provider() allPresent := true for _, arg := range p.Args { if inputVisited.At(arg.Type) == nil { allPresent = false } } if !allPresent { stk = append(stk, curr) for _, arg := range p.Args { if inputVisited.At(arg.Type) == nil { stk = append(stk, arg.Type) } } continue dfs } // Build up set of input types, match to a group. in := new(typeutil.Map) in.SetHasher(hash) for _, arg := range p.Args { i := inputVisited.At(arg.Type).(int) if i == -1 { in.Set(arg.Type, true) } else { mergeTypeSets(in, groups[i].inputs) } } for i := range groups { if sameTypeKeys(groups[i].inputs, in) { groups[i].outputs.Set(curr, p) inputVisited.Set(curr, i) continue dfs } } out := new(typeutil.Map) out.SetHasher(hash) out.Set(curr, p) inputVisited.Set(curr, len(groups)) groups = append(groups, outGroup{ inputs: in, outputs: out, }) case pv.IsValue(): v := pv.Value() for i := range groups { if groups[i].inputs.Len() == 0 { groups[i].outputs.Set(curr, v) inputVisited.Set(curr, i) continue dfs } } in := new(typeutil.Map) in.SetHasher(hash) out := new(typeutil.Map) out.SetHasher(hash) out.Set(curr, v) inputVisited.Set(curr, len(groups)) groups = append(groups, outGroup{ inputs: in, outputs: out, }) default: panic("unreachable") } } } // Name and sort groups. for i := range groups { if groups[i].inputs.Len() == 0 { groups[i].name = "no inputs" continue } instr := make([]string, 0, groups[i].inputs.Len()) groups[i].inputs.Iterate(func(k types.Type, _ interface{}) { instr = append(instr, types.TypeString(k, nil)) }) sort.Strings(instr) groups[i].name = strings.Join(instr, ", ") } sort.Slice(groups, func(i, j int) bool { if groups[i].inputs.Len() == groups[j].inputs.Len() { return groups[i].name < groups[j].name } return groups[i].inputs.Len() < groups[j].inputs.Len() }) return groups, imports } func mergeTypeSets(dst, src *typeutil.Map) { src.Iterate(func(k types.Type, _ interface{}) { dst.Set(k, true) }) } func sameTypeKeys(a, b *typeutil.Map) bool { if a.Len() != b.Len() { return false } same := true a.Iterate(func(k types.Type, _ interface{}) { if b.At(k) == nil { same = false } }) return same } func sortSet(set interface{}) []string { rv := reflect.ValueOf(set) a := make([]string, 0, rv.Len()) keys := rv.MapKeys() for _, k := range keys { a = append(a, k.String()) } sort.Strings(a) return a } func formatProviderSetName(importPath, varName string) string { // Since varName is an identifier, it doesn't make sense to quote. return strconv.Quote(importPath) + "." + varName }