diff --git a/cmd/gowire/main.go b/cmd/gowire/main.go index a2a5590..6201a37 100644 --- a/cmd/gowire/main.go +++ b/cmd/gowire/main.go @@ -18,6 +18,7 @@ package main import ( + "errors" "fmt" "go/build" "go/token" @@ -68,9 +69,10 @@ func generate(pkg string) error { if err != nil { return err } - out, err := wire.Generate(&build.Default, wd, pkg) - if err != nil { - return err + out, errs := wire.Generate(&build.Default, wd, pkg) + if len(errs) > 0 { + logErrors(errs) + return errors.New("generate failed") } if len(out) == 0 { // No Wire directives, don't write anything. @@ -94,56 +96,59 @@ func show(pkgs ...string) error { if err != nil { return err } - info, err := wire.Load(&build.Default, wd, pkgs) - if err != nil { - return err - } - keys := make([]wire.ProviderSetID, 0, len(info.Sets)) - for k := range info.Sets { - keys = append(keys, k) - } - sort.Slice(keys, func(i, j int) bool { - if keys[i].ImportPath == keys[j].ImportPath { - return keys[i].VarName < keys[j].VarName + info, errs := wire.Load(&build.Default, wd, pkgs) + if info != nil { + keys := make([]wire.ProviderSetID, 0, len(info.Sets)) + for k := range info.Sets { + keys = append(keys, k) } - return keys[i].ImportPath < keys[j].ImportPath - }) - // ANSI color codes. - // TODO(light): Possibly use github.com/fatih/color? - const ( - reset = "\x1b[0m" - redBold = "\x1b[0;1;31m" - blue = "\x1b[0;34m" - green = "\x1b[0;32m" - ) - for i, k := range keys { - if i > 0 { - fmt.Println() - } - outGroups, imports := gather(info, k) - fmt.Printf("%s%s%s\n", redBold, k, reset) - for _, imp := range sortSet(imports) { - fmt.Printf("\t%s\n", imp) - } - for i := range outGroups { - fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset) - out := make(map[string]token.Pos, outGroups[i].outputs.Len()) - outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) { - switch v := v.(type) { - case *wire.Provider: - out[types.TypeString(t, nil)] = v.Pos - case *wire.Value: - out[types.TypeString(t, nil)] = v.Pos - default: - panic("unreachable") + sort.Slice(keys, func(i, j int) bool { + if keys[i].ImportPath == keys[j].ImportPath { + return keys[i].VarName < keys[j].VarName + } + return keys[i].ImportPath < keys[j].ImportPath + }) + // ANSI color codes. + // TODO(light): Possibly use github.com/fatih/color? + const ( + reset = "\x1b[0m" + redBold = "\x1b[0;1;31m" + blue = "\x1b[0;34m" + green = "\x1b[0;32m" + ) + for i, k := range keys { + if i > 0 { + fmt.Println() + } + outGroups, imports := gather(info, k) + fmt.Printf("%s%s%s\n", redBold, k, reset) + for _, imp := range sortSet(imports) { + fmt.Printf("\t%s\n", imp) + } + for i := range outGroups { + fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset) + out := make(map[string]token.Pos, outGroups[i].outputs.Len()) + outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) { + switch v := v.(type) { + case *wire.Provider: + out[types.TypeString(t, nil)] = v.Pos + case *wire.Value: + out[types.TypeString(t, nil)] = v.Pos + default: + panic("unreachable") + } + }) + for _, t := range sortSet(out) { + fmt.Printf("\t%s%s%s\n", green, t, reset) + fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t])) } - }) - for _, t := range sortSet(out) { - fmt.Printf("\t%s%s%s\n", green, t, reset) - fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t])) } } } + if len(errs) > 0 { + logErrors(errs) + return errors.New("error loading packages") + } return nil } @@ -328,3 +333,9 @@ func formatProviderSetName(importPath, varName string) string { // Since varName is an identifier, it doesn't make sense to quote. return strconv.Quote(importPath) + "." + varName } + +func logErrors(errs []error) { + for _, err := range errs { + fmt.Fprintln(os.Stderr, strings.Replace(err.Error(), "\n", "\n\t", -1)) + } +} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index ea408a7..e77ba3c 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -144,8 +144,9 @@ type Value struct { } // Load finds all the provider sets in the given packages, as well as -// the provider sets' transitive dependencies. -func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) { +// the provider sets' transitive dependencies. It may return both an error +// and Info. +func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) { // TODO(light): Stop errors from printing to stderr. conf := &loader.Config{ Build: bctx, @@ -157,7 +158,7 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) { } prog, err := conf.Load() if err != nil { - return nil, fmt.Errorf("load: %v", err) + return nil, []error{fmt.Errorf("load: %v", err)} } info := &Info{ Fset: prog.Fset, diff --git a/internal/wire/wire.go b/internal/wire/wire.go index f98a9b9..11cf805 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -38,10 +38,10 @@ import ( // Generate performs dependency injection for a single package, // returning the gofmt'd Go source code. -func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { +func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { mainPkg, err := bctx.Import(pkg, wd, build.FindOnly) if err != nil { - return nil, fmt.Errorf("load: %v", err) + return nil, []error{fmt.Errorf("load: %v", err)} } // TODO(light): Stop errors from printing to stderr. conf := &loader.Config{ @@ -59,17 +59,17 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { prog, err := conf.Load() if err != nil { - return nil, fmt.Errorf("load: %v", err) + return nil, []error{fmt.Errorf("load: %v", err)} } if len(prog.InitialPackages()) != 1 { // This is more of a violated precondition than anything else. - return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages())) + return nil, []error{fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))} } pkgInfo := prog.InitialPackages()[0] g := newGen(prog, pkgInfo.Pkg.Path()) injectorFiles, err := generateInjectors(g, pkgInfo) if err != nil { - return nil, err + return nil, []error{err} } copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info) goSrc := g.frame() @@ -77,7 +77,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { if err != nil { // This is likely a bug from a poorly generated source file. // Return an error and the unformatted source. - return goSrc, err + return goSrc, []error{err} } return fmtSrc, nil } diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index 2283d00..d89da7d 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -69,17 +69,17 @@ func TestWire(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { bctx := test.buildContext() - gen, err := Generate(bctx, wd, test.pkg) + gen, errs := Generate(bctx, wd, test.pkg) if len(gen) > 0 { defer t.Logf("wire_gen.go:\n%s", gen) } - if err != nil { + if len(errs) > 0 { if !test.wantError { - t.Fatalf("wirego: %v", err) + t.Fatalf("wirego: %v", errs) } return } - if err == nil && test.wantError { + if len(errs) == 0 && test.wantError { t.Fatal("wirego succeeded; want error") } @@ -133,15 +133,15 @@ func TestWire(t *testing.T) { } t.Run(test.name, func(t *testing.T) { bctx := test.buildContext() - gold, err := Generate(bctx, wd, test.pkg) - if err != nil { - t.Fatal("wirego:", err) + gold, errs := Generate(bctx, wd, test.pkg) + if len(errs) > 0 { + t.Fatal("wirego:", errs) } goldstr := string(gold) for i := 0; i < runs-1; i++ { - out, err := Generate(bctx, wd, test.pkg) - if err != nil { - t.Fatal("wirego (on subsequent run):", err) + out, errs := Generate(bctx, wd, test.pkg) + if len(errs) > 0 { + t.Fatal("wirego (on subsequent run):", errs) } if !bytes.Equal(gold, out) { t.Fatalf("wirego output differs when run repeatedly on same input:\n%s", diff(goldstr, string(out)))