From 961b9c8abf36dec9d9320a22d38fd56d1ffe650e Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Fri, 24 Jul 2020 18:00:02 +1000 Subject: [PATCH] cmd/cloner: tool to generate Clone methods Signed-off-by: David Crawshaw --- cmd/cloner/cloner.go | 264 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100644 cmd/cloner/cloner.go diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go new file mode 100644 index 000000000..a0bd04e6a --- /dev/null +++ b/cmd/cloner/cloner.go @@ -0,0 +1,264 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Cloner is a tool to automate the creation of a Clone method. +// +// The result of the Clone method aliases no memory that can be edited +// with the original. +// +// This tool makes lots of implicit assumptions about the types you feed it. +// In particular, it can only write relatively "shallow" Clone methods. +// That is, if a type contains another named struct type, cloner assumes that +// named type will also have a Clone method. +package main + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/format" + "go/token" + "go/types" + "io/ioutil" + "log" + "os" + "strings" + + "golang.org/x/tools/go/packages" +) + +var ( + flagTypes = flag.String("type", "", "comma-separated list of types; required") + flagOutput = flag.String("output", "", "output file; required") + flagBuildTags = flag.String("tags", "", "compiler build tags to apply") +) + +func main() { + log.SetFlags(0) + log.SetPrefix("cloner: ") + flag.Parse() + if len(*flagTypes) == 0 { + flag.Usage() + os.Exit(2) + } + typeNames := strings.Split(*flagTypes, ",") + + cfg := &packages.Config{ + Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName, + Tests: false, + } + if *flagBuildTags != "" { + cfg.BuildFlags = []string{"-tags=" + *flagBuildTags} + } + pkgs, err := packages.Load(cfg, ".") + if err != nil { + log.Fatal(err) + } + if len(pkgs) != 1 { + log.Fatalf("wrong number of packages: %d", len(pkgs)) + } + pkg := pkgs[0] + buf := new(bytes.Buffer) + imports := make(map[string]struct{}) + for _, typeName := range typeNames { + found := false + for _, file := range pkg.Syntax { + //var fbuf bytes.Buffer + //ast.Fprint(&fbuf, pkg.Fset, file, nil) + //fmt.Println(fbuf.String()) + + for _, d := range file.Decls { + decl, ok := d.(*ast.GenDecl) + if !ok || decl.Tok != token.TYPE { + continue + } + for _, s := range decl.Specs { + spec, ok := s.(*ast.TypeSpec) + if !ok || spec.Name.Name != typeName { + continue + } + typeNameObj := pkg.TypesInfo.Defs[spec.Name] + typ, ok := typeNameObj.Type().(*types.Named) + if !ok { + continue + } + pkg := typeNameObj.Pkg() + gen(buf, imports, typeName, typ, pkg) + } + found = true + } + } + if !found { + log.Fatalf("could not find type %s", typeName) + } + } + + contents := new(bytes.Buffer) + fmt.Fprintf(contents, header, *flagTypes, pkg.Name) + fmt.Fprintf(contents, "import (\n") + for s := range imports { + fmt.Fprintf(contents, "\t%q\n", s) + } + fmt.Fprintf(contents, ")\n\n") + contents.Write(buf.Bytes()) + + out, err := format.Source(contents.Bytes()) + if err != nil { + log.Fatalf("%s, in source:\n%s", err, contents.Bytes()) + } + + output := *flagOutput + if output == "" { + flag.Usage() + os.Exit(2) + } + if err := ioutil.WriteFile(output, out, 0666); err != nil { + log.Fatal(err) + } +} + +const header = `// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// AUTO-GENERATED by: tailscale.com/cmd/cloner -type %s + +package %s + +` + +func gen(buf *bytes.Buffer, imports map[string]struct{}, name string, typ *types.Named, thisPkg *types.Package) { + pkgQual := func(pkg *types.Package) string { + if thisPkg == pkg { + return "" + } + imports[pkg.Path()] = struct{}{} + return pkg.Name() + } + importedName := func(t types.Type) string { + return types.TypeString(t, pkgQual) + } + + switch t := typ.Underlying().(type) { + case *types.Struct: + _ = t + name := typ.Obj().Name() + fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name) + fmt.Fprintf(buf, "// The result aliases no memory with the original.\n") + fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", name, name) + writef := func(format string, args ...interface{}) { + fmt.Fprintf(buf, "\t"+format+"\n", args...) + } + writef("if src == nil {") + writef("\treturn nil") + writef("}") + writef("dst := new(%s)", name) + writef("*dst = *src") + for i := 0; i < t.NumFields(); i++ { + fname := t.Field(i).Name() + ft := t.Field(i).Type() + if !containsPointers(ft) { + continue + } + if named, _ := ft.(*types.Named); named != nil && !hasBasicUnderlying(ft) { + writef("dst.%s = *src.%s.Clone()", fname, fname) + continue + } + switch ft := ft.Underlying().(type) { + case *types.Slice: + n := importedName(ft.Elem()) + if containsPointers(ft.Elem()) { + writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) + writef("for i := range dst.%s {", fname) + if _, isPtr := ft.Elem().(*types.Pointer); isPtr { + writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) + } else { + writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname) + } + writef("}") + } else { + writef("dst.%s = append([]%s(nil), src.%s...)", fname, n, fname) + } + case *types.Pointer: + if named, _ := ft.Elem().(*types.Named); named != nil && containsPointers(ft.Elem()) { + writef("dst.%s = src.%s.Clone()", fname, fname) + continue + } + n := importedName(ft.Elem()) + writef("if dst.%s != nil {", fname) + writef("\tdst.%s = new(%s)", fname, n) + writef("\t*dst.%s = *src.%s", fname, fname) + if containsPointers(ft.Elem()) { + writef("\t" + `panic("TODO pointers in pointers")`) + } + writef("}") + case *types.Map: + writef("if dst.%s != nil {", fname) + writef("\tdst.%s = map[%s]%s{}", fname, importedName(ft.Key()), importedName(ft.Elem())) + if sliceType, isSlice := ft.Elem().(*types.Slice); isSlice { + n := importedName(sliceType.Elem()) + writef("\tfor k := range src.%s {", fname) + // use zero-length slice instead of nil to ensure + // the key is always copied. + writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) + writef("\t}") + } else if containsPointers(ft.Elem()) { + writef("\t\t" + `panic("TODO map value pointers")`) + } else { + writef("\tfor k, v := range src.%s {", fname) + writef("\t\tdst.%s[k] = v", fname) + writef("\t}") + } + writef("}") + case *types.Struct: + writef(`panic("TODO struct %s")`, fname) + default: + writef(`panic(fmt.Sprintf("TODO: %T", ft))`) + } + } + writef("return dst") + fmt.Fprintf(buf, "}\n\n") + } +} + +func hasBasicUnderlying(typ types.Type) bool { + switch typ.Underlying().(type) { + case *types.Slice, *types.Map: + return true + default: + return false + } +} + +func containsPointers(typ types.Type) bool { + switch typ.String() { + case "time.Time": + // time.Time contains a pointer that does not need copying + return false + case "inet.af/netaddr.IP": + return false + } + switch ft := typ.Underlying().(type) { + case *types.Array: + return containsPointers(ft.Elem()) + case *types.Chan: + return true + case *types.Interface: + return true // a little too broad + case *types.Map: + return true + case *types.Pointer: + return true + case *types.Slice: + return true + case *types.Struct: + for i := 0; i < ft.NumFields(); i++ { + if containsPointers(ft.Field(i).Type()) { + return true + } + } + } + return false +}