2018-04-02 09:21:52 -07:00
package goose
import (
"fmt"
"go/ast"
2018-04-04 10:58:07 -07:00
"go/build"
2018-04-02 09:21:52 -07:00
"go/token"
"go/types"
2018-04-04 10:58:07 -07:00
"path/filepath"
2018-04-02 09:21:52 -07:00
"strconv"
"strings"
"golang.org/x/tools/go/loader"
)
// A providerSet describes a set of providers. The zero value is an empty
// providerSet.
type providerSet struct {
providers [ ] * providerInfo
imports [ ] providerSetImport
}
type providerSetImport struct {
providerSetRef
pos token . Pos
}
// providerInfo records the signature of a provider function.
type providerInfo struct {
importPath string
funcName string
pos token . Pos
args [ ] types . Type
out types . Type
hasErr bool
}
// findProviderSets processes a package and extracts the provider sets declared in it.
2018-04-04 10:58:07 -07:00
func findProviderSets ( fset * token . FileSet , pkg * types . Package , r * importResolver , typeInfo * types . Info , files [ ] * ast . File ) ( map [ string ] * providerSet , error ) {
2018-04-02 09:21:52 -07:00
sets := make ( map [ string ] * providerSet )
var directives [ ] directive
for _ , f := range files {
fileScope := typeInfo . Scopes [ f ]
for _ , c := range f . Comments {
directives = extractDirectives ( directives [ : 0 ] , c )
for _ , d := range directives {
switch d . kind {
case "provide" , "use" :
// handled later
case "import" :
if fileScope == nil {
return nil , fmt . Errorf ( "%s: no scope found for file (likely a bug)" , fset . File ( f . Pos ( ) ) . Name ( ) )
}
i := strings . IndexByte ( d . line , ' ' )
// TODO(light): allow multiple imports in one line
if i == - 1 {
return nil , fmt . Errorf ( "%s: invalid import: expected TARGET SETREF" , fset . Position ( d . pos ) )
}
name , spec := d . line [ : i ] , d . line [ i + 1 : ]
2018-04-04 10:58:07 -07:00
ref , err := parseProviderSetRef ( r , spec , fileScope , pkg . Path ( ) , d . pos )
2018-04-02 09:21:52 -07:00
if err != nil {
return nil , fmt . Errorf ( "%v: %v" , fset . Position ( d . pos ) , err )
}
if ref . importPath != pkg . Path ( ) {
imported := false
for _ , imp := range pkg . Imports ( ) {
if ref . importPath == imp . Path ( ) {
imported = true
break
}
}
if ! imported {
return nil , fmt . Errorf ( "%v: provider set %s imports %q which is not in the package's imports" , fset . Position ( d . pos ) , name , ref . importPath )
}
}
if mod := sets [ name ] ; mod != nil {
found := false
for _ , other := range mod . imports {
if ref == other . providerSetRef {
found = true
break
}
}
if ! found {
mod . imports = append ( mod . imports , providerSetImport { providerSetRef : ref , pos : d . pos } )
}
} else {
sets [ name ] = & providerSet {
imports : [ ] providerSetImport { { providerSetRef : ref , pos : d . pos } } ,
}
}
default :
return nil , fmt . Errorf ( "%v: unknown directive %s" , fset . Position ( d . pos ) , d . kind )
}
}
}
cmap := ast . NewCommentMap ( fset , f , f . Comments )
for _ , decl := range f . Decls {
directives = directives [ : 0 ]
for _ , cg := range cmap [ decl ] {
directives = extractDirectives ( directives , cg )
}
fn , isFunction := decl . ( * ast . FuncDecl )
var providerSetName string
for _ , d := range directives {
if d . kind != "provide" {
continue
}
if providerSetName != "" {
return nil , fmt . Errorf ( "%v: multiple provide directives for %s" , fset . Position ( d . pos ) , fn . Name . Name )
}
if ! isFunction {
return nil , fmt . Errorf ( "%v: only functions can be marked as providers" , fset . Position ( d . pos ) )
}
providerSetName = fn . Name . Name
if d . line != "" {
// TODO(light): validate identifier
providerSetName = d . line
}
}
if providerSetName == "" {
continue
}
fpos := fn . Pos ( )
sig := typeInfo . ObjectOf ( fn . Name ) . Type ( ) . ( * types . Signature )
r := sig . Results ( )
var hasErr bool
switch r . Len ( ) {
case 1 :
hasErr = false
case 2 :
if t := r . At ( 1 ) . Type ( ) ; ! types . Identical ( t , errorType ) {
return nil , fmt . Errorf ( "%v: wrong signature for provider %s: second return type must be error" , fset . Position ( fpos ) , fn . Name . Name )
}
hasErr = true
default :
return nil , fmt . Errorf ( "%v: wrong signature for provider %s: must have one return value and optional error" , fset . Position ( fpos ) , fn . Name . Name )
}
out := r . At ( 0 ) . Type ( )
p := sig . Params ( )
provider := & providerInfo {
importPath : pkg . Path ( ) ,
funcName : fn . Name . Name ,
pos : fn . Pos ( ) ,
args : make ( [ ] types . Type , p . Len ( ) ) ,
out : out ,
hasErr : hasErr ,
}
for i := 0 ; i < p . Len ( ) ; i ++ {
provider . args [ i ] = p . At ( i ) . Type ( )
for j := 0 ; j < i ; j ++ {
if types . Identical ( provider . args [ i ] , provider . args [ j ] ) {
return nil , fmt . Errorf ( "%v: provider has multiple parameters of type %s" , fset . Position ( fpos ) , types . TypeString ( provider . args [ j ] , nil ) )
}
}
}
if mod := sets [ providerSetName ] ; mod != nil {
for _ , other := range mod . providers {
if types . Identical ( other . out , provider . out ) {
return nil , fmt . Errorf ( "%v: provider set %s has multiple providers for %s (previous declaration at %v)" , fset . Position ( fpos ) , providerSetName , types . TypeString ( provider . out , nil ) , fset . Position ( other . pos ) )
}
}
mod . providers = append ( mod . providers , provider )
} else {
sets [ providerSetName ] = & providerSet {
providers : [ ] * providerInfo { provider } ,
}
}
}
}
return sets , nil
}
// providerSetCache is a lazily evaluated index of provider sets.
type providerSetCache struct {
sets map [ string ] map [ string ] * providerSet
fset * token . FileSet
prog * loader . Program
2018-04-04 10:58:07 -07:00
r * importResolver
2018-04-02 09:21:52 -07:00
}
2018-04-04 10:58:07 -07:00
func newProviderSetCache ( prog * loader . Program , r * importResolver ) * providerSetCache {
2018-04-02 09:21:52 -07:00
return & providerSetCache {
fset : prog . Fset ,
prog : prog ,
2018-04-04 10:58:07 -07:00
r : r ,
2018-04-02 09:21:52 -07:00
}
}
func ( mc * providerSetCache ) get ( ref providerSetRef ) ( * providerSet , error ) {
if mods , cached := mc . sets [ ref . importPath ] ; cached {
mod := mods [ ref . name ]
if mod == nil {
return nil , fmt . Errorf ( "no such provider set %s in package %q" , ref . name , ref . importPath )
}
return mod , nil
}
if mc . sets == nil {
mc . sets = make ( map [ string ] map [ string ] * providerSet )
}
pkg := mc . prog . Package ( ref . importPath )
2018-04-04 10:58:07 -07:00
mods , err := findProviderSets ( mc . fset , pkg . Pkg , mc . r , & pkg . Info , pkg . Files )
2018-04-02 09:21:52 -07:00
if err != nil {
mc . sets [ ref . importPath ] = nil
return nil , err
}
mc . sets [ ref . importPath ] = mods
mod := mods [ ref . name ]
if mod == nil {
return nil , fmt . Errorf ( "no such provider set %s in package %q" , ref . name , ref . importPath )
}
return mod , nil
}
// A providerSetRef is a parsed reference to a collection of providers.
type providerSetRef struct {
importPath string
name string
}
2018-04-04 10:58:07 -07:00
func parseProviderSetRef ( r * importResolver , ref string , s * types . Scope , pkg string , pos token . Pos ) ( providerSetRef , error ) {
2018-04-02 09:21:52 -07:00
// TODO(light): verify that provider set name is an identifier before returning
i := strings . LastIndexByte ( ref , '.' )
if i == - 1 {
return providerSetRef { importPath : pkg , name : ref } , nil
}
imp , name := ref [ : i ] , ref [ i + 1 : ]
if strings . HasPrefix ( imp , ` " ` ) {
path , err := strconv . Unquote ( imp )
if err != nil {
return providerSetRef { } , fmt . Errorf ( "parse provider set reference %q: bad import path" , ref )
}
2018-04-04 10:58:07 -07:00
path , err = r . resolve ( pos , path )
if err != nil {
return providerSetRef { } , fmt . Errorf ( "parse provider set reference %q: %v" , ref , err )
}
2018-04-02 09:21:52 -07:00
return providerSetRef { importPath : path , name : name } , nil
}
_ , obj := s . LookupParent ( imp , pos )
if obj == nil {
return providerSetRef { } , fmt . Errorf ( "parse provider set reference %q: unknown identifier %s" , ref , imp )
}
pn , ok := obj . ( * types . PkgName )
if ! ok {
return providerSetRef { } , fmt . Errorf ( "parse provider set reference %q: %s does not name a package" , ref , imp )
}
return providerSetRef { importPath : pn . Imported ( ) . Path ( ) , name : name } , nil
}
func ( ref providerSetRef ) String ( ) string {
return strconv . Quote ( ref . importPath ) + "." + ref . name
}
2018-04-04 10:58:07 -07:00
type importResolver struct {
fset * token . FileSet
bctx * build . Context
findPackage func ( bctx * build . Context , importPath , fromDir string , mode build . ImportMode ) ( * build . Package , error )
}
func newImportResolver ( c * loader . Config , fset * token . FileSet ) * importResolver {
r := & importResolver {
fset : fset ,
bctx : c . Build ,
findPackage : c . FindPackage ,
}
if r . bctx == nil {
r . bctx = & build . Default
}
if r . findPackage == nil {
r . findPackage = ( * build . Context ) . Import
}
return r
}
func ( r * importResolver ) resolve ( pos token . Pos , path string ) ( string , error ) {
dir := filepath . Dir ( r . fset . File ( pos ) . Name ( ) )
pkg , err := r . findPackage ( r . bctx , path , dir , build . FindOnly )
if err != nil {
return "" , err
}
return pkg . ImportPath , nil
}
2018-04-02 09:21:52 -07:00
type directive struct {
pos token . Pos
kind string
line string
}
func extractDirectives ( d [ ] directive , cg * ast . CommentGroup ) [ ] directive {
const prefix = "goose:"
text := cg . Text ( )
for len ( text ) > 0 {
text = strings . TrimLeft ( text , " \t\r\n" )
if ! strings . HasPrefix ( text , prefix ) {
break
}
line := text [ len ( prefix ) : ]
if i := strings . IndexByte ( line , '\n' ) ; i != - 1 {
line , text = line [ : i ] , line [ i + 1 : ]
} else {
text = ""
}
if i := strings . IndexByte ( line , ' ' ) ; i != - 1 {
d = append ( d , directive {
kind : line [ : i ] ,
line : strings . TrimSpace ( line [ i + 1 : ] ) ,
pos : cg . Pos ( ) , // TODO(light): more precise position
} )
} else {
d = append ( d , directive {
kind : line ,
pos : cg . Pos ( ) , // TODO(light): more precise position
} )
}
}
return d
}
// isInjectFile reports whether a given file is an injection template.
func isInjectFile ( f * ast . File ) bool {
// TODO(light): better determination
for _ , cg := range f . Comments {
text := cg . Text ( )
if strings . HasPrefix ( text , "+build" ) && strings . Contains ( text , "gooseinject" ) {
return true
}
}
return false
}