diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 35d9ea0..2256155 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -19,90 +19,105 @@ package main import ( "context" - "errors" + "flag" "fmt" "go/token" "go/types" "io/ioutil" + "log" "os" "reflect" "sort" "strconv" "strings" + "github.com/google/subcommands" "github.com/google/wire/internal/wire" "github.com/pmezard/go-difflib/difflib" "golang.org/x/tools/go/types/typeutil" ) -const usage = "usage: wire [gen|diff|show|check] [...]" - func main() { - var ( - exitCode = 0 - err error - ) - switch { - case len(os.Args) == 2 && (os.Args[1] == "help" || os.Args[1] == "-h" || os.Args[1] == "-help" || os.Args[1] == "--help"): - fmt.Fprintln(os.Stderr, usage) - os.Exit(0) - case len(os.Args) == 2 && os.Args[1] == "show": - err = show(".") - case len(os.Args) > 2 && os.Args[1] == "show": - err = show(os.Args[2:]...) - case len(os.Args) == 2 && os.Args[1] == "check": - err = check(".") - case len(os.Args) > 2 && os.Args[1] == "check": - err = check(os.Args[2:]...) - case len(os.Args) == 2 && os.Args[1] == "diff": - exitCode, err = diff(".") - case len(os.Args) > 2 && os.Args[1] == "diff": - exitCode, err = diff(os.Args[2:]...) - case len(os.Args) == 2 && os.Args[1] == "gen": - err = generate(".") - case len(os.Args) > 2 && os.Args[1] == "gen": - err = generate(os.Args[2:]...) - // No explicit command given, assume "gen". - case len(os.Args) == 1: - err = generate(".") - case len(os.Args) > 1: - err = generate(os.Args[1:]...) - default: - fmt.Fprintln(os.Stderr, usage) - exitCode = 64 + subcommands.Register(subcommands.CommandsCommand(), "") + subcommands.Register(subcommands.FlagsCommand(), "") + subcommands.Register(subcommands.HelpCommand(), "") + subcommands.Register(&checkCmd{}, "") + subcommands.Register(&diffCmd{}, "") + subcommands.Register(&genCmd{}, "") + subcommands.Register(&showCmd{}, "") + flag.Parse() + + // Initialize the default logger to log to stderr. + log.SetFlags(0) + log.SetPrefix("wire: ") + log.SetOutput(os.Stderr) + + // TODO(rvangent): Use subcommands's VisitCommands instead of hardcoded map, + // once there is a release that contains it: + // allCmds := map[string]bool{} + // subcommands.DefaultCommander.VisitCommands(func(_ *subcommands.CommandGroup, cmd subcommands.Command) { allCmds[cmd.Name()] = true }) + allCmds := map[string]bool{ + "commands": true, // builtin + "help": true, // builtin + "flags": true, // builtin + "check": true, + "diff": true, + "gen": true, + "show": true, } - if err != nil { - fmt.Fprintln(os.Stderr, "wire:", err) - // Don't override more specific error codes from above - // (e.g., diff returns 2 on error). - if exitCode == 0 { - exitCode = 1 - } + // Default to running the "gen" command. + if args := flag.Args(); len(args) == 0 || !allCmds[args[0]] { + genCmd := &genCmd{} + os.Exit(int(genCmd.Execute(context.Background(), flag.CommandLine))) } - os.Exit(exitCode) + os.Exit(int(subcommands.Execute(context.Background()))) } -// generate runs the gen subcommand. -// -// Given one or more packages, gen will create the wire_gen.go file for each. -func generate(pkgs ...string) error { +// packages returns the slice of packages to run wire over based on f. +// It defaults to ".". +func packages(f *flag.FlagSet) []string { + pkgs := f.Args() + if len(pkgs) == 0 { + pkgs = []string{"."} + } + return pkgs +} + +type genCmd struct{} + +func (*genCmd) Name() string { return "gen" } +func (*genCmd) Synopsis() string { + return "generate the wire_gen.go file for each package" +} +func (*genCmd) Usage() string { + return `gen [packages] + + Given one or more packages, gen creates the wire_gen.go file for each. + + If no packages are listed, it defaults to ".". +` +} +func (*genCmd) SetFlags(_ *flag.FlagSet) {} +func (*genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { wd, err := os.Getwd() if err != nil { - return err + log.Println("failed to get working directory: ", err) + return subcommands.ExitFailure } - outs, errs := wire.Generate(context.Background(), wd, os.Environ(), pkgs) + outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f)) if len(errs) > 0 { logErrors(errs) - return errors.New("generate failed") + log.Println("generate failed") + return subcommands.ExitFailure } if len(outs) == 0 { - return nil + return subcommands.ExitSuccess } success := true for _, out := range outs { if len(out.Errs) > 0 { - fmt.Fprintf(os.Stderr, "%s: generate failed\n", out.PkgPath) logErrors(out.Errs) + log.Printf("%s: generate failed\n", out.PkgPath) success = false } if len(out.Content) == 0 { @@ -110,53 +125,63 @@ func generate(pkgs ...string) error { continue } if err := out.Commit(); err == nil { - fmt.Fprintf(os.Stderr, "%s: wrote %s\n", out.PkgPath, out.OutputPath) + log.Printf("%s: wrote %s\n", out.PkgPath, out.OutputPath) } else { - fmt.Fprintf(os.Stderr, "%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) + log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) success = false } } if !success { - return errors.New("at least one generate failure") + log.Println("at least one generate failure") + return subcommands.ExitFailure } - return nil + return subcommands.ExitSuccess } -// diff runs the diff subcommand. -// -// Given one or more packages, diff will generate the content for the -// wire_gen.go file, and output the diff against the existing file. -// -// Similar to the diff command, it returns 0 if no diff, 1 if different, 2 -// plus an error if trouble. -func diff(pkgs ...string) (int, error) { - errReturn := func(err error) (int, error) { - return 2, err - } - okReturn := func(hadDiff bool) (int, error) { - if hadDiff { - return 1, nil - } - return 0, nil - } +type diffCmd struct{} + +func (*diffCmd) Name() string { return "diff" } +func (*diffCmd) Synopsis() string { + return "output a diff between existing wire_gen.go files and what gen would generate" +} +func (*diffCmd) Usage() string { + return `diff [packages] + + Given one or more packages, diff generates the content for their wire_gen.go + files and outputs the diff against the existing files. + + If no packages are listed, it defaults to ".". + + Similar to the diff command, it returns 0 if no diff, 1 if different, 2 + plus an error if trouble. +` +} +func (*diffCmd) SetFlags(_ *flag.FlagSet) {} +func (*diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + const ( + errReturn = subcommands.ExitStatus(2) + diffReturn = subcommands.ExitStatus(1) + ) wd, err := os.Getwd() if err != nil { - return errReturn(err) + log.Println("failed to get working directory: ", err) + return errReturn } - outs, errs := wire.Generate(context.Background(), wd, os.Environ(), pkgs) + outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f)) if len(errs) > 0 { logErrors(errs) - return errReturn(errors.New("generate failed")) + log.Println("generate failed") + return errReturn } if len(outs) == 0 { - return okReturn(false) + return subcommands.ExitSuccess } success := true hadDiff := false for _, out := range outs { if len(out.Errs) > 0 { - fmt.Fprintf(os.Stderr, "%s: generate failed\n", out.PkgPath) logErrors(out.Errs) + log.Printf("%s: generate failed\n", out.PkgPath) success = false } if len(out.Content) == 0 { @@ -170,32 +195,50 @@ func diff(pkgs ...string) (int, error) { B: difflib.SplitLines(string(out.Content)), }); err == nil { if diff != "" { - fmt.Fprintf(os.Stdout, "%s: diff from %s:\n%s", out.PkgPath, out.OutputPath, diff) + // Print the actual diff to stdout, not stderr. + fmt.Printf("%s: diff from %s:\n%s\n", out.PkgPath, out.OutputPath, diff) hadDiff = true } } else { - fmt.Fprintf(os.Stderr, "%s: failed to diff %s: %v\n", out.PkgPath, out.OutputPath, err) + log.Printf("%s: failed to diff %s: %v\n", out.PkgPath, out.OutputPath, err) success = false } } if !success { - return errReturn(errors.New("at least one generate failure")) + log.Println("at least one generate failure") + return errReturn } - return okReturn(hadDiff) + if hadDiff { + return diffReturn + } + return subcommands.ExitSuccess } -// show runs the show subcommand. -// -// Given one or more packages, show will find all the provider sets -// declared as top-level variables and print what other provider sets it -// imports and what outputs it can produce, given possible inputs. -// It also lists any injector functions defined in the package. -func show(pkgs ...string) error { +type showCmd struct{} + +func (*showCmd) Name() string { return "show" } +func (*showCmd) Synopsis() string { + return "describe all top-level provider sets" +} +func (*showCmd) Usage() string { + return `show [packages] + + Given one or more packages, show finds all the provider sets declared as + top-level variables and prints what other provider sets they import and what + outputs they can produce, given possible inputs. It also lists any injector + functions defined in the package. + + If no packages are listed, it defaults to ".". +` +} +func (*showCmd) SetFlags(_ *flag.FlagSet) {} +func (*showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { wd, err := os.Getwd() if err != nil { - return err + log.Println("failed to get working directory: ", err) + return subcommands.ExitFailure } - info, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs) + info, errs := wire.Load(ctx, wd, os.Environ(), packages(f)) if info != nil { keys := make([]wire.ProviderSetID, 0, len(info.Sets)) for k := range info.Sets { @@ -261,27 +304,41 @@ func show(pkgs ...string) error { } if len(errs) > 0 { logErrors(errs) - return errors.New("error loading packages") + log.Println("error loading packages") + return subcommands.ExitFailure } - return nil + return subcommands.ExitSuccess } -// check runs the check subcommand. -// -// Given one or more packages, check will print any type-checking or -// Wire errors found with top-level variable provider sets or injector -// functions. -func check(pkgs ...string) error { +type checkCmd struct{} + +func (*checkCmd) Name() string { return "check" } +func (*checkCmd) Synopsis() string { + return "print any Wire errors found" +} +func (*checkCmd) Usage() string { + return `check [packages] + + Given one or more packages, check prints any type-checking or Wire errors + found with top-level variable provider sets or injector functions. + + If no packages are listed, it defaults to ".". +` +} +func (*checkCmd) SetFlags(_ *flag.FlagSet) {} +func (*checkCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { wd, err := os.Getwd() if err != nil { - return err + log.Println("failed to get working directory: ", err) + return subcommands.ExitFailure } - _, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs) + _, errs := wire.Load(ctx, wd, os.Environ(), packages(f)) if len(errs) > 0 { logErrors(errs) - return errors.New("error loading packages") + log.Println("error loading packages") + return subcommands.ExitFailure } - return nil + return subcommands.ExitSuccess } type outGroup struct { @@ -503,6 +560,6 @@ func formatProviderSetName(importPath, varName string) string { func logErrors(errs []error) { for _, err := range errs { - fmt.Fprintln(os.Stderr, strings.Replace(err.Error(), "\n", "\n\t", -1)) + log.Println(strings.Replace(err.Error(), "\n", "\n\t", -1)) } } diff --git a/go.mod b/go.mod index 5bf5c56..5a3912f 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,7 @@ module github.com/google/wire require ( github.com/google/go-cmp v0.2.0 + github.com/google/subcommands v1.0.1 github.com/pmezard/go-difflib v1.0.0 golang.org/x/tools v0.0.0-20190422233926-fe54fb35175b ) diff --git a/go.sum b/go.sum index 57add19..88ea58c 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/subcommands v1.0.1 h1:/eqq+otEXm5vhfBrbREPCSVQbvofip6kIz+mX5TUH7k= +github.com/google/subcommands v1.0.1/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/internal/alldeps b/internal/alldeps index fc533ad..e2ede73 100644 --- a/internal/alldeps +++ b/internal/alldeps @@ -1,3 +1,4 @@ +github.com/google/subcommands github.com/google/wire github.com/pmezard/go-difflib golang.org/x/tools