wire: change call sites to allow multiple errors (google/go-cloud#118)

Updated the call sites to allow multiple errors to be returned from
the package. Load is now permitted to return partial success.

Updates google/go-cloud#5
This commit is contained in:
Ross Light
2018-06-25 10:30:34 -07:00
parent 45a535a0bd
commit f7658c8a13
4 changed files with 79 additions and 67 deletions

View File

@@ -18,6 +18,7 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"go/build" "go/build"
"go/token" "go/token"
@@ -68,9 +69,10 @@ func generate(pkg string) error {
if err != nil { if err != nil {
return err return err
} }
out, err := wire.Generate(&build.Default, wd, pkg) out, errs := wire.Generate(&build.Default, wd, pkg)
if err != nil { if len(errs) > 0 {
return err logErrors(errs)
return errors.New("generate failed")
} }
if len(out) == 0 { if len(out) == 0 {
// No Wire directives, don't write anything. // No Wire directives, don't write anything.
@@ -94,56 +96,59 @@ func show(pkgs ...string) error {
if err != nil { if err != nil {
return err return err
} }
info, err := wire.Load(&build.Default, wd, pkgs) info, errs := wire.Load(&build.Default, wd, pkgs)
if err != nil { if info != nil {
return err keys := make([]wire.ProviderSetID, 0, len(info.Sets))
} for k := range info.Sets {
keys := make([]wire.ProviderSetID, 0, len(info.Sets)) keys = append(keys, k)
for k := range info.Sets {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].ImportPath == keys[j].ImportPath {
return keys[i].VarName < keys[j].VarName
} }
return keys[i].ImportPath < keys[j].ImportPath sort.Slice(keys, func(i, j int) bool {
}) if keys[i].ImportPath == keys[j].ImportPath {
// ANSI color codes. return keys[i].VarName < keys[j].VarName
// TODO(light): Possibly use github.com/fatih/color? }
const ( return keys[i].ImportPath < keys[j].ImportPath
reset = "\x1b[0m" })
redBold = "\x1b[0;1;31m" // ANSI color codes.
blue = "\x1b[0;34m" // TODO(light): Possibly use github.com/fatih/color?
green = "\x1b[0;32m" const (
) reset = "\x1b[0m"
for i, k := range keys { redBold = "\x1b[0;1;31m"
if i > 0 { blue = "\x1b[0;34m"
fmt.Println() green = "\x1b[0;32m"
} )
outGroups, imports := gather(info, k) for i, k := range keys {
fmt.Printf("%s%s%s\n", redBold, k, reset) if i > 0 {
for _, imp := range sortSet(imports) { fmt.Println()
fmt.Printf("\t%s\n", imp) }
} outGroups, imports := gather(info, k)
for i := range outGroups { fmt.Printf("%s%s%s\n", redBold, k, reset)
fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset) for _, imp := range sortSet(imports) {
out := make(map[string]token.Pos, outGroups[i].outputs.Len()) fmt.Printf("\t%s\n", imp)
outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) { }
switch v := v.(type) { for i := range outGroups {
case *wire.Provider: fmt.Printf("%sOutputs given %s:%s\n", blue, outGroups[i].name, reset)
out[types.TypeString(t, nil)] = v.Pos out := make(map[string]token.Pos, outGroups[i].outputs.Len())
case *wire.Value: outGroups[i].outputs.Iterate(func(t types.Type, v interface{}) {
out[types.TypeString(t, nil)] = v.Pos switch v := v.(type) {
default: case *wire.Provider:
panic("unreachable") out[types.TypeString(t, nil)] = v.Pos
case *wire.Value:
out[types.TypeString(t, nil)] = v.Pos
default:
panic("unreachable")
}
})
for _, t := range sortSet(out) {
fmt.Printf("\t%s%s%s\n", green, t, reset)
fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t]))
} }
})
for _, t := range sortSet(out) {
fmt.Printf("\t%s%s%s\n", green, t, reset)
fmt.Printf("\t\tat %v\n", info.Fset.Position(out[t]))
} }
} }
} }
if len(errs) > 0 {
logErrors(errs)
return errors.New("error loading packages")
}
return nil return nil
} }
@@ -328,3 +333,9 @@ func formatProviderSetName(importPath, varName string) string {
// Since varName is an identifier, it doesn't make sense to quote. // Since varName is an identifier, it doesn't make sense to quote.
return strconv.Quote(importPath) + "." + varName return strconv.Quote(importPath) + "." + varName
} }
func logErrors(errs []error) {
for _, err := range errs {
fmt.Fprintln(os.Stderr, strings.Replace(err.Error(), "\n", "\n\t", -1))
}
}

View File

@@ -144,8 +144,9 @@ type Value struct {
} }
// Load finds all the provider sets in the given packages, as well as // Load finds all the provider sets in the given packages, as well as
// the provider sets' transitive dependencies. // the provider sets' transitive dependencies. It may return both an error
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) { // and Info.
func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) {
// TODO(light): Stop errors from printing to stderr. // TODO(light): Stop errors from printing to stderr.
conf := &loader.Config{ conf := &loader.Config{
Build: bctx, Build: bctx,
@@ -157,7 +158,7 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, error) {
} }
prog, err := conf.Load() prog, err := conf.Load()
if err != nil { if err != nil {
return nil, fmt.Errorf("load: %v", err) return nil, []error{fmt.Errorf("load: %v", err)}
} }
info := &Info{ info := &Info{
Fset: prog.Fset, Fset: prog.Fset,

View File

@@ -38,10 +38,10 @@ import (
// Generate performs dependency injection for a single package, // Generate performs dependency injection for a single package,
// returning the gofmt'd Go source code. // returning the gofmt'd Go source code.
func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) { func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) {
mainPkg, err := bctx.Import(pkg, wd, build.FindOnly) mainPkg, err := bctx.Import(pkg, wd, build.FindOnly)
if err != nil { if err != nil {
return nil, fmt.Errorf("load: %v", err) return nil, []error{fmt.Errorf("load: %v", err)}
} }
// TODO(light): Stop errors from printing to stderr. // TODO(light): Stop errors from printing to stderr.
conf := &loader.Config{ conf := &loader.Config{
@@ -59,17 +59,17 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
prog, err := conf.Load() prog, err := conf.Load()
if err != nil { if err != nil {
return nil, fmt.Errorf("load: %v", err) return nil, []error{fmt.Errorf("load: %v", err)}
} }
if len(prog.InitialPackages()) != 1 { if len(prog.InitialPackages()) != 1 {
// This is more of a violated precondition than anything else. // This is more of a violated precondition than anything else.
return nil, fmt.Errorf("load: got %d packages", len(prog.InitialPackages())) return nil, []error{fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))}
} }
pkgInfo := prog.InitialPackages()[0] pkgInfo := prog.InitialPackages()[0]
g := newGen(prog, pkgInfo.Pkg.Path()) g := newGen(prog, pkgInfo.Pkg.Path())
injectorFiles, err := generateInjectors(g, pkgInfo) injectorFiles, err := generateInjectors(g, pkgInfo)
if err != nil { if err != nil {
return nil, err return nil, []error{err}
} }
copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info) copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info)
goSrc := g.frame() goSrc := g.frame()
@@ -77,7 +77,7 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, error) {
if err != nil { if err != nil {
// This is likely a bug from a poorly generated source file. // This is likely a bug from a poorly generated source file.
// Return an error and the unformatted source. // Return an error and the unformatted source.
return goSrc, err return goSrc, []error{err}
} }
return fmtSrc, nil return fmtSrc, nil
} }

View File

@@ -69,17 +69,17 @@ func TestWire(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
bctx := test.buildContext() bctx := test.buildContext()
gen, err := Generate(bctx, wd, test.pkg) gen, errs := Generate(bctx, wd, test.pkg)
if len(gen) > 0 { if len(gen) > 0 {
defer t.Logf("wire_gen.go:\n%s", gen) defer t.Logf("wire_gen.go:\n%s", gen)
} }
if err != nil { if len(errs) > 0 {
if !test.wantError { if !test.wantError {
t.Fatalf("wirego: %v", err) t.Fatalf("wirego: %v", errs)
} }
return return
} }
if err == nil && test.wantError { if len(errs) == 0 && test.wantError {
t.Fatal("wirego succeeded; want error") t.Fatal("wirego succeeded; want error")
} }
@@ -133,15 +133,15 @@ func TestWire(t *testing.T) {
} }
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
bctx := test.buildContext() bctx := test.buildContext()
gold, err := Generate(bctx, wd, test.pkg) gold, errs := Generate(bctx, wd, test.pkg)
if err != nil { if len(errs) > 0 {
t.Fatal("wirego:", err) t.Fatal("wirego:", errs)
} }
goldstr := string(gold) goldstr := string(gold)
for i := 0; i < runs-1; i++ { for i := 0; i < runs-1; i++ {
out, err := Generate(bctx, wd, test.pkg) out, errs := Generate(bctx, wd, test.pkg)
if err != nil { if len(errs) > 0 {
t.Fatal("wirego (on subsequent run):", err) t.Fatal("wirego (on subsequent run):", errs)
} }
if !bytes.Equal(gold, out) { if !bytes.Equal(gold, out) {
t.Fatalf("wirego output differs when run repeatedly on same input:\n%s", diff(goldstr, string(out))) t.Fatalf("wirego output differs when run repeatedly on same input:\n%s", diff(goldstr, string(out)))