diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index 90ff9d014..1bc4dfbe1 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -17,8 +17,6 @@ import ( "bytes" "flag" "fmt" - "go/ast" - "go/token" "go/types" "log" "os" @@ -62,33 +60,13 @@ func main() { pkg := pkgs[0] buf := new(bytes.Buffer) imports := make(map[string]struct{}) + namedTypes := codegen.NamedTypes(pkg) for _, typeName := range typeNames { - found := false - for _, file := range pkg.Syntax { - for _, d := range file.Decls { - decl, ok := d.(*ast.GenDecl) - if !ok || decl.Tok != token.TYPE { - continue - } - for _, s := range decl.Specs { - spec, ok := s.(*ast.TypeSpec) - if !ok || spec.Name.Name != typeName { - continue - } - typeNameObj := pkg.TypesInfo.Defs[spec.Name] - typ, ok := typeNameObj.Type().(*types.Named) - if !ok { - continue - } - pkg := typeNameObj.Pkg() - gen(buf, imports, typ, pkg) - found = true - } - } - } - if !found { + typ, ok := namedTypes[typeName] + if !ok { log.Fatalf("could not find type %s", typeName) } + gen(buf, imports, typ, pkg.Types) } w := func(format string, args ...interface{}) { diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 95302440b..013170735 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -8,9 +8,13 @@ package codegen import ( "bytes" "fmt" + "go/ast" "go/format" + "go/token" "go/types" "os" + + "golang.org/x/tools/go/packages" ) // WriteFormatted writes code to path. @@ -41,6 +45,32 @@ func WriteFormatted(code []byte, path string) error { return nil } +// NamedTypes returns all named types in pkg, keyed by their type name. +func NamedTypes(pkg *packages.Package) map[string]*types.Named { + nt := make(map[string]*types.Named) + for _, file := range pkg.Syntax { + for _, d := range file.Decls { + decl, ok := d.(*ast.GenDecl) + if !ok || decl.Tok != token.TYPE { + continue + } + for _, s := range decl.Specs { + spec, ok := s.(*ast.TypeSpec) + if !ok { + continue + } + typeNameObj := pkg.TypesInfo.Defs[spec.Name] + typ, ok := typeNameObj.Type().(*types.Named) + if !ok { + continue + } + nt[spec.Name.Name] = typ + } + } + } + return nt +} + // AssertStructUnchanged generates code that asserts at compile time that type t is unchanged. // tname is the named type corresponding to t. // ctx is a single-word context for this assertion, such as "Clone".