2018-04-02 09:21:52 -07:00
package goose
import (
"fmt"
"go/ast"
2018-04-04 10:58:07 -07:00
"go/build"
2018-04-02 09:21:52 -07:00
"go/token"
"go/types"
2018-04-04 10:58:07 -07:00
"path/filepath"
2018-04-02 09:21:52 -07:00
"strconv"
"strings"
2018-04-02 10:57:48 -07:00
"unicode"
2018-04-02 09:21:52 -07:00
"golang.org/x/tools/go/loader"
)
2018-04-04 14:42:56 -07:00
// A ProviderSet describes a set of providers. The zero value is an empty
// ProviderSet.
type ProviderSet struct {
Providers [ ] * Provider
Bindings [ ] IfaceBinding
Imports [ ] ProviderSetImport
2018-04-02 09:21:52 -07:00
}
2018-04-04 14:42:56 -07:00
// An IfaceBinding declares that a type should be used to satisfy inputs
2018-04-02 14:08:17 -07:00
// of the given interface type.
2018-04-04 14:42:56 -07:00
type IfaceBinding struct {
// Iface is the interface type, which is what can be injected.
Iface types . Type
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// Provided is always a type that is assignable to Iface.
Provided types . Type
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// Pos is the position where the binding was declared.
Pos token . Pos
2018-04-02 14:08:17 -07:00
}
2018-04-04 14:42:56 -07:00
// A ProviderSetImport adds providers from one provider set into another.
type ProviderSetImport struct {
ProviderSetID
Pos token . Pos
2018-04-02 09:21:52 -07:00
}
2018-04-04 14:42:56 -07:00
// Provider records the signature of a provider. A provider is a
// single Go object, either a function or a named struct type.
type Provider struct {
// ImportPath is the package path that the Go object resides in.
ImportPath string
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// Name is the name of the Go object.
Name string
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// Pos is the source position of the func keyword or type spec
2018-04-03 21:11:53 -07:00
// defining this provider.
2018-04-04 14:42:56 -07:00
Pos token . Pos
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// Args is the list of data dependencies this provider has.
Args [ ] ProviderInput
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// IsStruct is true if this provider is a named struct type.
2018-04-03 21:11:53 -07:00
// Otherwise it's a function.
2018-04-04 14:42:56 -07:00
IsStruct bool
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// Fields lists the field names to populate. This will map 1:1 with
2018-04-03 21:11:53 -07:00
// elements in Args.
2018-04-04 14:42:56 -07:00
Fields [ ] string
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// Out is the type this provider produces.
Out types . Type
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// HasCleanup reports whether the provider function returns a cleanup
2018-04-03 21:11:53 -07:00
// function. (Always false for structs.)
2018-04-04 14:42:56 -07:00
HasCleanup bool
2018-04-03 21:11:53 -07:00
2018-04-04 14:42:56 -07:00
// HasErr reports whether the provider function can return an error.
2018-04-03 21:11:53 -07:00
// (Always false for structs.)
2018-04-04 14:42:56 -07:00
HasErr bool
}
// ProviderInput describes an incoming edge in the provider graph.
type ProviderInput struct {
2018-04-26 14:23:06 -04:00
Type types . Type
// TODO(light): Move field name into this struct.
2018-04-04 14:42:56 -07:00
}
// Load finds all the provider sets in the given packages, as well as
// the provider sets' transitive dependencies.
func Load ( bctx * build . Context , wd string , pkgs [ ] string ) ( * Info , error ) {
conf := newLoaderConfig ( bctx , wd , false )
for _ , p := range pkgs {
conf . Import ( p )
}
prog , err := conf . Load ( )
if err != nil {
return nil , fmt . Errorf ( "load: %v" , err )
}
r := newImportResolver ( conf , prog . Fset )
var next [ ] string
initial := make ( map [ string ] struct { } )
for _ , pkgInfo := range prog . InitialPackages ( ) {
path := pkgInfo . Pkg . Path ( )
next = append ( next , path )
initial [ path ] = struct { } { }
}
visited := make ( map [ string ] struct { } )
info := & Info {
Fset : prog . Fset ,
Sets : make ( map [ ProviderSetID ] * ProviderSet ) ,
All : make ( map [ ProviderSetID ] * ProviderSet ) ,
}
for len ( next ) > 0 {
curr := next [ len ( next ) - 1 ]
next = next [ : len ( next ) - 1 ]
if _ , ok := visited [ curr ] ; ok {
continue
}
visited [ curr ] = struct { } { }
pkgInfo := prog . Package ( curr )
sets , err := findProviderSets ( findContext {
fset : prog . Fset ,
pkg : pkgInfo . Pkg ,
typeInfo : & pkgInfo . Info ,
r : r ,
} , pkgInfo . Files )
if err != nil {
return nil , fmt . Errorf ( "load: %v" , err )
}
path := pkgInfo . Pkg . Path ( )
for name , set := range sets {
info . All [ ProviderSetID { path , name } ] = set
for _ , imp := range set . Imports {
next = append ( next , imp . ImportPath )
}
}
if _ , ok := initial [ path ] ; ok {
for name , set := range sets {
info . Sets [ ProviderSetID { path , name } ] = set
}
}
}
return info , nil
}
// Info holds the result of Load.
type Info struct {
Fset * token . FileSet
// Sets contains all the provider sets in the initial packages.
Sets map [ ProviderSetID ] * ProviderSet
// All contains all the provider sets transitively depended on by the
// initial packages' provider sets.
All map [ ProviderSetID ] * ProviderSet
}
// A ProviderSetID identifies a provider set.
type ProviderSetID struct {
ImportPath string
Name string
}
// String returns the ID as ""path/to/pkg".Foo".
func ( id ProviderSetID ) String ( ) string {
return id . symref ( ) . String ( )
2018-04-02 09:21:52 -07:00
}
2018-04-04 14:42:56 -07:00
func ( id ProviderSetID ) symref ( ) symref {
return symref { importPath : id . ImportPath , name : id . Name }
2018-03-30 21:34:08 -07:00
}
type findContext struct {
fset * token . FileSet
pkg * types . Package
typeInfo * types . Info
r * importResolver
}
2018-04-02 09:21:52 -07:00
// findProviderSets processes a package and extracts the provider sets declared in it.
2018-04-04 14:42:56 -07:00
func findProviderSets ( fctx findContext , files [ ] * ast . File ) ( map [ string ] * ProviderSet , error ) {
sets := make ( map [ string ] * ProviderSet )
2018-04-02 09:21:52 -07:00
for _ , f := range files {
2018-03-30 21:34:08 -07:00
fileScope := fctx . typeInfo . Scopes [ f ]
if fileScope == nil {
return nil , fmt . Errorf ( "%s: no scope found for file (likely a bug)" , fctx . fset . File ( f . Pos ( ) ) . Name ( ) )
}
for _ , dg := range parseFile ( fctx . fset , f ) {
if dg . decl != nil {
if err := processDeclDirectives ( fctx , sets , fileScope , dg ) ; err != nil {
return nil , err
}
} else {
for _ , d := range dg . dirs {
if err := processUnassociatedDirective ( fctx , sets , fileScope , d ) ; err != nil {
return nil , err
2018-04-02 09:21:52 -07:00
}
}
}
}
2018-03-30 21:34:08 -07:00
}
return sets , nil
}
// processUnassociatedDirective handles any directive that was not associated with a top-level declaration.
2018-04-04 14:42:56 -07:00
func processUnassociatedDirective ( fctx findContext , sets map [ string ] * ProviderSet , scope * types . Scope , d directive ) error {
2018-03-30 21:34:08 -07:00
switch d . kind {
2018-04-26 14:23:06 -04:00
case "provide" :
2018-03-30 21:34:08 -07:00
return fmt . Errorf ( "%v: only functions can be marked as providers" , fctx . fset . Position ( d . pos ) )
case "use" :
// Ignore, picked up by injector flow.
2018-04-02 14:08:17 -07:00
case "bind" :
args := d . args ( )
if len ( args ) != 3 {
return fmt . Errorf ( "%v: invalid binding: expected TARGET IFACE TYPE" , fctx . fset . Position ( d . pos ) )
}
ifaceRef , err := parseSymbolRef ( fctx . r , args [ 1 ] , scope , fctx . pkg . Path ( ) , d . pos )
if err != nil {
return fmt . Errorf ( "%v: %v" , fctx . fset . Position ( d . pos ) , err )
}
ifaceObj , err := ifaceRef . resolveObject ( fctx . pkg )
if err != nil {
return fmt . Errorf ( "%v: %v" , fctx . fset . Position ( d . pos ) , err )
}
ifaceDecl , ok := ifaceObj . ( * types . TypeName )
if ! ok {
return fmt . Errorf ( "%v: %v does not name a type" , fctx . fset . Position ( d . pos ) , ifaceRef )
}
iface := ifaceDecl . Type ( )
methodSet , ok := iface . Underlying ( ) . ( * types . Interface )
if ! ok {
return fmt . Errorf ( "%v: %v does not name an interface type" , fctx . fset . Position ( d . pos ) , ifaceRef )
}
providedRef , err := parseSymbolRef ( fctx . r , strings . TrimPrefix ( args [ 2 ] , "*" ) , scope , fctx . pkg . Path ( ) , d . pos )
if err != nil {
return fmt . Errorf ( "%v: %v" , fctx . fset . Position ( d . pos ) , err )
}
providedObj , err := providedRef . resolveObject ( fctx . pkg )
if err != nil {
return fmt . Errorf ( "%v: %v" , fctx . fset . Position ( d . pos ) , err )
}
providedDecl , ok := providedObj . ( * types . TypeName )
if ! ok {
return fmt . Errorf ( "%v: %v does not name a type" , fctx . fset . Position ( d . pos ) , providedRef )
}
provided := providedDecl . Type ( )
if types . Identical ( provided , iface ) {
return fmt . Errorf ( "%v: cannot bind interface to itself" , fctx . fset . Position ( d . pos ) )
}
if strings . HasPrefix ( args [ 2 ] , "*" ) {
provided = types . NewPointer ( provided )
}
if ! types . Implements ( provided , methodSet ) {
return fmt . Errorf ( "%v: %s does not implement %s" , fctx . fset . Position ( d . pos ) , types . TypeString ( provided , nil ) , types . TypeString ( iface , nil ) )
}
name := args [ 0 ]
if pset := sets [ name ] ; pset != nil {
2018-04-04 14:42:56 -07:00
pset . Bindings = append ( pset . Bindings , IfaceBinding {
Iface : iface ,
Provided : provided ,
2018-04-02 14:08:17 -07:00
} )
} else {
2018-04-04 14:42:56 -07:00
sets [ name ] = & ProviderSet {
Bindings : [ ] IfaceBinding { {
Iface : iface ,
Provided : provided ,
2018-04-02 14:08:17 -07:00
} } ,
}
}
2018-03-30 21:34:08 -07:00
case "import" :
2018-04-02 10:57:48 -07:00
args := d . args ( )
if len ( args ) < 2 {
2018-04-02 14:08:17 -07:00
return fmt . Errorf ( "%v: invalid import: expected TARGET SETREF" , fctx . fset . Position ( d . pos ) )
2018-03-30 21:34:08 -07:00
}
2018-04-02 10:57:48 -07:00
name := args [ 0 ]
for _ , spec := range args [ 1 : ] {
2018-04-02 14:08:17 -07:00
ref , err := parseSymbolRef ( fctx . r , spec , scope , fctx . pkg . Path ( ) , d . pos )
2018-04-02 10:57:48 -07:00
if err != nil {
return fmt . Errorf ( "%v: %v" , fctx . fset . Position ( d . pos ) , err )
2018-04-02 09:21:52 -07:00
}
2018-04-02 14:08:17 -07:00
if findImport ( fctx . pkg , ref . importPath ) == nil {
return fmt . Errorf ( "%v: provider set %s imports %q which is not in the package's imports" , fctx . fset . Position ( d . pos ) , name , ref . importPath )
2018-04-02 09:21:52 -07:00
}
2018-04-02 10:57:48 -07:00
if mod := sets [ name ] ; mod != nil {
found := false
2018-04-04 14:42:56 -07:00
for _ , other := range mod . Imports {
if ref == other . symref ( ) {
2018-04-02 10:57:48 -07:00
found = true
break
}
}
if ! found {
2018-04-04 14:42:56 -07:00
mod . Imports = append ( mod . Imports , ProviderSetImport {
ProviderSetID : ProviderSetID {
ImportPath : ref . importPath ,
Name : ref . name ,
} ,
Pos : d . pos ,
} )
2018-04-02 10:57:48 -07:00
}
} else {
2018-04-04 14:42:56 -07:00
sets [ name ] = & ProviderSet {
Imports : [ ] ProviderSetImport { {
ProviderSetID : ProviderSetID {
ImportPath : ref . importPath ,
Name : ref . name ,
} ,
Pos : d . pos ,
} } ,
2018-04-02 10:57:48 -07:00
}
2018-04-02 09:21:52 -07:00
}
2018-03-30 21:34:08 -07:00
}
default :
return fmt . Errorf ( "%v: unknown directive %s" , fctx . fset . Position ( d . pos ) , d . kind )
}
return nil
}
// processDeclDirectives processes the directives associated with a top-level declaration.
2018-04-04 14:42:56 -07:00
func processDeclDirectives ( fctx findContext , sets map [ string ] * ProviderSet , scope * types . Scope , dg directiveGroup ) error {
2018-03-30 21:34:08 -07:00
p , err := dg . single ( fctx . fset , "provide" )
if err != nil {
return err
}
if ! p . isValid ( ) {
return nil
}
2018-04-03 21:11:53 -07:00
var providerSetName string
if args := p . args ( ) ; len ( args ) == 1 {
2018-04-27 17:40:40 -04:00
// TODO(light): Validate identifier.
2018-04-03 21:11:53 -07:00
providerSetName = args [ 0 ]
} else if len ( args ) > 1 {
return fmt . Errorf ( "%v: goose:provide takes at most one argument" , fctx . fset . Position ( p . pos ) )
2018-03-30 21:34:08 -07:00
}
2018-04-03 21:11:53 -07:00
switch decl := dg . decl . ( type ) {
case * ast . FuncDecl :
fn := fctx . typeInfo . ObjectOf ( decl . Name ) . ( * types . Func )
2018-04-26 14:23:06 -04:00
provider , err := processFuncProvider ( fctx , fn )
2018-04-03 21:11:53 -07:00
if err != nil {
return err
}
if providerSetName == "" {
providerSetName = fn . Name ( )
}
if mod := sets [ providerSetName ] ; mod != nil {
2018-04-04 14:42:56 -07:00
for _ , other := range mod . Providers {
if types . Identical ( other . Out , provider . Out ) {
return fmt . Errorf ( "%v: provider set %s has multiple providers for %s (previous declaration at %v)" , fctx . fset . Position ( fn . Pos ( ) ) , providerSetName , types . TypeString ( provider . Out , nil ) , fctx . fset . Position ( other . Pos ) )
2018-04-03 21:11:53 -07:00
}
}
2018-04-04 14:42:56 -07:00
mod . Providers = append ( mod . Providers , provider )
2018-04-03 21:11:53 -07:00
} else {
2018-04-04 14:42:56 -07:00
sets [ providerSetName ] = & ProviderSet {
Providers : [ ] * Provider { provider } ,
2018-04-03 21:11:53 -07:00
}
}
case * ast . GenDecl :
if decl . Tok != token . TYPE {
return fmt . Errorf ( "%v: only functions and structs can be marked as providers" , fctx . fset . Position ( p . pos ) )
}
if len ( decl . Specs ) != 1 {
2018-04-27 17:40:40 -04:00
// TODO(light): Tighten directive extraction to associate with particular specs.
2018-04-03 21:11:53 -07:00
return fmt . Errorf ( "%v: only functions and structs can be marked as providers" , fctx . fset . Position ( p . pos ) )
}
typeName := fctx . typeInfo . ObjectOf ( decl . Specs [ 0 ] . ( * ast . TypeSpec ) . Name ) . ( * types . TypeName )
if _ , ok := typeName . Type ( ) . ( * types . Named ) . Underlying ( ) . ( * types . Struct ) ; ! ok {
return fmt . Errorf ( "%v: only functions and structs can be marked as providers" , fctx . fset . Position ( p . pos ) )
}
2018-04-26 14:23:06 -04:00
provider , err := processStructProvider ( fctx , typeName )
2018-04-03 21:11:53 -07:00
if err != nil {
return err
}
if providerSetName == "" {
providerSetName = typeName . Name ( )
}
2018-04-04 14:42:56 -07:00
ptrProvider := new ( Provider )
2018-04-03 21:11:53 -07:00
* ptrProvider = * provider
2018-04-04 14:42:56 -07:00
ptrProvider . Out = types . NewPointer ( provider . Out )
2018-04-03 21:11:53 -07:00
if mod := sets [ providerSetName ] ; mod != nil {
2018-04-04 14:42:56 -07:00
for _ , other := range mod . Providers {
if types . Identical ( other . Out , provider . Out ) {
return fmt . Errorf ( "%v: provider set %s has multiple providers for %s (previous declaration at %v)" , fctx . fset . Position ( typeName . Pos ( ) ) , providerSetName , types . TypeString ( provider . Out , nil ) , fctx . fset . Position ( other . Pos ) )
2018-04-03 21:11:53 -07:00
}
2018-04-04 14:42:56 -07:00
if types . Identical ( other . Out , ptrProvider . Out ) {
return fmt . Errorf ( "%v: provider set %s has multiple providers for %s (previous declaration at %v)" , fctx . fset . Position ( typeName . Pos ( ) ) , providerSetName , types . TypeString ( ptrProvider . Out , nil ) , fctx . fset . Position ( other . Pos ) )
2018-04-02 09:21:52 -07:00
}
2018-04-03 21:11:53 -07:00
}
2018-04-04 14:42:56 -07:00
mod . Providers = append ( mod . Providers , provider , ptrProvider )
2018-04-03 21:11:53 -07:00
} else {
2018-04-04 14:42:56 -07:00
sets [ providerSetName ] = & ProviderSet {
Providers : [ ] * Provider { provider , ptrProvider } ,
2018-04-02 09:21:52 -07:00
}
}
2018-04-03 21:11:53 -07:00
default :
return fmt . Errorf ( "%v: only functions and structs can be marked as providers" , fctx . fset . Position ( p . pos ) )
}
return nil
}
2018-04-26 14:23:06 -04:00
func processFuncProvider ( fctx findContext , fn * types . Func ) ( * Provider , error ) {
2018-04-03 21:11:53 -07:00
sig := fn . Type ( ) . ( * types . Signature )
2018-03-30 21:34:08 -07:00
fpos := fn . Pos ( )
r := sig . Results ( )
2018-04-03 13:13:15 -07:00
var hasCleanup , hasErr bool
2018-03-30 21:34:08 -07:00
switch r . Len ( ) {
case 1 :
2018-04-03 13:13:15 -07:00
hasCleanup , hasErr = false , false
2018-03-30 21:34:08 -07:00
case 2 :
2018-04-03 13:13:15 -07:00
switch t := r . At ( 1 ) . Type ( ) ; {
case types . Identical ( t , errorType ) :
hasCleanup , hasErr = false , true
case types . Identical ( t , cleanupType ) :
hasCleanup , hasErr = true , false
default :
2018-04-03 21:11:53 -07:00
return nil , fmt . Errorf ( "%v: wrong signature for provider %s: second return type must be error or func()" , fctx . fset . Position ( fpos ) , fn . Name ( ) )
2018-03-30 21:34:08 -07:00
}
2018-04-03 13:13:15 -07:00
case 3 :
if t := r . At ( 1 ) . Type ( ) ; ! types . Identical ( t , cleanupType ) {
2018-04-03 21:11:53 -07:00
return nil , fmt . Errorf ( "%v: wrong signature for provider %s: second return type must be func()" , fctx . fset . Position ( fpos ) , fn . Name ( ) )
2018-04-03 13:13:15 -07:00
}
if t := r . At ( 2 ) . Type ( ) ; ! types . Identical ( t , errorType ) {
2018-04-03 21:11:53 -07:00
return nil , fmt . Errorf ( "%v: wrong signature for provider %s: third return type must be error" , fctx . fset . Position ( fpos ) , fn . Name ( ) )
2018-04-03 13:13:15 -07:00
}
hasCleanup , hasErr = true , true
2018-03-30 21:34:08 -07:00
default :
2018-04-03 21:11:53 -07:00
return nil , fmt . Errorf ( "%v: wrong signature for provider %s: must have one return value and optional error" , fctx . fset . Position ( fpos ) , fn . Name ( ) )
2018-03-30 21:34:08 -07:00
}
out := r . At ( 0 ) . Type ( )
params := sig . Params ( )
2018-04-04 14:42:56 -07:00
provider := & Provider {
ImportPath : fctx . pkg . Path ( ) ,
Name : fn . Name ( ) ,
Pos : fn . Pos ( ) ,
Args : make ( [ ] ProviderInput , params . Len ( ) ) ,
Out : out ,
HasCleanup : hasCleanup ,
HasErr : hasErr ,
2018-03-30 21:34:08 -07:00
}
for i := 0 ; i < params . Len ( ) ; i ++ {
2018-04-04 14:42:56 -07:00
provider . Args [ i ] = ProviderInput {
2018-04-26 14:23:06 -04:00
Type : params . At ( i ) . Type ( ) ,
2018-03-30 21:34:08 -07:00
}
for j := 0 ; j < i ; j ++ {
2018-04-04 14:42:56 -07:00
if types . Identical ( provider . Args [ i ] . Type , provider . Args [ j ] . Type ) {
return nil , fmt . Errorf ( "%v: provider has multiple parameters of type %s" , fctx . fset . Position ( fpos ) , types . TypeString ( provider . Args [ j ] . Type , nil ) )
2018-03-30 21:34:08 -07:00
}
}
}
2018-04-03 21:11:53 -07:00
return provider , nil
}
2018-04-26 14:23:06 -04:00
func processStructProvider ( fctx findContext , typeName * types . TypeName ) ( * Provider , error ) {
2018-04-03 21:11:53 -07:00
out := typeName . Type ( )
st := out . Underlying ( ) . ( * types . Struct )
pos := typeName . Pos ( )
2018-04-04 14:42:56 -07:00
provider := & Provider {
ImportPath : fctx . pkg . Path ( ) ,
Name : typeName . Name ( ) ,
Pos : pos ,
Args : make ( [ ] ProviderInput , st . NumFields ( ) ) ,
Fields : make ( [ ] string , st . NumFields ( ) ) ,
IsStruct : true ,
Out : out ,
2018-04-03 21:11:53 -07:00
}
for i := 0 ; i < st . NumFields ( ) ; i ++ {
f := st . Field ( i )
2018-04-04 14:42:56 -07:00
provider . Args [ i ] = ProviderInput {
2018-04-26 14:23:06 -04:00
Type : f . Type ( ) ,
2018-04-03 21:11:53 -07:00
}
2018-04-04 14:42:56 -07:00
provider . Fields [ i ] = f . Name ( )
2018-04-03 21:11:53 -07:00
for j := 0 ; j < i ; j ++ {
2018-04-04 14:42:56 -07:00
if types . Identical ( provider . Args [ i ] . Type , provider . Args [ j ] . Type ) {
return nil , fmt . Errorf ( "%v: provider struct has multiple fields of type %s" , fctx . fset . Position ( pos ) , types . TypeString ( provider . Args [ j ] . Type , nil ) )
2018-04-03 21:11:53 -07:00
}
}
}
return provider , nil
2018-04-02 09:21:52 -07:00
}
// providerSetCache is a lazily evaluated index of provider sets.
type providerSetCache struct {
2018-04-04 14:42:56 -07:00
sets map [ string ] map [ string ] * ProviderSet
2018-04-02 09:21:52 -07:00
fset * token . FileSet
prog * loader . Program
2018-04-04 10:58:07 -07:00
r * importResolver
2018-04-02 09:21:52 -07:00
}
2018-04-04 10:58:07 -07:00
func newProviderSetCache ( prog * loader . Program , r * importResolver ) * providerSetCache {
2018-04-02 09:21:52 -07:00
return & providerSetCache {
fset : prog . Fset ,
prog : prog ,
2018-04-04 10:58:07 -07:00
r : r ,
2018-04-02 09:21:52 -07:00
}
}
2018-04-04 14:42:56 -07:00
func ( mc * providerSetCache ) get ( ref symref ) ( * ProviderSet , error ) {
2018-04-02 09:21:52 -07:00
if mods , cached := mc . sets [ ref . importPath ] ; cached {
mod := mods [ ref . name ]
if mod == nil {
return nil , fmt . Errorf ( "no such provider set %s in package %q" , ref . name , ref . importPath )
}
return mod , nil
}
if mc . sets == nil {
2018-04-04 14:42:56 -07:00
mc . sets = make ( map [ string ] map [ string ] * ProviderSet )
2018-04-02 09:21:52 -07:00
}
pkg := mc . prog . Package ( ref . importPath )
2018-03-30 21:34:08 -07:00
mods , err := findProviderSets ( findContext {
fset : mc . fset ,
pkg : pkg . Pkg ,
typeInfo : & pkg . Info ,
r : mc . r ,
} , pkg . Files )
2018-04-02 09:21:52 -07:00
if err != nil {
mc . sets [ ref . importPath ] = nil
return nil , err
}
mc . sets [ ref . importPath ] = mods
mod := mods [ ref . name ]
if mod == nil {
return nil , fmt . Errorf ( "no such provider set %s in package %q" , ref . name , ref . importPath )
}
return mod , nil
}
2018-04-02 14:08:17 -07:00
// A symref is a parsed reference to a symbol (either a provider set or a Go object).
type symref struct {
2018-04-02 09:21:52 -07:00
importPath string
name string
}
2018-04-02 14:08:17 -07:00
func parseSymbolRef ( r * importResolver , ref string , s * types . Scope , pkg string , pos token . Pos ) ( symref , error ) {
2018-04-27 17:40:40 -04:00
// TODO(light): Verify that provider set name is an identifier before returning.
2018-04-02 09:21:52 -07:00
i := strings . LastIndexByte ( ref , '.' )
if i == - 1 {
2018-04-02 14:08:17 -07:00
return symref { importPath : pkg , name : ref } , nil
2018-04-02 09:21:52 -07:00
}
imp , name := ref [ : i ] , ref [ i + 1 : ]
if strings . HasPrefix ( imp , ` " ` ) {
path , err := strconv . Unquote ( imp )
if err != nil {
2018-04-02 14:08:17 -07:00
return symref { } , fmt . Errorf ( "parse symbol reference %q: bad import path" , ref )
2018-04-02 09:21:52 -07:00
}
2018-04-04 10:58:07 -07:00
path , err = r . resolve ( pos , path )
if err != nil {
2018-04-02 14:08:17 -07:00
return symref { } , fmt . Errorf ( "parse symbol reference %q: %v" , ref , err )
2018-04-04 10:58:07 -07:00
}
2018-04-02 14:08:17 -07:00
return symref { importPath : path , name : name } , nil
2018-04-02 09:21:52 -07:00
}
_ , obj := s . LookupParent ( imp , pos )
if obj == nil {
2018-04-02 14:08:17 -07:00
return symref { } , fmt . Errorf ( "parse symbol reference %q: unknown identifier %s" , ref , imp )
2018-04-02 09:21:52 -07:00
}
pn , ok := obj . ( * types . PkgName )
if ! ok {
2018-04-02 14:08:17 -07:00
return symref { } , fmt . Errorf ( "parse symbol reference %q: %s does not name a package" , ref , imp )
2018-04-02 09:21:52 -07:00
}
2018-04-02 14:08:17 -07:00
return symref { importPath : pn . Imported ( ) . Path ( ) , name : name } , nil
2018-04-02 09:21:52 -07:00
}
2018-04-02 14:08:17 -07:00
func ( ref symref ) String ( ) string {
2018-04-02 09:21:52 -07:00
return strconv . Quote ( ref . importPath ) + "." + ref . name
}
2018-04-02 14:08:17 -07:00
func ( ref symref ) resolveObject ( pkg * types . Package ) ( types . Object , error ) {
imp := findImport ( pkg , ref . importPath )
if imp == nil {
return nil , fmt . Errorf ( "resolve Go reference %v: package not directly imported" , ref )
}
obj := imp . Scope ( ) . Lookup ( ref . name )
if obj == nil {
return nil , fmt . Errorf ( "resolve Go reference %v: %s not found in package" , ref , ref . name )
}
return obj , nil
}
2018-04-04 10:58:07 -07:00
type importResolver struct {
fset * token . FileSet
bctx * build . Context
findPackage func ( bctx * build . Context , importPath , fromDir string , mode build . ImportMode ) ( * build . Package , error )
}
func newImportResolver ( c * loader . Config , fset * token . FileSet ) * importResolver {
r := & importResolver {
fset : fset ,
bctx : c . Build ,
findPackage : c . FindPackage ,
}
if r . bctx == nil {
r . bctx = & build . Default
}
if r . findPackage == nil {
r . findPackage = ( * build . Context ) . Import
}
return r
}
func ( r * importResolver ) resolve ( pos token . Pos , path string ) ( string , error ) {
dir := filepath . Dir ( r . fset . File ( pos ) . Name ( ) )
pkg , err := r . findPackage ( r . bctx , path , dir , build . FindOnly )
if err != nil {
return "" , err
}
return pkg . ImportPath , nil
}
2018-04-02 14:08:17 -07:00
func findImport ( pkg * types . Package , path string ) * types . Package {
if pkg . Path ( ) == path {
return pkg
}
for _ , imp := range pkg . Imports ( ) {
if imp . Path ( ) == path {
return imp
}
}
return nil
}
2018-03-30 21:34:08 -07:00
// A directive is a parsed goose comment.
2018-04-02 09:21:52 -07:00
type directive struct {
pos token . Pos
kind string
line string
}
2018-03-30 21:34:08 -07:00
// A directiveGroup is a set of directives associated with a particular
// declaration.
type directiveGroup struct {
decl ast . Decl
dirs [ ] directive
}
// parseFile extracts the directives from a file, grouped by declaration.
func parseFile ( fset * token . FileSet , f * ast . File ) [ ] directiveGroup {
cmap := ast . NewCommentMap ( fset , f , f . Comments )
// Reserve first group for directives that don't associate with a
// declaration, like import.
groups := make ( [ ] directiveGroup , 1 , len ( f . Decls ) + 1 )
// Walk declarations and add to groups.
for _ , decl := range f . Decls {
grp := directiveGroup { decl : decl }
ast . Inspect ( decl , func ( node ast . Node ) bool {
if g := cmap [ node ] ; len ( g ) > 0 {
for _ , cg := range g {
start := len ( grp . dirs )
grp . dirs = extractDirectives ( grp . dirs , cg )
// Move directives that don't associate into the unassociated group.
n := 0
for i := start ; i < len ( grp . dirs ) ; i ++ {
2018-04-26 14:23:06 -04:00
if k := grp . dirs [ i ] . kind ; k == "provide" || k == "use" {
2018-03-30 21:34:08 -07:00
grp . dirs [ start + n ] = grp . dirs [ i ]
n ++
} else {
groups [ 0 ] . dirs = append ( groups [ 0 ] . dirs , grp . dirs [ i ] )
}
}
grp . dirs = grp . dirs [ : start + n ]
}
delete ( cmap , node )
}
return true
} )
if len ( grp . dirs ) > 0 {
groups = append ( groups , grp )
}
}
// Place remaining directives into the unassociated group.
unassoc := & groups [ 0 ]
for _ , g := range cmap {
for _ , cg := range g {
unassoc . dirs = extractDirectives ( unassoc . dirs , cg )
}
}
if len ( unassoc . dirs ) == 0 {
return groups [ 1 : ]
}
return groups
}
2018-04-02 09:21:52 -07:00
func extractDirectives ( d [ ] directive , cg * ast . CommentGroup ) [ ] directive {
const prefix = "goose:"
text := cg . Text ( )
for len ( text ) > 0 {
text = strings . TrimLeft ( text , " \t\r\n" )
if ! strings . HasPrefix ( text , prefix ) {
break
}
line := text [ len ( prefix ) : ]
2018-04-02 10:57:48 -07:00
// Text() is always newline terminated.
i := strings . IndexByte ( line , '\n' )
line , text = line [ : i ] , line [ i + 1 : ]
2018-04-02 09:21:52 -07:00
if i := strings . IndexByte ( line , ' ' ) ; i != - 1 {
d = append ( d , directive {
kind : line [ : i ] ,
line : strings . TrimSpace ( line [ i + 1 : ] ) ,
2018-04-27 17:40:40 -04:00
pos : cg . Pos ( ) , // TODO(light): More precise position.
2018-04-02 09:21:52 -07:00
} )
} else {
d = append ( d , directive {
kind : line ,
2018-04-27 17:40:40 -04:00
pos : cg . Pos ( ) , // TODO(light): More precise position.
2018-04-02 09:21:52 -07:00
} )
}
}
return d
}
2018-03-30 21:34:08 -07:00
// single finds at most one directive that matches the given kind.
func ( dg directiveGroup ) single ( fset * token . FileSet , kind string ) ( directive , error ) {
var found directive
ok := false
for _ , d := range dg . dirs {
if d . kind != kind {
continue
}
if ok {
switch decl := dg . decl . ( type ) {
case * ast . FuncDecl :
return directive { } , fmt . Errorf ( "%v: multiple %s directives for %s" , fset . Position ( d . pos ) , kind , decl . Name . Name )
case * ast . GenDecl :
if decl . Tok == token . TYPE && len ( decl . Specs ) == 1 {
name := decl . Specs [ 0 ] . ( * ast . TypeSpec ) . Name . Name
return directive { } , fmt . Errorf ( "%v: multiple %s directives for %s" , fset . Position ( d . pos ) , kind , name )
}
return directive { } , fmt . Errorf ( "%v: multiple %s directives" , fset . Position ( d . pos ) , kind )
default :
return directive { } , fmt . Errorf ( "%v: multiple %s directives" , fset . Position ( d . pos ) , kind )
}
}
found , ok = d , true
}
return found , nil
}
func ( d directive ) isValid ( ) bool {
return d . kind != ""
}
2018-04-02 10:57:48 -07:00
// args splits the directive line into tokens.
func ( d directive ) args ( ) [ ] string {
var args [ ] string
start := - 1
state := 0 // 0 = boundary, 1 = in token, 2 = in quote, 3 = quote backslash
for i , r := range d . line {
switch state {
case 0 :
2018-04-27 17:40:40 -04:00
// Argument boundary.
2018-04-02 10:57:48 -07:00
switch {
case r == '"' :
start = i
state = 2
case ! unicode . IsSpace ( r ) :
start = i
state = 1
}
case 1 :
2018-04-27 17:40:40 -04:00
// In token.
2018-04-02 10:57:48 -07:00
switch {
case unicode . IsSpace ( r ) :
args = append ( args , d . line [ start : i ] )
start = - 1
state = 0
case r == '"' :
state = 2
}
case 2 :
2018-04-27 17:40:40 -04:00
// In quotes.
2018-04-02 10:57:48 -07:00
switch {
case r == '"' :
state = 1
case r == '\\' :
state = 3
}
case 3 :
// Quote backslash. Consumes one character and jumps back into "in quote" state.
state = 2
default :
panic ( "unreachable" )
}
}
if start != - 1 {
args = append ( args , d . line [ start : ] )
}
return args
}
2018-04-02 09:21:52 -07:00
// isInjectFile reports whether a given file is an injection template.
func isInjectFile ( f * ast . File ) bool {
2018-04-27 17:40:40 -04:00
// TODO(light): Better determination.
2018-04-02 09:21:52 -07:00
for _ , cg := range f . Comments {
text := cg . Text ( )
if strings . HasPrefix ( text , "+build" ) && strings . Contains ( text , "gooseinject" ) {
return true
}
}
return false
}
2018-03-30 21:34:08 -07:00
// paramIndex returns the index of the parameter with the given name, or
// -1 if no such parameter exists.
func paramIndex ( params * types . Tuple , name string ) int {
for i := 0 ; i < params . Len ( ) ; i ++ {
if params . At ( i ) . Name ( ) == name {
return i
}
}
return - 1
}