wire: support multiple packages in Generate (google/go-cloud#729)

This commit is contained in:
Robert van Gent
2018-11-16 15:46:13 -08:00
committed by Ross Light
parent 925a11ad0d
commit 65eb134857
3 changed files with 100 additions and 53 deletions

View File

@@ -41,8 +41,6 @@ func main() {
case len(os.Args) == 2 && (os.Args[1] == "help" || os.Args[1] == "-h" || os.Args[1] == "-help" || os.Args[1] == "--help"): case len(os.Args) == 2 && (os.Args[1] == "help" || os.Args[1] == "-h" || os.Args[1] == "-help" || os.Args[1] == "--help"):
fmt.Fprintln(os.Stderr, usage) fmt.Fprintln(os.Stderr, usage)
os.Exit(0) os.Exit(0)
case len(os.Args) == 1 || len(os.Args) == 2 && os.Args[1] == "gen":
err = generate(".")
case len(os.Args) == 2 && os.Args[1] == "show": case len(os.Args) == 2 && os.Args[1] == "show":
err = show(".") err = show(".")
case len(os.Args) > 2 && os.Args[1] == "show": case len(os.Args) > 2 && os.Args[1] == "show":
@@ -51,10 +49,15 @@ func main() {
err = check(".") err = check(".")
case len(os.Args) > 2 && os.Args[1] == "check": case len(os.Args) > 2 && os.Args[1] == "check":
err = check(os.Args[2:]...) err = check(os.Args[2:]...)
case len(os.Args) == 2: case len(os.Args) == 2 && os.Args[1] == "gen":
err = generate(os.Args[1]) err = generate(".")
case len(os.Args) == 3 && os.Args[1] == "gen": case len(os.Args) > 2 && os.Args[1] == "gen":
err = generate(os.Args[2]) err = generate(os.Args[2:]...)
// No explicit command given, assume "gen".
case len(os.Args) == 1:
err = generate(".")
case len(os.Args) > 1:
err = generate(os.Args[1:]...)
default: default:
fmt.Fprintln(os.Stderr, usage) fmt.Fprintln(os.Stderr, usage)
os.Exit(64) os.Exit(64)
@@ -65,24 +68,44 @@ func main() {
} }
} }
// generate runs the gen subcommand. Given a package, gen will create // generate runs the gen subcommand.
// the wire_gen.go file. //
func generate(pkg string) error { // Given one or more packages, gen will create the wire_gen.go file for each.
func generate(pkgs ...string) error {
wd, err := os.Getwd() wd, err := os.Getwd()
if err != nil { if err != nil {
return err return err
} }
out, errs := wire.Generate(context.Background(), wd, os.Environ(), pkg) outs, errs := wire.Generate(context.Background(), wd, os.Environ(), pkgs)
if len(errs) > 0 { if len(errs) > 0 {
logErrors(errs) logErrors(errs)
return errors.New("generate failed") return errors.New("generate failed")
} }
if len(out.Content) == 0 { if len(outs) == 0 {
// No Wire directives, don't write anything.
fmt.Fprintln(os.Stderr, "wire: no injector found for", pkg)
return nil return nil
} }
return out.Commit() success := true
for _, out := range outs {
if len(out.Errs) > 0 {
fmt.Fprintf(os.Stderr, "%s: generate failed\n", out.PkgPath)
logErrors(out.Errs)
success = false
}
if len(out.Content) == 0 {
// No Wire output. Maybe errors, maybe no Wire directives.
continue
}
if err := out.Commit(); err == nil {
fmt.Fprintf(os.Stderr, "%s: wrote %s\n", out.PkgPath, out.OutputPath)
} else {
fmt.Fprintf(os.Stderr, "%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err)
success = false
}
}
if !success {
return errors.New("at least one generate failure")
}
return nil
} }
// show runs the show subcommand. // show runs the show subcommand.

View File

@@ -38,24 +38,31 @@ import (
"golang.org/x/tools/go/packages" "golang.org/x/tools/go/packages"
) )
// GeneratedFile stores the content of a call to Generate and the // GenerateResult stores the result for a package from a call to Generate.
// desired on-disk location of the file. type GenerateResult struct {
type GeneratedFile struct { // PkgPath is the package's PkgPath.
Path string PkgPath string
// OutputPath is the path where the generated output should be written.
// May be empty if there were errors.
OutputPath string
// Content is the gofmt'd source code that was generated. May be nil if
// there were errors during generation.
Content []byte Content []byte
// Errs is a slice of errors identified during generation.
Errs []error
} }
// Commit writes the generated file to disk. // Commit writes the generated file to disk.
func (gen GeneratedFile) Commit() error { func (gen GenerateResult) Commit() error {
if len(gen.Content) == 0 { if len(gen.Content) == 0 {
return nil return nil
} }
return ioutil.WriteFile(gen.Path, gen.Content, 0666) return ioutil.WriteFile(gen.OutputPath, gen.Content, 0666)
} }
// Generate performs dependency injection for a single package, // Generate performs dependency injection for the packages that match the given
// returning the gofmt'd Go source code. The package pattern is defined // patterns, return a GenerateResult for each package. The package pattern is
// by the underlying build system. For the go tool, this is described at // defined by the underlying build system. For the go tool, this is described at
// https://golang.org/cmd/go/#hdr-Package_lists_and_patterns // https://golang.org/cmd/go/#hdr-Package_lists_and_patterns
// //
// wd is the working directory and env is the set of environment // wd is the working directory and env is the set of environment
@@ -63,34 +70,41 @@ func (gen GeneratedFile) Commit() error {
// env is nil or empty, it is interpreted as an empty set of variables. // env is nil or empty, it is interpreted as an empty set of variables.
// In case of duplicate environment variables, the last one in the list // In case of duplicate environment variables, the last one in the list
// takes precedence. // takes precedence.
func Generate(ctx context.Context, wd string, env []string, pkgPattern string) (GeneratedFile, []error) { //
pkgs, errs := load(ctx, wd, env, []string{pkgPattern}) // Generate may return one or more errors if it failed to load the packages.
func Generate(ctx context.Context, wd string, env []string, patterns []string) ([]GenerateResult, []error) {
pkgs, errs := load(ctx, wd, env, patterns)
if len(errs) > 0 { if len(errs) > 0 {
return GeneratedFile{}, errs return nil, errs
} }
if len(pkgs) != 1 { generated := make([]GenerateResult, len(pkgs))
// This is more of a violated precondition than anything else. for i, pkg := range pkgs {
return GeneratedFile{}, []error{fmt.Errorf("load: got %d packages", len(pkgs))} generated[i].PkgPath = pkg.PkgPath
} outDir, err := detectOutputDir(pkg.GoFiles)
outDir, err := detectOutputDir(pkgs[0].GoFiles)
if err != nil { if err != nil {
return GeneratedFile{}, []error{fmt.Errorf("load: %v", err)} generated[i].Errs = append(generated[i].Errs, err)
continue
} }
outFname := filepath.Join(outDir, "wire_gen.go") generated[i].OutputPath = filepath.Join(outDir, "wire_gen.go")
g := newGen(pkgs[0]) g := newGen(pkg)
injectorFiles, errs := generateInjectors(g, pkgs[0]) injectorFiles, errs := generateInjectors(g, pkg)
if len(errs) > 0 { if len(errs) > 0 {
return GeneratedFile{}, errs generated[i].Errs = errs
continue
} }
copyNonInjectorDecls(g, injectorFiles, pkgs[0].TypesInfo) copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo)
goSrc := g.frame() goSrc := g.frame()
fmtSrc, err := format.Source(goSrc) fmtSrc, err := format.Source(goSrc)
if err != nil { if err != nil {
// This is likely a bug from a poorly generated source file. // This is likely a bug from a poorly generated source file.
// Return an error and the unformatted source. // Add an error but also the unformatted source.
return GeneratedFile{Path: outFname, Content: goSrc}, []error{err} generated[i].Errs = append(generated[i].Errs, err)
} else {
goSrc = fmtSrc
} }
return GeneratedFile{Path: outFname, Content: fmtSrc}, nil generated[i].Content = goSrc
}
return generated, nil
} }
func detectOutputDir(paths []string) (string, error) { func detectOutputDir(paths []string) (string, error) {

View File

@@ -90,10 +90,20 @@ func TestWire(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
wd := filepath.Join(gopath, "src", "example.com") wd := filepath.Join(gopath, "src", "example.com")
gen, errs := Generate(ctx, wd, append(os.Environ(), "GOPATH="+gopath), test.pkg) gens, errs := Generate(ctx, wd, append(os.Environ(), "GOPATH="+gopath), []string{test.pkg})
var gen GenerateResult
if len(gens) > 1 {
t.Fatalf("got %d generated files, want 0 or 1", len(gens))
}
if len(gens) == 1 {
gen = gens[0]
if len(gen.Errs) > 0 {
errs = append(errs, gen.Errs...)
}
if len(gen.Content) > 0 { if len(gen.Content) > 0 {
defer t.Logf("wire_gen.go:\n%s", gen.Content) defer t.Logf("wire_gen.go:\n%s", gen.Content)
} }
}
if len(errs) > 0 { if len(errs) > 0 {
gotErrStrings := make([]string, len(errs)) gotErrStrings := make([]string, len(errs))
for i, e := range errs { for i, e := range errs {
@@ -119,9 +129,9 @@ func TestWire(t *testing.T) {
t.Fatal("wire succeeded; want error") t.Fatal("wire succeeded; want error")
} }
outPathSane := true outPathSane := true
if prefix := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator); !strings.HasPrefix(gen.Path, prefix) { if prefix := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator); !strings.HasPrefix(gen.OutputPath, prefix) {
outPathSane = false outPathSane = false
t.Errorf("suggested output path = %q; want to start with %q", gen.Path, prefix) t.Errorf("suggested output path = %q; want to start with %q", gen.OutputPath, prefix)
} }
if *record { if *record {