wire: use subcommands package, improving help (#173)

This commit is contained in:
Robert van Gent
2019-05-14 12:51:16 -07:00
committed by GitHub
parent c1be6ec0d8
commit d76a979091
4 changed files with 165 additions and 104 deletions

View File

@@ -19,90 +19,105 @@ package main
import (
"context"
"errors"
"flag"
"fmt"
"go/token"
"go/types"
"io/ioutil"
"log"
"os"
"reflect"
"sort"
"strconv"
"strings"
"github.com/google/subcommands"
"github.com/google/wire/internal/wire"
"github.com/pmezard/go-difflib/difflib"
"golang.org/x/tools/go/types/typeutil"
)
const usage = "usage: wire [gen|diff|show|check] [...]"
func main() {
var (
exitCode = 0
err error
)
switch {
case len(os.Args) == 2 && (os.Args[1] == "help" || os.Args[1] == "-h" || os.Args[1] == "-help" || os.Args[1] == "--help"):
fmt.Fprintln(os.Stderr, usage)
os.Exit(0)
case len(os.Args) == 2 && os.Args[1] == "show":
err = show(".")
case len(os.Args) > 2 && os.Args[1] == "show":
err = show(os.Args[2:]...)
case len(os.Args) == 2 && os.Args[1] == "check":
err = check(".")
case len(os.Args) > 2 && os.Args[1] == "check":
err = check(os.Args[2:]...)
case len(os.Args) == 2 && os.Args[1] == "diff":
exitCode, err = diff(".")
case len(os.Args) > 2 && os.Args[1] == "diff":
exitCode, err = diff(os.Args[2:]...)
case len(os.Args) == 2 && os.Args[1] == "gen":
err = generate(".")
case len(os.Args) > 2 && os.Args[1] == "gen":
err = generate(os.Args[2:]...)
// No explicit command given, assume "gen".
case len(os.Args) == 1:
err = generate(".")
case len(os.Args) > 1:
err = generate(os.Args[1:]...)
default:
fmt.Fprintln(os.Stderr, usage)
exitCode = 64
subcommands.Register(subcommands.CommandsCommand(), "")
subcommands.Register(subcommands.FlagsCommand(), "")
subcommands.Register(subcommands.HelpCommand(), "")
subcommands.Register(&checkCmd{}, "")
subcommands.Register(&diffCmd{}, "")
subcommands.Register(&genCmd{}, "")
subcommands.Register(&showCmd{}, "")
flag.Parse()
// Initialize the default logger to log to stderr.
log.SetFlags(0)
log.SetPrefix("wire: ")
log.SetOutput(os.Stderr)
// TODO(rvangent): Use subcommands's VisitCommands instead of hardcoded map,
// once there is a release that contains it:
// allCmds := map[string]bool{}
// subcommands.DefaultCommander.VisitCommands(func(_ *subcommands.CommandGroup, cmd subcommands.Command) { allCmds[cmd.Name()] = true })
allCmds := map[string]bool{
"commands": true, // builtin
"help": true, // builtin
"flags": true, // builtin
"check": true,
"diff": true,
"gen": true,
"show": true,
}
if err != nil {
fmt.Fprintln(os.Stderr, "wire:", err)
// Don't override more specific error codes from above
// (e.g., diff returns 2 on error).
if exitCode == 0 {
exitCode = 1
// Default to running the "gen" command.
if args := flag.Args(); len(args) == 0 || !allCmds[args[0]] {
genCmd := &genCmd{}
os.Exit(int(genCmd.Execute(context.Background(), flag.CommandLine)))
}
}
os.Exit(exitCode)
os.Exit(int(subcommands.Execute(context.Background())))
}
// generate runs the gen subcommand.
//
// Given one or more packages, gen will create the wire_gen.go file for each.
func generate(pkgs ...string) error {
// packages returns the slice of packages to run wire over based on f.
// It defaults to ".".
func packages(f *flag.FlagSet) []string {
pkgs := f.Args()
if len(pkgs) == 0 {
pkgs = []string{"."}
}
return pkgs
}
type genCmd struct{}
func (*genCmd) Name() string { return "gen" }
func (*genCmd) Synopsis() string {
return "generate the wire_gen.go file for each package"
}
func (*genCmd) Usage() string {
return `gen [packages]
Given one or more packages, gen creates the wire_gen.go file for each.
If no packages are listed, it defaults to ".".
`
}
func (*genCmd) SetFlags(_ *flag.FlagSet) {}
func (*genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
wd, err := os.Getwd()
if err != nil {
return err
log.Println("failed to get working directory: ", err)
return subcommands.ExitFailure
}
outs, errs := wire.Generate(context.Background(), wd, os.Environ(), pkgs)
outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f))
if len(errs) > 0 {
logErrors(errs)
return errors.New("generate failed")
log.Println("generate failed")
return subcommands.ExitFailure
}
if len(outs) == 0 {
return nil
return subcommands.ExitSuccess
}
success := true
for _, out := range outs {
if len(out.Errs) > 0 {
fmt.Fprintf(os.Stderr, "%s: generate failed\n", out.PkgPath)
logErrors(out.Errs)
log.Printf("%s: generate failed\n", out.PkgPath)
success = false
}
if len(out.Content) == 0 {
@@ -110,53 +125,63 @@ func generate(pkgs ...string) error {
continue
}
if err := out.Commit(); err == nil {
fmt.Fprintf(os.Stderr, "%s: wrote %s\n", out.PkgPath, out.OutputPath)
log.Printf("%s: wrote %s\n", out.PkgPath, out.OutputPath)
} else {
fmt.Fprintf(os.Stderr, "%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err)
log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err)
success = false
}
}
if !success {
return errors.New("at least one generate failure")
log.Println("at least one generate failure")
return subcommands.ExitFailure
}
return nil
return subcommands.ExitSuccess
}
// diff runs the diff subcommand.
//
// Given one or more packages, diff will generate the content for the
// wire_gen.go file, and output the diff against the existing file.
//
// Similar to the diff command, it returns 0 if no diff, 1 if different, 2
// plus an error if trouble.
func diff(pkgs ...string) (int, error) {
errReturn := func(err error) (int, error) {
return 2, err
}
okReturn := func(hadDiff bool) (int, error) {
if hadDiff {
return 1, nil
}
return 0, nil
}
type diffCmd struct{}
func (*diffCmd) Name() string { return "diff" }
func (*diffCmd) Synopsis() string {
return "output a diff between existing wire_gen.go files and what gen would generate"
}
func (*diffCmd) Usage() string {
return `diff [packages]
Given one or more packages, diff generates the content for their wire_gen.go
files and outputs the diff against the existing files.
If no packages are listed, it defaults to ".".
Similar to the diff command, it returns 0 if no diff, 1 if different, 2
plus an error if trouble.
`
}
func (*diffCmd) SetFlags(_ *flag.FlagSet) {}
func (*diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
const (
errReturn = subcommands.ExitStatus(2)
diffReturn = subcommands.ExitStatus(1)
)
wd, err := os.Getwd()
if err != nil {
return errReturn(err)
log.Println("failed to get working directory: ", err)
return errReturn
}
outs, errs := wire.Generate(context.Background(), wd, os.Environ(), pkgs)
outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f))
if len(errs) > 0 {
logErrors(errs)
return errReturn(errors.New("generate failed"))
log.Println("generate failed")
return errReturn
}
if len(outs) == 0 {
return okReturn(false)
return subcommands.ExitSuccess
}
success := true
hadDiff := false
for _, out := range outs {
if len(out.Errs) > 0 {
fmt.Fprintf(os.Stderr, "%s: generate failed\n", out.PkgPath)
logErrors(out.Errs)
log.Printf("%s: generate failed\n", out.PkgPath)
success = false
}
if len(out.Content) == 0 {
@@ -170,32 +195,50 @@ func diff(pkgs ...string) (int, error) {
B: difflib.SplitLines(string(out.Content)),
}); err == nil {
if diff != "" {
fmt.Fprintf(os.Stdout, "%s: diff from %s:\n%s", out.PkgPath, out.OutputPath, diff)
// Print the actual diff to stdout, not stderr.
fmt.Printf("%s: diff from %s:\n%s\n", out.PkgPath, out.OutputPath, diff)
hadDiff = true
}
} else {
fmt.Fprintf(os.Stderr, "%s: failed to diff %s: %v\n", out.PkgPath, out.OutputPath, err)
log.Printf("%s: failed to diff %s: %v\n", out.PkgPath, out.OutputPath, err)
success = false
}
}
if !success {
return errReturn(errors.New("at least one generate failure"))
log.Println("at least one generate failure")
return errReturn
}
return okReturn(hadDiff)
if hadDiff {
return diffReturn
}
return subcommands.ExitSuccess
}
// show runs the show subcommand.
//
// Given one or more packages, show will find all the provider sets
// declared as top-level variables and print what other provider sets it
// imports and what outputs it can produce, given possible inputs.
// It also lists any injector functions defined in the package.
func show(pkgs ...string) error {
type showCmd struct{}
func (*showCmd) Name() string { return "show" }
func (*showCmd) Synopsis() string {
return "describe all top-level provider sets"
}
func (*showCmd) Usage() string {
return `show [packages]
Given one or more packages, show finds all the provider sets declared as
top-level variables and prints what other provider sets they import and what
outputs they can produce, given possible inputs. It also lists any injector
functions defined in the package.
If no packages are listed, it defaults to ".".
`
}
func (*showCmd) SetFlags(_ *flag.FlagSet) {}
func (*showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
wd, err := os.Getwd()
if err != nil {
return err
log.Println("failed to get working directory: ", err)
return subcommands.ExitFailure
}
info, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs)
info, errs := wire.Load(ctx, wd, os.Environ(), packages(f))
if info != nil {
keys := make([]wire.ProviderSetID, 0, len(info.Sets))
for k := range info.Sets {
@@ -261,27 +304,41 @@ func show(pkgs ...string) error {
}
if len(errs) > 0 {
logErrors(errs)
return errors.New("error loading packages")
log.Println("error loading packages")
return subcommands.ExitFailure
}
return nil
return subcommands.ExitSuccess
}
// check runs the check subcommand.
//
// Given one or more packages, check will print any type-checking or
// Wire errors found with top-level variable provider sets or injector
// functions.
func check(pkgs ...string) error {
type checkCmd struct{}
func (*checkCmd) Name() string { return "check" }
func (*checkCmd) Synopsis() string {
return "print any Wire errors found"
}
func (*checkCmd) Usage() string {
return `check [packages]
Given one or more packages, check prints any type-checking or Wire errors
found with top-level variable provider sets or injector functions.
If no packages are listed, it defaults to ".".
`
}
func (*checkCmd) SetFlags(_ *flag.FlagSet) {}
func (*checkCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
wd, err := os.Getwd()
if err != nil {
return err
log.Println("failed to get working directory: ", err)
return subcommands.ExitFailure
}
_, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs)
_, errs := wire.Load(ctx, wd, os.Environ(), packages(f))
if len(errs) > 0 {
logErrors(errs)
return errors.New("error loading packages")
log.Println("error loading packages")
return subcommands.ExitFailure
}
return nil
return subcommands.ExitSuccess
}
type outGroup struct {
@@ -503,6 +560,6 @@ func formatProviderSetName(importPath, varName string) string {
func logErrors(errs []error) {
for _, err := range errs {
fmt.Fprintln(os.Stderr, strings.Replace(err.Error(), "\n", "\n\t", -1))
log.Println(strings.Replace(err.Error(), "\n", "\n\t", -1))
}
}

1
go.mod
View File

@@ -2,6 +2,7 @@ module github.com/google/wire
require (
github.com/google/go-cmp v0.2.0
github.com/google/subcommands v1.0.1
github.com/pmezard/go-difflib v1.0.0
golang.org/x/tools v0.0.0-20190422233926-fe54fb35175b
)

2
go.sum
View File

@@ -1,5 +1,7 @@
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/subcommands v1.0.1 h1:/eqq+otEXm5vhfBrbREPCSVQbvofip6kIz+mX5TUH7k=
github.com/google/subcommands v1.0.1/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View File

@@ -1,3 +1,4 @@
github.com/google/subcommands
github.com/google/wire
github.com/pmezard/go-difflib
golang.org/x/tools