diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 3fc29f8..0494c07 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -41,8 +41,6 @@ func main() { case len(os.Args) == 2 && (os.Args[1] == "help" || os.Args[1] == "-h" || os.Args[1] == "-help" || os.Args[1] == "--help"): fmt.Fprintln(os.Stderr, usage) os.Exit(0) - case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen": - err = generate(".") case len(os.Args) == 2 && os.Args[1] == "show": err = show(".") case len(os.Args) > 2 && os.Args[1] == "show": @@ -51,10 +49,15 @@ func main() { err = check(".") case len(os.Args) > 2 && os.Args[1] == "check": err = check(os.Args[2:]...) - case len(os.Args) == 2: - err = generate(os.Args[1]) - case len(os.Args) == 3 && os.Args[1] == "gen": - err = generate(os.Args[2]) + case len(os.Args) == 2 && os.Args[1] == "gen": + err = generate(".") + case len(os.Args) > 2 && os.Args[1] == "gen": + err = generate(os.Args[2:]...) + // No explicit command given, assume "gen". + case len(os.Args) == 1: + err = generate(".") + case len(os.Args) > 1: + err = generate(os.Args[1:]...) default: fmt.Fprintln(os.Stderr, usage) os.Exit(64) @@ -65,24 +68,44 @@ func main() { } } -// generate runs the gen subcommand. Given a package, gen will create -// the wire_gen.go file. -func generate(pkg string) error { +// generate runs the gen subcommand. +// +// Given one or more packages, gen will create the wire_gen.go file for each. +func generate(pkgs ...string) error { wd, err := os.Getwd() if err != nil { return err } - out, errs := wire.Generate(context.Background(), wd, os.Environ(), pkg) + outs, errs := wire.Generate(context.Background(), wd, os.Environ(), pkgs) if len(errs) > 0 { logErrors(errs) return errors.New("generate failed") } - if len(out.Content) == 0 { - // No Wire directives, don't write anything. - fmt.Fprintln(os.Stderr, "wire: no injector found for", pkg) + if len(outs) == 0 { return nil } - return out.Commit() + success := true + for _, out := range outs { + if len(out.Errs) > 0 { + fmt.Fprintf(os.Stderr, "%s: generate failed\n", out.PkgPath) + logErrors(out.Errs) + success = false + } + if len(out.Content) == 0 { + // No Wire output. Maybe errors, maybe no Wire directives. + continue + } + if err := out.Commit(); err == nil { + fmt.Fprintf(os.Stderr, "%s: wrote %s\n", out.PkgPath, out.OutputPath) + } else { + fmt.Fprintf(os.Stderr, "%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) + success = false + } + } + if !success { + return errors.New("at least one generate failure") + } + return nil } // show runs the show subcommand. diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 64db9e8..ef9ce68 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -38,24 +38,31 @@ import ( "golang.org/x/tools/go/packages" ) -// GeneratedFile stores the content of a call to Generate and the -// desired on-disk location of the file. -type GeneratedFile struct { - Path string +// GenerateResult stores the result for a package from a call to Generate. +type GenerateResult struct { + // PkgPath is the package's PkgPath. + PkgPath string + // OutputPath is the path where the generated output should be written. + // May be empty if there were errors. + OutputPath string + // Content is the gofmt'd source code that was generated. May be nil if + // there were errors during generation. Content []byte + // Errs is a slice of errors identified during generation. + Errs []error } // Commit writes the generated file to disk. -func (gen GeneratedFile) Commit() error { +func (gen GenerateResult) Commit() error { if len(gen.Content) == 0 { return nil } - return ioutil.WriteFile(gen.Path, gen.Content, 0666) + return ioutil.WriteFile(gen.OutputPath, gen.Content, 0666) } -// Generate performs dependency injection for a single package, -// returning the gofmt'd Go source code. The package pattern is defined -// by the underlying build system. For the go tool, this is described at +// Generate performs dependency injection for the packages that match the given +// patterns, return a GenerateResult for each package. The package pattern is +// defined by the underlying build system. For the go tool, this is described at // https://golang.org/cmd/go/#hdr-Package_lists_and_patterns // // wd is the working directory and env is the set of environment @@ -63,34 +70,41 @@ func (gen GeneratedFile) Commit() error { // env is nil or empty, it is interpreted as an empty set of variables. // In case of duplicate environment variables, the last one in the list // takes precedence. -func Generate(ctx context.Context, wd string, env []string, pkgPattern string) (GeneratedFile, []error) { - pkgs, errs := load(ctx, wd, env, []string{pkgPattern}) +// +// Generate may return one or more errors if it failed to load the packages. +func Generate(ctx context.Context, wd string, env []string, patterns []string) ([]GenerateResult, []error) { + pkgs, errs := load(ctx, wd, env, patterns) if len(errs) > 0 { - return GeneratedFile{}, errs + return nil, errs } - if len(pkgs) != 1 { - // This is more of a violated precondition than anything else. - return GeneratedFile{}, []error{fmt.Errorf("load: got %d packages", len(pkgs))} + generated := make([]GenerateResult, len(pkgs)) + for i, pkg := range pkgs { + generated[i].PkgPath = pkg.PkgPath + outDir, err := detectOutputDir(pkg.GoFiles) + if err != nil { + generated[i].Errs = append(generated[i].Errs, err) + continue + } + generated[i].OutputPath = filepath.Join(outDir, "wire_gen.go") + g := newGen(pkg) + injectorFiles, errs := generateInjectors(g, pkg) + if len(errs) > 0 { + generated[i].Errs = errs + continue + } + copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo) + goSrc := g.frame() + fmtSrc, err := format.Source(goSrc) + if err != nil { + // This is likely a bug from a poorly generated source file. + // Add an error but also the unformatted source. + generated[i].Errs = append(generated[i].Errs, err) + } else { + goSrc = fmtSrc + } + generated[i].Content = goSrc } - outDir, err := detectOutputDir(pkgs[0].GoFiles) - if err != nil { - return GeneratedFile{}, []error{fmt.Errorf("load: %v", err)} - } - outFname := filepath.Join(outDir, "wire_gen.go") - g := newGen(pkgs[0]) - injectorFiles, errs := generateInjectors(g, pkgs[0]) - if len(errs) > 0 { - return GeneratedFile{}, errs - } - copyNonInjectorDecls(g, injectorFiles, pkgs[0].TypesInfo) - goSrc := g.frame() - fmtSrc, err := format.Source(goSrc) - if err != nil { - // This is likely a bug from a poorly generated source file. - // Return an error and the unformatted source. - return GeneratedFile{Path: outFname, Content: goSrc}, []error{err} - } - return GeneratedFile{Path: outFname, Content: fmtSrc}, nil + return generated, nil } func detectOutputDir(paths []string) (string, error) { diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index b01b7c4..805e03a 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -90,9 +90,19 @@ func TestWire(t *testing.T) { t.Fatal(err) } wd := filepath.Join(gopath, "src", "example.com") - gen, errs := Generate(ctx, wd, append(os.Environ(), "GOPATH="+gopath), test.pkg) - if len(gen.Content) > 0 { - defer t.Logf("wire_gen.go:\n%s", gen.Content) + gens, errs := Generate(ctx, wd, append(os.Environ(), "GOPATH="+gopath), []string{test.pkg}) + var gen GenerateResult + if len(gens) > 1 { + t.Fatalf("got %d generated files, want 0 or 1", len(gens)) + } + if len(gens) == 1 { + gen = gens[0] + if len(gen.Errs) > 0 { + errs = append(errs, gen.Errs...) + } + if len(gen.Content) > 0 { + defer t.Logf("wire_gen.go:\n%s", gen.Content) + } } if len(errs) > 0 { gotErrStrings := make([]string, len(errs)) @@ -119,9 +129,9 @@ func TestWire(t *testing.T) { t.Fatal("wire succeeded; want error") } outPathSane := true - if prefix := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator); !strings.HasPrefix(gen.Path, prefix) { + if prefix := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator); !strings.HasPrefix(gen.OutputPath, prefix) { outPathSane = false - t.Errorf("suggested output path = %q; want to start with %q", gen.Path, prefix) + t.Errorf("suggested output path = %q; want to start with %q", gen.OutputPath, prefix) } if *record {