goose: add interface binding
An interface binding instructs goose that a concrete type should be used to satisfy a dependency on an interface type. goose could determine this implicitly, but having an explicit directive makes the provider author's intent clear and allows different concrete types to satisfy different smaller interfaces. Reviewed-by: Tuo Shan <shantuo@google.com>
This commit is contained in:
@@ -18,11 +18,22 @@ import (
|
||||
// providerSet.
|
||||
type providerSet struct {
|
||||
providers []*providerInfo
|
||||
bindings []ifaceBinding
|
||||
imports []providerSetImport
|
||||
}
|
||||
|
||||
// An ifaceBinding declares that a type should be used to satisfy inputs
|
||||
// of the given interface type.
|
||||
//
|
||||
// provided is always a type that is assignable to iface.
|
||||
type ifaceBinding struct {
|
||||
iface types.Type
|
||||
provided types.Type
|
||||
pos token.Pos
|
||||
}
|
||||
|
||||
type providerSetImport struct {
|
||||
providerSetRef
|
||||
symref
|
||||
pos token.Pos
|
||||
}
|
||||
|
||||
@@ -30,7 +41,7 @@ type providerSetImport struct {
|
||||
type providerInfo struct {
|
||||
importPath string
|
||||
funcName string
|
||||
pos token.Pos
|
||||
pos token.Pos // provider function definition
|
||||
args []providerInput
|
||||
out types.Type
|
||||
hasErr bool
|
||||
@@ -80,43 +91,94 @@ func processUnassociatedDirective(fctx findContext, sets map[string]*providerSet
|
||||
return fmt.Errorf("%v: only functions can be marked as providers", fctx.fset.Position(d.pos))
|
||||
case "use":
|
||||
// Ignore, picked up by injector flow.
|
||||
case "bind":
|
||||
args := d.args()
|
||||
if len(args) != 3 {
|
||||
return fmt.Errorf("%v: invalid binding: expected TARGET IFACE TYPE", fctx.fset.Position(d.pos))
|
||||
}
|
||||
ifaceRef, err := parseSymbolRef(fctx.r, args[1], scope, fctx.pkg.Path(), d.pos)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
|
||||
}
|
||||
ifaceObj, err := ifaceRef.resolveObject(fctx.pkg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
|
||||
}
|
||||
ifaceDecl, ok := ifaceObj.(*types.TypeName)
|
||||
if !ok {
|
||||
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), ifaceRef)
|
||||
}
|
||||
iface := ifaceDecl.Type()
|
||||
methodSet, ok := iface.Underlying().(*types.Interface)
|
||||
if !ok {
|
||||
return fmt.Errorf("%v: %v does not name an interface type", fctx.fset.Position(d.pos), ifaceRef)
|
||||
}
|
||||
|
||||
providedRef, err := parseSymbolRef(fctx.r, strings.TrimPrefix(args[2], "*"), scope, fctx.pkg.Path(), d.pos)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
|
||||
}
|
||||
providedObj, err := providedRef.resolveObject(fctx.pkg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
|
||||
}
|
||||
providedDecl, ok := providedObj.(*types.TypeName)
|
||||
if !ok {
|
||||
return fmt.Errorf("%v: %v does not name a type", fctx.fset.Position(d.pos), providedRef)
|
||||
}
|
||||
provided := providedDecl.Type()
|
||||
if types.Identical(provided, iface) {
|
||||
return fmt.Errorf("%v: cannot bind interface to itself", fctx.fset.Position(d.pos))
|
||||
}
|
||||
if strings.HasPrefix(args[2], "*") {
|
||||
provided = types.NewPointer(provided)
|
||||
}
|
||||
if !types.Implements(provided, methodSet) {
|
||||
return fmt.Errorf("%v: %s does not implement %s", fctx.fset.Position(d.pos), types.TypeString(provided, nil), types.TypeString(iface, nil))
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
if pset := sets[name]; pset != nil {
|
||||
pset.bindings = append(pset.bindings, ifaceBinding{
|
||||
iface: iface,
|
||||
provided: provided,
|
||||
})
|
||||
} else {
|
||||
sets[name] = &providerSet{
|
||||
bindings: []ifaceBinding{{
|
||||
iface: iface,
|
||||
provided: provided,
|
||||
}},
|
||||
}
|
||||
}
|
||||
case "import":
|
||||
args := d.args()
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("%s: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
|
||||
return fmt.Errorf("%v: invalid import: expected TARGET SETREF", fctx.fset.Position(d.pos))
|
||||
}
|
||||
name := args[0]
|
||||
for _, spec := range args[1:] {
|
||||
ref, err := parseProviderSetRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
|
||||
ref, err := parseSymbolRef(fctx.r, spec, scope, fctx.pkg.Path(), d.pos)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %v", fctx.fset.Position(d.pos), err)
|
||||
}
|
||||
if ref.importPath != fctx.pkg.Path() {
|
||||
imported := false
|
||||
for _, imp := range fctx.pkg.Imports() {
|
||||
if ref.importPath == imp.Path() {
|
||||
imported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !imported {
|
||||
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
|
||||
}
|
||||
if findImport(fctx.pkg, ref.importPath) == nil {
|
||||
return fmt.Errorf("%v: provider set %s imports %q which is not in the package's imports", fctx.fset.Position(d.pos), name, ref.importPath)
|
||||
}
|
||||
if mod := sets[name]; mod != nil {
|
||||
found := false
|
||||
for _, other := range mod.imports {
|
||||
if ref == other.providerSetRef {
|
||||
if ref == other.symref {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mod.imports = append(mod.imports, providerSetImport{providerSetRef: ref, pos: d.pos})
|
||||
mod.imports = append(mod.imports, providerSetImport{symref: ref, pos: d.pos})
|
||||
}
|
||||
} else {
|
||||
sets[name] = &providerSet{
|
||||
imports: []providerSetImport{{providerSetRef: ref, pos: d.pos}},
|
||||
imports: []providerSetImport{{symref: ref, pos: d.pos}},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -233,7 +295,7 @@ func newProviderSetCache(prog *loader.Program, r *importResolver) *providerSetCa
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
|
||||
func (mc *providerSetCache) get(ref symref) (*providerSet, error) {
|
||||
if mods, cached := mc.sets[ref.importPath]; cached {
|
||||
mod := mods[ref.name]
|
||||
if mod == nil {
|
||||
@@ -263,46 +325,58 @@ func (mc *providerSetCache) get(ref providerSetRef) (*providerSet, error) {
|
||||
return mod, nil
|
||||
}
|
||||
|
||||
// A providerSetRef is a parsed reference to a collection of providers.
|
||||
type providerSetRef struct {
|
||||
// A symref is a parsed reference to a symbol (either a provider set or a Go object).
|
||||
type symref struct {
|
||||
importPath string
|
||||
name string
|
||||
}
|
||||
|
||||
func parseProviderSetRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (providerSetRef, error) {
|
||||
func parseSymbolRef(r *importResolver, ref string, s *types.Scope, pkg string, pos token.Pos) (symref, error) {
|
||||
// TODO(light): verify that provider set name is an identifier before returning
|
||||
|
||||
i := strings.LastIndexByte(ref, '.')
|
||||
if i == -1 {
|
||||
return providerSetRef{importPath: pkg, name: ref}, nil
|
||||
return symref{importPath: pkg, name: ref}, nil
|
||||
}
|
||||
imp, name := ref[:i], ref[i+1:]
|
||||
if strings.HasPrefix(imp, `"`) {
|
||||
path, err := strconv.Unquote(imp)
|
||||
if err != nil {
|
||||
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: bad import path", ref)
|
||||
return symref{}, fmt.Errorf("parse symbol reference %q: bad import path", ref)
|
||||
}
|
||||
path, err = r.resolve(pos, path)
|
||||
if err != nil {
|
||||
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %v", ref, err)
|
||||
return symref{}, fmt.Errorf("parse symbol reference %q: %v", ref, err)
|
||||
}
|
||||
return providerSetRef{importPath: path, name: name}, nil
|
||||
return symref{importPath: path, name: name}, nil
|
||||
}
|
||||
_, obj := s.LookupParent(imp, pos)
|
||||
if obj == nil {
|
||||
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: unknown identifier %s", ref, imp)
|
||||
return symref{}, fmt.Errorf("parse symbol reference %q: unknown identifier %s", ref, imp)
|
||||
}
|
||||
pn, ok := obj.(*types.PkgName)
|
||||
if !ok {
|
||||
return providerSetRef{}, fmt.Errorf("parse provider set reference %q: %s does not name a package", ref, imp)
|
||||
return symref{}, fmt.Errorf("parse symbol reference %q: %s does not name a package", ref, imp)
|
||||
}
|
||||
return providerSetRef{importPath: pn.Imported().Path(), name: name}, nil
|
||||
return symref{importPath: pn.Imported().Path(), name: name}, nil
|
||||
}
|
||||
|
||||
func (ref providerSetRef) String() string {
|
||||
func (ref symref) String() string {
|
||||
return strconv.Quote(ref.importPath) + "." + ref.name
|
||||
}
|
||||
|
||||
func (ref symref) resolveObject(pkg *types.Package) (types.Object, error) {
|
||||
imp := findImport(pkg, ref.importPath)
|
||||
if imp == nil {
|
||||
return nil, fmt.Errorf("resolve Go reference %v: package not directly imported", ref)
|
||||
}
|
||||
obj := imp.Scope().Lookup(ref.name)
|
||||
if obj == nil {
|
||||
return nil, fmt.Errorf("resolve Go reference %v: %s not found in package", ref, ref.name)
|
||||
}
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
type importResolver struct {
|
||||
fset *token.FileSet
|
||||
bctx *build.Context
|
||||
@@ -333,6 +407,18 @@ func (r *importResolver) resolve(pos token.Pos, path string) (string, error) {
|
||||
return pkg.ImportPath, nil
|
||||
}
|
||||
|
||||
func findImport(pkg *types.Package, path string) *types.Package {
|
||||
if pkg.Path() == path {
|
||||
return pkg
|
||||
}
|
||||
for _, imp := range pkg.Imports() {
|
||||
if imp.Path() == path {
|
||||
return imp
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A directive is a parsed goose comment.
|
||||
type directive struct {
|
||||
pos token.Pos
|
||||
|
||||
Reference in New Issue
Block a user