diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index 421f0e69c..123529f95 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -156,7 +156,7 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisP for i := 0; i < t.NumFields(); i++ { fname := t.Field(i).Name() ft := t.Field(i).Type() - if !containsPointers(ft) { + if !codegen.ContainsPointers(ft) { continue } if named, _ := ft.(*types.Named); named != nil && !hasBasicUnderlying(ft) { @@ -165,7 +165,7 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisP } switch ft := ft.Underlying().(type) { case *types.Slice: - if containsPointers(ft.Elem()) { + if codegen.ContainsPointers(ft.Elem()) { n := importedName(ft.Elem()) writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) writef("for i := range dst.%s {", fname) @@ -179,7 +179,7 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisP writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname) } case *types.Pointer: - if named, _ := ft.Elem().(*types.Named); named != nil && containsPointers(ft.Elem()) { + if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) { writef("dst.%s = src.%s.Clone()", fname, fname) continue } @@ -187,7 +187,7 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisP writef("if dst.%s != nil {", fname) writef("\tdst.%s = new(%s)", fname, n) writef("\t*dst.%s = *src.%s", fname, fname) - if containsPointers(ft.Elem()) { + if codegen.ContainsPointers(ft.Elem()) { writef("\t" + `panic("TODO pointers in pointers")`) } writef("}") @@ -201,7 +201,7 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisP // the key is always copied. writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) writef("\t}") - } else if containsPointers(ft.Elem()) { + } else if codegen.ContainsPointers(ft.Elem()) { writef("\tfor k, v := range src.%s {", fname) writef("\t\tdst.%s[k] = v.Clone()", fname) writef("\t}") @@ -229,34 +229,3 @@ func hasBasicUnderlying(typ types.Type) bool { return false } } - -func containsPointers(typ types.Type) bool { - switch typ.String() { - case "time.Time": - // time.Time contains a pointer that does not need copying - return false - case "inet.af/netaddr.IP": - return false - } - switch ft := typ.Underlying().(type) { - case *types.Array: - return containsPointers(ft.Elem()) - case *types.Chan: - return true - case *types.Interface: - return true // a little too broad - case *types.Map: - return true - case *types.Pointer: - return true - case *types.Slice: - return true - case *types.Struct: - for i := 0; i < ft.NumFields(); i++ { - if containsPointers(ft.Field(i).Type()) { - return true - } - } - } - return false -} diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 013170735..2df6c9b25 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -109,3 +109,38 @@ func importedName(t types.Type, thisPkg *types.Package) (qualifiedName, importPk } return types.TypeString(t, qual), importPkg } + +// ContainsPointers reports whether typ contains any pointers, +// either explicitly or implicitly. +// It has special handling for some types that contain pointers +// that we know are free from memory aliasing/mutation concerns. +func ContainsPointers(typ types.Type) bool { + switch typ.String() { + case "time.Time": + // time.Time contains a pointer that does not need copying + return false + case "inet.af/netaddr.IP": + return false + } + switch ft := typ.Underlying().(type) { + case *types.Array: + return ContainsPointers(ft.Elem()) + case *types.Chan: + return true + case *types.Interface: + return true // a little too broad + case *types.Map: + return true + case *types.Pointer: + return true + case *types.Slice: + return true + case *types.Struct: + for i := 0; i < ft.NumFields(); i++ { + if ContainsPointers(ft.Field(i).Type()) { + return true + } + } + } + return false +}