goose: remove optional directive

This introduces some short-term pain in practice, but I aim to fix that
with the goose.Value directive.

Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
Ross Light
2018-04-26 14:23:06 -04:00
parent cfc6111ea5
commit 3345599aaf
20 changed files with 11 additions and 197 deletions

View File

@@ -245,22 +245,6 @@ set that provides the concrete type.
[type identity]: https://golang.org/ref/spec#Type_identity [type identity]: https://golang.org/ref/spec#Type_identity
[return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces [return concrete types]: https://github.com/golang/go/wiki/CodeReviewComments#interfaces
### Optional Inputs
A provider input can be marked optional using `goose:optional`:
```go
//goose:provide Bar
//goose:optional foo
func provideBar(foo Foo) Bar {
// ...
}
```
If used as part of an injector that does not bring in the `Foo` dependency, then
the injector will pass the provider the zero value as the `foo` argument.
### Struct Providers ### Struct Providers
Structs can also be marked as providers. Instead of calling a function, an Structs can also be marked as providers. Instead of calling a function, an
@@ -308,9 +292,6 @@ func injectFooBar() FooBar {
And similarly if the injector needed a `*FooBar`. And similarly if the injector needed a `*FooBar`.
Like function providers, you can mark dependencies of a struct provider optional
by using the `goose:optional` directive with the field names.
### Cleanup functions ### Cleanup functions
If a provider creates a value that needs to be cleaned up (e.g. closing a file), If a provider creates a value that needs to be cleaned up (e.g. closing a file),

View File

@@ -85,9 +85,6 @@ func solve(mc *providerSetCache, out types.Type, given []types.Type, sets []symr
p, _ := providers.At(typ).(*Provider) p, _ := providers.At(typ).(*Provider)
if p == nil { if p == nil {
if trail[len(trail)-1].Optional {
return nil
}
if len(trail) == 1 { if len(trail) == 1 {
return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil)) return fmt.Errorf("no provider found for %s (output of injector)", types.TypeString(typ, nil))
} }

View File

@@ -79,8 +79,9 @@ type Provider struct {
// ProviderInput describes an incoming edge in the provider graph. // ProviderInput describes an incoming edge in the provider graph.
type ProviderInput struct { type ProviderInput struct {
Type types.Type Type types.Type
Optional bool
// TODO(light): Move field name into this struct.
} }
// Load finds all the provider sets in the given packages, as well as // Load finds all the provider sets in the given packages, as well as
@@ -203,7 +204,7 @@ func findProviderSets(fctx findContext, files []*ast.File) (map[string]*Provider
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration. // processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
func processUnassociatedDirective(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, d directive) error { func processUnassociatedDirective(fctx findContext, sets map[string]*ProviderSet, scope *types.Scope, d directive) error {
switch d.kind { switch d.kind {
case "provide", "optional": case "provide":
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos)) return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
case "use": case "use":
// Ignore, picked up by injector flow. // Ignore, picked up by injector flow.
@@ -323,11 +324,6 @@ func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope
return err return err
} }
if !p.isValid() { if !p.isValid() {
for _, d := range dg.dirs {
if d.kind == "optional" {
return fmt.Errorf("%v: cannot use goose:%s directive on non-provider", fctx.fset.Position(d.pos), d.kind)
}
}
return nil return nil
} }
var providerSetName string var providerSetName string
@@ -337,18 +333,10 @@ func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope
} else if len(args) > 1 { } else if len(args) > 1 {
return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(p.pos)) return fmt.Errorf("%v: goose:provide takes at most one argument", fctx.fset.Position(p.pos))
} }
optionals := make(map[string]token.Pos)
for _, d := range dg.dirs {
if d.kind == "optional" {
for _, arg := range d.args() {
optionals[arg] = d.pos
}
}
}
switch decl := dg.decl.(type) { switch decl := dg.decl.(type) {
case *ast.FuncDecl: case *ast.FuncDecl:
fn := fctx.typeInfo.ObjectOf(decl.Name).(*types.Func) fn := fctx.typeInfo.ObjectOf(decl.Name).(*types.Func)
provider, err := processFuncProvider(fctx, fn, optionals) provider, err := processFuncProvider(fctx, fn)
if err != nil { if err != nil {
return err return err
} }
@@ -379,7 +367,7 @@ func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope
if _, ok := typeName.Type().(*types.Named).Underlying().(*types.Struct); !ok { if _, ok := typeName.Type().(*types.Named).Underlying().(*types.Struct); !ok {
return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos)) return fmt.Errorf("%v: only functions and structs can be marked as providers", fctx.fset.Position(p.pos))
} }
provider, err := processStructProvider(fctx, typeName, optionals) provider, err := processStructProvider(fctx, typeName)
if err != nil { if err != nil {
return err return err
} }
@@ -410,18 +398,9 @@ func processDeclDirectives(fctx findContext, sets map[string]*ProviderSet, scope
return nil return nil
} }
func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[string]token.Pos) (*Provider, error) { func processFuncProvider(fctx findContext, fn *types.Func) (*Provider, error) {
sig := fn.Type().(*types.Signature) sig := fn.Type().(*types.Signature)
optionals := make([]bool, sig.Params().Len())
for arg, dpos := range optionalArgs {
pi := paramIndex(sig.Params(), arg)
if pi == -1 {
return nil, fmt.Errorf("%v: %s is not a parameter of func %s", fctx.fset.Position(dpos), arg, fn.Name())
}
optionals[pi] = true
}
fpos := fn.Pos() fpos := fn.Pos()
r := sig.Results() r := sig.Results()
var hasCleanup, hasErr bool var hasCleanup, hasErr bool
@@ -461,8 +440,7 @@ func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[stri
} }
for i := 0; i < params.Len(); i++ { for i := 0; i < params.Len(); i++ {
provider.Args[i] = ProviderInput{ provider.Args[i] = ProviderInput{
Type: params.At(i).Type(), Type: params.At(i).Type(),
Optional: optionals[i],
} }
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) {
@@ -473,21 +451,9 @@ func processFuncProvider(fctx findContext, fn *types.Func, optionalArgs map[stri
return provider, nil return provider, nil
} }
func processStructProvider(fctx findContext, typeName *types.TypeName, optionals map[string]token.Pos) (*Provider, error) { func processStructProvider(fctx findContext, typeName *types.TypeName) (*Provider, error) {
out := typeName.Type() out := typeName.Type()
st := out.Underlying().(*types.Struct) st := out.Underlying().(*types.Struct)
for arg, dpos := range optionals {
found := false
for i := 0; i < st.NumFields(); i++ {
if st.Field(i).Name() == arg {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("%v: %s is not a field of struct %s", fctx.fset.Position(dpos), arg, types.TypeString(st, nil))
}
}
pos := typeName.Pos() pos := typeName.Pos()
provider := &Provider{ provider := &Provider{
@@ -501,10 +467,8 @@ func processStructProvider(fctx findContext, typeName *types.TypeName, optionals
} }
for i := 0; i < st.NumFields(); i++ { for i := 0; i < st.NumFields(); i++ {
f := st.Field(i) f := st.Field(i)
_, optional := optionals[f.Name()]
provider.Args[i] = ProviderInput{ provider.Args[i] = ProviderInput{
Type: f.Type(), Type: f.Type(),
Optional: optional,
} }
provider.Fields[i] = f.Name() provider.Fields[i] = f.Name()
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
@@ -688,7 +652,7 @@ func parseFile(fset *token.FileSet, f *ast.File) []directiveGroup {
// Move directives that don't associate into the unassociated group. // Move directives that don't associate into the unassociated group.
n := 0 n := 0
for i := start; i < len(grp.dirs); i++ { for i := start; i < len(grp.dirs); i++ {
if k := grp.dirs[i].kind; k == "provide" || k == "optional" || k == "use" { if k := grp.dirs[i].kind; k == "provide" || k == "use" {
grp.dirs[start+n] = grp.dirs[i] grp.dirs[start+n] = grp.dirs[i]
n++ n++
} else { } else {

View File

@@ -1,16 +0,0 @@
package main
import "fmt"
func main() {
fmt.Println(injectBar())
}
type foo int
type bar int
//goose:provide
//goose:optional f
func provideBar(f foo) bar {
return bar(f)
}

View File

@@ -1,7 +0,0 @@
//+build gooseinject
package main
//goose:use provideBar
func injectBar() bar

View File

@@ -1 +0,0 @@
0

View File

@@ -1 +0,0 @@
foo

View File

@@ -1,16 +0,0 @@
package main
import "fmt"
func main() {
fmt.Println(injectBar(42))
}
type foo int
type bar int
//goose:provide
//goose:optional f
func provideBar(f foo) bar {
return bar(f)
}

View File

@@ -1,7 +0,0 @@
//+build gooseinject
package main
//goose:use provideBar
func injectBar(foo) bar

View File

@@ -1 +0,0 @@
42

View File

@@ -1 +0,0 @@
foo

View File

@@ -1,23 +0,0 @@
package main
import "fmt"
func main() {
fb := injectFooBar()
fmt.Println(fb.Foo, fb.OptionalBar)
}
type Foo int
type Bar int
//goose:provide Set
//goose:optional OptionalBar
type FooBar struct {
Foo Foo
OptionalBar Bar
}
//goose:provide Set
func provideFoo() Foo {
return 42
}

View File

@@ -1,7 +0,0 @@
//+build gooseinject
package main
//goose:use Set
func injectFooBar() FooBar

View File

@@ -1 +0,0 @@
42 0

View File

@@ -1 +0,0 @@
foo

View File

@@ -1,28 +0,0 @@
package main
import "fmt"
func main() {
fb := injectFooBar()
fmt.Println(fb.Foo, fb.Bar)
}
type Foo int
type Bar int
//goose:provide Set
//goose:optional Bar
type FooBar struct {
Foo Foo
Bar Bar
}
//goose:provide Set
func provideFoo() Foo {
return 41
}
//goose:provide Set
func provideBar() Bar {
return 1
}

View File

@@ -1,7 +0,0 @@
//+build gooseinject
package main
//goose:use Set
func injectFooBar() FooBar

View File

@@ -1 +0,0 @@
41 1

View File

@@ -1 +0,0 @@
foo

View File

@@ -198,9 +198,6 @@ func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports ma
// Try to see if any args haven't been visited. // Try to see if any args haven't been visited.
allPresent := true allPresent := true
for _, arg := range p.Args { for _, arg := range p.Args {
if arg.Optional {
continue
}
if inputVisited.At(arg.Type) == nil { if inputVisited.At(arg.Type) == nil {
allPresent = false allPresent = false
} }
@@ -208,9 +205,6 @@ func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports ma
if !allPresent { if !allPresent {
stk = append(stk, curr) stk = append(stk, curr)
for _, arg := range p.Args { for _, arg := range p.Args {
if arg.Optional {
continue
}
if inputVisited.At(arg.Type) == nil { if inputVisited.At(arg.Type) == nil {
stk = append(stk, arg.Type) stk = append(stk, arg.Type)
} }
@@ -222,9 +216,6 @@ func gather(info *goose.Info, key goose.ProviderSetID) (_ []outGroup, imports ma
in := new(typeutil.Map) in := new(typeutil.Map)
in.SetHasher(hash) in.SetHasher(hash)
for _, arg := range p.Args { for _, arg := range p.Args {
if arg.Optional {
continue
}
i := inputVisited.At(arg.Type).(int) i := inputVisited.At(arg.Type).(int)
if i == -1 { if i == -1 {
in.Set(arg.Type, true) in.Set(arg.Type, true)