wire: use subcommands package, improving help (#173)
This commit is contained in:
265
cmd/wire/main.go
265
cmd/wire/main.go
@@ -19,90 +19,105 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/subcommands"
|
||||
"github.com/google/wire/internal/wire"
|
||||
"github.com/pmezard/go-difflib/difflib"
|
||||
"golang.org/x/tools/go/types/typeutil"
|
||||
)
|
||||
|
||||
const usage = "usage: wire [gen|diff|show|check] [...]"
|
||||
|
||||
func main() {
|
||||
var (
|
||||
exitCode = 0
|
||||
err error
|
||||
)
|
||||
switch {
|
||||
case len(os.Args) == 2 && (os.Args[1] == "help" || os.Args[1] == "-h" || os.Args[1] == "-help" || os.Args[1] == "--help"):
|
||||
fmt.Fprintln(os.Stderr, usage)
|
||||
os.Exit(0)
|
||||
case len(os.Args) == 2 && os.Args[1] == "show":
|
||||
err = show(".")
|
||||
case len(os.Args) > 2 && os.Args[1] == "show":
|
||||
err = show(os.Args[2:]...)
|
||||
case len(os.Args) == 2 && os.Args[1] == "check":
|
||||
err = check(".")
|
||||
case len(os.Args) > 2 && os.Args[1] == "check":
|
||||
err = check(os.Args[2:]...)
|
||||
case len(os.Args) == 2 && os.Args[1] == "diff":
|
||||
exitCode, err = diff(".")
|
||||
case len(os.Args) > 2 && os.Args[1] == "diff":
|
||||
exitCode, err = diff(os.Args[2:]...)
|
||||
case len(os.Args) == 2 && os.Args[1] == "gen":
|
||||
err = generate(".")
|
||||
case len(os.Args) > 2 && os.Args[1] == "gen":
|
||||
err = generate(os.Args[2:]...)
|
||||
// No explicit command given, assume "gen".
|
||||
case len(os.Args) == 1:
|
||||
err = generate(".")
|
||||
case len(os.Args) > 1:
|
||||
err = generate(os.Args[1:]...)
|
||||
default:
|
||||
fmt.Fprintln(os.Stderr, usage)
|
||||
exitCode = 64
|
||||
subcommands.Register(subcommands.CommandsCommand(), "")
|
||||
subcommands.Register(subcommands.FlagsCommand(), "")
|
||||
subcommands.Register(subcommands.HelpCommand(), "")
|
||||
subcommands.Register(&checkCmd{}, "")
|
||||
subcommands.Register(&diffCmd{}, "")
|
||||
subcommands.Register(&genCmd{}, "")
|
||||
subcommands.Register(&showCmd{}, "")
|
||||
flag.Parse()
|
||||
|
||||
// Initialize the default logger to log to stderr.
|
||||
log.SetFlags(0)
|
||||
log.SetPrefix("wire: ")
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
// TODO(rvangent): Use subcommands's VisitCommands instead of hardcoded map,
|
||||
// once there is a release that contains it:
|
||||
// allCmds := map[string]bool{}
|
||||
// subcommands.DefaultCommander.VisitCommands(func(_ *subcommands.CommandGroup, cmd subcommands.Command) { allCmds[cmd.Name()] = true })
|
||||
allCmds := map[string]bool{
|
||||
"commands": true, // builtin
|
||||
"help": true, // builtin
|
||||
"flags": true, // builtin
|
||||
"check": true,
|
||||
"diff": true,
|
||||
"gen": true,
|
||||
"show": true,
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "wire:", err)
|
||||
// Don't override more specific error codes from above
|
||||
// (e.g., diff returns 2 on error).
|
||||
if exitCode == 0 {
|
||||
exitCode = 1
|
||||
}
|
||||
// Default to running the "gen" command.
|
||||
if args := flag.Args(); len(args) == 0 || !allCmds[args[0]] {
|
||||
genCmd := &genCmd{}
|
||||
os.Exit(int(genCmd.Execute(context.Background(), flag.CommandLine)))
|
||||
}
|
||||
os.Exit(exitCode)
|
||||
os.Exit(int(subcommands.Execute(context.Background())))
|
||||
}
|
||||
|
||||
// generate runs the gen subcommand.
|
||||
//
|
||||
// Given one or more packages, gen will create the wire_gen.go file for each.
|
||||
func generate(pkgs ...string) error {
|
||||
// packages returns the slice of packages to run wire over based on f.
|
||||
// It defaults to ".".
|
||||
func packages(f *flag.FlagSet) []string {
|
||||
pkgs := f.Args()
|
||||
if len(pkgs) == 0 {
|
||||
pkgs = []string{"."}
|
||||
}
|
||||
return pkgs
|
||||
}
|
||||
|
||||
type genCmd struct{}
|
||||
|
||||
func (*genCmd) Name() string { return "gen" }
|
||||
func (*genCmd) Synopsis() string {
|
||||
return "generate the wire_gen.go file for each package"
|
||||
}
|
||||
func (*genCmd) Usage() string {
|
||||
return `gen [packages]
|
||||
|
||||
Given one or more packages, gen creates the wire_gen.go file for each.
|
||||
|
||||
If no packages are listed, it defaults to ".".
|
||||
`
|
||||
}
|
||||
func (*genCmd) SetFlags(_ *flag.FlagSet) {}
|
||||
func (*genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
log.Println("failed to get working directory: ", err)
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
outs, errs := wire.Generate(context.Background(), wd, os.Environ(), pkgs)
|
||||
outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f))
|
||||
if len(errs) > 0 {
|
||||
logErrors(errs)
|
||||
return errors.New("generate failed")
|
||||
log.Println("generate failed")
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
if len(outs) == 0 {
|
||||
return nil
|
||||
return subcommands.ExitSuccess
|
||||
}
|
||||
success := true
|
||||
for _, out := range outs {
|
||||
if len(out.Errs) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "%s: generate failed\n", out.PkgPath)
|
||||
logErrors(out.Errs)
|
||||
log.Printf("%s: generate failed\n", out.PkgPath)
|
||||
success = false
|
||||
}
|
||||
if len(out.Content) == 0 {
|
||||
@@ -110,53 +125,63 @@ func generate(pkgs ...string) error {
|
||||
continue
|
||||
}
|
||||
if err := out.Commit(); err == nil {
|
||||
fmt.Fprintf(os.Stderr, "%s: wrote %s\n", out.PkgPath, out.OutputPath)
|
||||
log.Printf("%s: wrote %s\n", out.PkgPath, out.OutputPath)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err)
|
||||
log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err)
|
||||
success = false
|
||||
}
|
||||
}
|
||||
if !success {
|
||||
return errors.New("at least one generate failure")
|
||||
log.Println("at least one generate failure")
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
return nil
|
||||
return subcommands.ExitSuccess
|
||||
}
|
||||
|
||||
// diff runs the diff subcommand.
|
||||
//
|
||||
// Given one or more packages, diff will generate the content for the
|
||||
// wire_gen.go file, and output the diff against the existing file.
|
||||
//
|
||||
// Similar to the diff command, it returns 0 if no diff, 1 if different, 2
|
||||
// plus an error if trouble.
|
||||
func diff(pkgs ...string) (int, error) {
|
||||
errReturn := func(err error) (int, error) {
|
||||
return 2, err
|
||||
}
|
||||
okReturn := func(hadDiff bool) (int, error) {
|
||||
if hadDiff {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
type diffCmd struct{}
|
||||
|
||||
func (*diffCmd) Name() string { return "diff" }
|
||||
func (*diffCmd) Synopsis() string {
|
||||
return "output a diff between existing wire_gen.go files and what gen would generate"
|
||||
}
|
||||
func (*diffCmd) Usage() string {
|
||||
return `diff [packages]
|
||||
|
||||
Given one or more packages, diff generates the content for their wire_gen.go
|
||||
files and outputs the diff against the existing files.
|
||||
|
||||
If no packages are listed, it defaults to ".".
|
||||
|
||||
Similar to the diff command, it returns 0 if no diff, 1 if different, 2
|
||||
plus an error if trouble.
|
||||
`
|
||||
}
|
||||
func (*diffCmd) SetFlags(_ *flag.FlagSet) {}
|
||||
func (*diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
|
||||
const (
|
||||
errReturn = subcommands.ExitStatus(2)
|
||||
diffReturn = subcommands.ExitStatus(1)
|
||||
)
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return errReturn(err)
|
||||
log.Println("failed to get working directory: ", err)
|
||||
return errReturn
|
||||
}
|
||||
outs, errs := wire.Generate(context.Background(), wd, os.Environ(), pkgs)
|
||||
outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f))
|
||||
if len(errs) > 0 {
|
||||
logErrors(errs)
|
||||
return errReturn(errors.New("generate failed"))
|
||||
log.Println("generate failed")
|
||||
return errReturn
|
||||
}
|
||||
if len(outs) == 0 {
|
||||
return okReturn(false)
|
||||
return subcommands.ExitSuccess
|
||||
}
|
||||
success := true
|
||||
hadDiff := false
|
||||
for _, out := range outs {
|
||||
if len(out.Errs) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "%s: generate failed\n", out.PkgPath)
|
||||
logErrors(out.Errs)
|
||||
log.Printf("%s: generate failed\n", out.PkgPath)
|
||||
success = false
|
||||
}
|
||||
if len(out.Content) == 0 {
|
||||
@@ -170,32 +195,50 @@ func diff(pkgs ...string) (int, error) {
|
||||
B: difflib.SplitLines(string(out.Content)),
|
||||
}); err == nil {
|
||||
if diff != "" {
|
||||
fmt.Fprintf(os.Stdout, "%s: diff from %s:\n%s", out.PkgPath, out.OutputPath, diff)
|
||||
// Print the actual diff to stdout, not stderr.
|
||||
fmt.Printf("%s: diff from %s:\n%s\n", out.PkgPath, out.OutputPath, diff)
|
||||
hadDiff = true
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s: failed to diff %s: %v\n", out.PkgPath, out.OutputPath, err)
|
||||
log.Printf("%s: failed to diff %s: %v\n", out.PkgPath, out.OutputPath, err)
|
||||
success = false
|
||||
}
|
||||
}
|
||||
if !success {
|
||||
return errReturn(errors.New("at least one generate failure"))
|
||||
log.Println("at least one generate failure")
|
||||
return errReturn
|
||||
}
|
||||
return okReturn(hadDiff)
|
||||
if hadDiff {
|
||||
return diffReturn
|
||||
}
|
||||
return subcommands.ExitSuccess
|
||||
}
|
||||
|
||||
// show runs the show subcommand.
|
||||
//
|
||||
// Given one or more packages, show will find all the provider sets
|
||||
// declared as top-level variables and print what other provider sets it
|
||||
// imports and what outputs it can produce, given possible inputs.
|
||||
// It also lists any injector functions defined in the package.
|
||||
func show(pkgs ...string) error {
|
||||
type showCmd struct{}
|
||||
|
||||
func (*showCmd) Name() string { return "show" }
|
||||
func (*showCmd) Synopsis() string {
|
||||
return "describe all top-level provider sets"
|
||||
}
|
||||
func (*showCmd) Usage() string {
|
||||
return `show [packages]
|
||||
|
||||
Given one or more packages, show finds all the provider sets declared as
|
||||
top-level variables and prints what other provider sets they import and what
|
||||
outputs they can produce, given possible inputs. It also lists any injector
|
||||
functions defined in the package.
|
||||
|
||||
If no packages are listed, it defaults to ".".
|
||||
`
|
||||
}
|
||||
func (*showCmd) SetFlags(_ *flag.FlagSet) {}
|
||||
func (*showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
log.Println("failed to get working directory: ", err)
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
info, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs)
|
||||
info, errs := wire.Load(ctx, wd, os.Environ(), packages(f))
|
||||
if info != nil {
|
||||
keys := make([]wire.ProviderSetID, 0, len(info.Sets))
|
||||
for k := range info.Sets {
|
||||
@@ -261,27 +304,41 @@ func show(pkgs ...string) error {
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
logErrors(errs)
|
||||
return errors.New("error loading packages")
|
||||
log.Println("error loading packages")
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
return nil
|
||||
return subcommands.ExitSuccess
|
||||
}
|
||||
|
||||
// check runs the check subcommand.
|
||||
//
|
||||
// Given one or more packages, check will print any type-checking or
|
||||
// Wire errors found with top-level variable provider sets or injector
|
||||
// functions.
|
||||
func check(pkgs ...string) error {
|
||||
type checkCmd struct{}
|
||||
|
||||
func (*checkCmd) Name() string { return "check" }
|
||||
func (*checkCmd) Synopsis() string {
|
||||
return "print any Wire errors found"
|
||||
}
|
||||
func (*checkCmd) Usage() string {
|
||||
return `check [packages]
|
||||
|
||||
Given one or more packages, check prints any type-checking or Wire errors
|
||||
found with top-level variable provider sets or injector functions.
|
||||
|
||||
If no packages are listed, it defaults to ".".
|
||||
`
|
||||
}
|
||||
func (*checkCmd) SetFlags(_ *flag.FlagSet) {}
|
||||
func (*checkCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
log.Println("failed to get working directory: ", err)
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
_, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs)
|
||||
_, errs := wire.Load(ctx, wd, os.Environ(), packages(f))
|
||||
if len(errs) > 0 {
|
||||
logErrors(errs)
|
||||
return errors.New("error loading packages")
|
||||
log.Println("error loading packages")
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
return nil
|
||||
return subcommands.ExitSuccess
|
||||
}
|
||||
|
||||
type outGroup struct {
|
||||
@@ -503,6 +560,6 @@ func formatProviderSetName(importPath, varName string) string {
|
||||
|
||||
func logErrors(errs []error) {
|
||||
for _, err := range errs {
|
||||
fmt.Fprintln(os.Stderr, strings.Replace(err.Error(), "\n", "\n\t", -1))
|
||||
log.Println(strings.Replace(err.Error(), "\n", "\n\t", -1))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user