From ee90cd02fdd4e4125ec9d12eef1195ed36ef4b2e Mon Sep 17 00:00:00 2001 From: James Tucker Date: Fri, 29 Sep 2023 17:29:17 -0700 Subject: [PATCH] cmd/cloner,*: optimize nillable slice cloner A wild @josharian appears with a good suggestion for a refactor, thanks Josh! Updates #9410 Signed-off-by: James Tucker --- cmd/cloner/cloner.go | 4 +--- cmd/viewer/tests/tests_clone.go | 32 ++++++++++---------------- tailcfg/tailcfg_clone.go | 40 +++++++++++++-------------------- wgengine/filter/filter_clone.go | 8 +++---- wgengine/wgcfg/wgcfg_clone.go | 8 +++---- 5 files changed, 34 insertions(+), 58 deletions(-) diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index 5a94fa97d..25d796dbf 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -122,8 +122,7 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { case *types.Slice: if codegen.ContainsPointers(ft.Elem()) { n := it.QualifiedName(ft.Elem()) - writef("if src.%s != nil {", fname) - writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) + writef("dst.%s = append([]%s(nil), make([]%s, len(src.%s))...)", fname, n, n, fname) writef("for i := range dst.%s {", fname) if ptr, isPtr := ft.Elem().(*types.Pointer); isPtr { if _, isBasic := ptr.Elem().Underlying().(*types.Basic); isBasic { @@ -138,7 +137,6 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname) } writef("}") - writef("}") } else { writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname) } diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 2b41639fd..35721430e 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -136,29 +136,21 @@ func (src *StructWithSlices) Clone() *StructWithSlices { dst := new(StructWithSlices) *dst = *src dst.Values = append(src.Values[:0:0], src.Values...) - if src.ValuePointers != nil { - dst.ValuePointers = make([]*StructWithoutPtrs, len(src.ValuePointers)) - for i := range dst.ValuePointers { - dst.ValuePointers[i] = src.ValuePointers[i].Clone() - } + dst.ValuePointers = append([]*StructWithoutPtrs(nil), make([]*StructWithoutPtrs, len(src.ValuePointers))...) + for i := range dst.ValuePointers { + dst.ValuePointers[i] = src.ValuePointers[i].Clone() } - if src.StructPointers != nil { - dst.StructPointers = make([]*StructWithPtrs, len(src.StructPointers)) - for i := range dst.StructPointers { - dst.StructPointers[i] = src.StructPointers[i].Clone() - } + dst.StructPointers = append([]*StructWithPtrs(nil), make([]*StructWithPtrs, len(src.StructPointers))...) + for i := range dst.StructPointers { + dst.StructPointers[i] = src.StructPointers[i].Clone() } - if src.Structs != nil { - dst.Structs = make([]StructWithPtrs, len(src.Structs)) - for i := range dst.Structs { - dst.Structs[i] = *src.Structs[i].Clone() - } + dst.Structs = append([]StructWithPtrs(nil), make([]StructWithPtrs, len(src.Structs))...) + for i := range dst.Structs { + dst.Structs[i] = *src.Structs[i].Clone() } - if src.Ints != nil { - dst.Ints = make([]*int, len(src.Ints)) - for i := range dst.Ints { - dst.Ints[i] = ptr.To(*src.Ints[i]) - } + dst.Ints = append([]*int(nil), make([]*int, len(src.Ints))...) + for i := range dst.Ints { + dst.Ints[i] = ptr.To(*src.Ints[i]) } dst.Slice = append(src.Slice[:0:0], src.Slice...) dst.Prefixes = append(src.Prefixes[:0:0], src.Prefixes...) diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 6a2292149..7d3ed78ff 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -74,11 +74,9 @@ func (src *Node) Clone() *Node { if dst.SelfNodeV6MasqAddrForThisPeer != nil { dst.SelfNodeV6MasqAddrForThisPeer = ptr.To(*src.SelfNodeV6MasqAddrForThisPeer) } - if src.ExitNodeDNSResolvers != nil { - dst.ExitNodeDNSResolvers = make([]*dnstype.Resolver, len(src.ExitNodeDNSResolvers)) - for i := range dst.ExitNodeDNSResolvers { - dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone() - } + dst.ExitNodeDNSResolvers = append([]*dnstype.Resolver(nil), make([]*dnstype.Resolver, len(src.ExitNodeDNSResolvers))...) + for i := range dst.ExitNodeDNSResolvers { + dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone() } return dst } @@ -237,11 +235,9 @@ func (src *DNSConfig) Clone() *DNSConfig { } dst := new(DNSConfig) *dst = *src - if src.Resolvers != nil { - dst.Resolvers = make([]*dnstype.Resolver, len(src.Resolvers)) - for i := range dst.Resolvers { - dst.Resolvers[i] = src.Resolvers[i].Clone() - } + dst.Resolvers = append([]*dnstype.Resolver(nil), make([]*dnstype.Resolver, len(src.Resolvers))...) + for i := range dst.Resolvers { + dst.Resolvers[i] = src.Resolvers[i].Clone() } if dst.Routes != nil { dst.Routes = map[string][]*dnstype.Resolver{} @@ -249,11 +245,9 @@ func (src *DNSConfig) Clone() *DNSConfig { dst.Routes[k] = append([]*dnstype.Resolver{}, src.Routes[k]...) } } - if src.FallbackResolvers != nil { - dst.FallbackResolvers = make([]*dnstype.Resolver, len(src.FallbackResolvers)) - for i := range dst.FallbackResolvers { - dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone() - } + dst.FallbackResolvers = append([]*dnstype.Resolver(nil), make([]*dnstype.Resolver, len(src.FallbackResolvers))...) + for i := range dst.FallbackResolvers { + dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone() } dst.Domains = append(src.Domains[:0:0], src.Domains...) dst.Nameservers = append(src.Nameservers[:0:0], src.Nameservers...) @@ -387,11 +381,9 @@ func (src *DERPRegion) Clone() *DERPRegion { } dst := new(DERPRegion) *dst = *src - if src.Nodes != nil { - dst.Nodes = make([]*DERPNode, len(src.Nodes)) - for i := range dst.Nodes { - dst.Nodes[i] = src.Nodes[i].Clone() - } + dst.Nodes = append([]*DERPNode(nil), make([]*DERPNode, len(src.Nodes))...) + for i := range dst.Nodes { + dst.Nodes[i] = src.Nodes[i].Clone() } return dst } @@ -468,11 +460,9 @@ func (src *SSHRule) Clone() *SSHRule { if dst.RuleExpires != nil { dst.RuleExpires = ptr.To(*src.RuleExpires) } - if src.Principals != nil { - dst.Principals = make([]*SSHPrincipal, len(src.Principals)) - for i := range dst.Principals { - dst.Principals[i] = src.Principals[i].Clone() - } + dst.Principals = append([]*SSHPrincipal(nil), make([]*SSHPrincipal, len(src.Principals))...) + for i := range dst.Principals { + dst.Principals[i] = src.Principals[i].Clone() } dst.SSHUsers = maps.Clone(src.SSHUsers) dst.Action = src.Action.Clone() diff --git a/wgengine/filter/filter_clone.go b/wgengine/filter/filter_clone.go index 97366d83c..9835c0a16 100644 --- a/wgengine/filter/filter_clone.go +++ b/wgengine/filter/filter_clone.go @@ -23,11 +23,9 @@ func (src *Match) Clone() *Match { dst.IPProto = append(src.IPProto[:0:0], src.IPProto...) dst.Srcs = append(src.Srcs[:0:0], src.Srcs...) dst.Dsts = append(src.Dsts[:0:0], src.Dsts...) - if src.Caps != nil { - dst.Caps = make([]CapMatch, len(src.Caps)) - for i := range dst.Caps { - dst.Caps[i] = *src.Caps[i].Clone() - } + dst.Caps = append([]CapMatch(nil), make([]CapMatch, len(src.Caps))...) + for i := range dst.Caps { + dst.Caps[i] = *src.Caps[i].Clone() } return dst } diff --git a/wgengine/wgcfg/wgcfg_clone.go b/wgengine/wgcfg/wgcfg_clone.go index 4a2288f1e..883aebc71 100644 --- a/wgengine/wgcfg/wgcfg_clone.go +++ b/wgengine/wgcfg/wgcfg_clone.go @@ -24,11 +24,9 @@ func (src *Config) Clone() *Config { *dst = *src dst.Addresses = append(src.Addresses[:0:0], src.Addresses...) dst.DNS = append(src.DNS[:0:0], src.DNS...) - if src.Peers != nil { - dst.Peers = make([]Peer, len(src.Peers)) - for i := range dst.Peers { - dst.Peers[i] = *src.Peers[i].Clone() - } + dst.Peers = append([]Peer(nil), make([]Peer, len(src.Peers))...) + for i := range dst.Peers { + dst.Peers[i] = *src.Peers[i].Clone() } return dst }