From 8a5ec72c85de6d84ccdbf97ff6ea288cacb5a0a9 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Sun, 20 Aug 2023 13:16:06 -0400 Subject: [PATCH] cmd/cloner: use maps.Clone and ptr.To Updates #cleanup Signed-off-by: Maisem Ali --- cmd/cloner/cloner.go | 26 ++++++++--------- cmd/derper/depaware.txt | 2 +- cmd/tailscale/depaware.txt | 2 +- cmd/tailscaled/depaware.txt | 2 +- cmd/viewer/tests/tests_clone.go | 50 ++++++++------------------------- ipn/ipn_clone.go | 8 ++---- tailcfg/tailcfg_clone.go | 47 +++++++++---------------------- wgengine/wgcfg/wgcfg_clone.go | 4 +-- 8 files changed, 45 insertions(+), 96 deletions(-) diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index a2620c8d7..4e8e2119b 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -126,8 +126,8 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("for i := range dst.%s {", fname) if ptr, isPtr := ft.Elem().(*types.Pointer); isPtr { if _, isBasic := ptr.Elem().Underlying().(*types.Basic); isBasic { - writef("\tx := *src.%s[i]", fname) - writef("\tdst.%s[i] = &x", fname) + it.Import("tailscale.com/types/ptr") + writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname) } else { writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) } @@ -145,41 +145,41 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("dst.%s = src.%s.Clone()", fname, fname) continue } - n := it.QualifiedName(ft.Elem()) + it.Import("tailscale.com/types/ptr") writef("if dst.%s != nil {", fname) - writef("\tdst.%s = new(%s)", fname, n) - writef("\t*dst.%s = *src.%s", fname, fname) + writef("\tdst.%s = ptr.To(*src.%s)", fname, fname) if codegen.ContainsPointers(ft.Elem()) { writef("\t" + `panic("TODO pointers in pointers")`) } writef("}") case *types.Map: elem := ft.Elem() - writef("if dst.%s != nil {", fname) - writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) if sliceType, isSlice := elem.(*types.Slice); isSlice { n := it.QualifiedName(sliceType.Elem()) + writef("if dst.%s != nil {", fname) + writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) writef("\tfor k := range src.%s {", fname) // use zero-length slice instead of nil to ensure // the key is always copied. writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) writef("\t}") + writef("}") } else if codegen.ContainsPointers(elem) { + writef("if dst.%s != nil {", fname) + writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) writef("\tfor k, v := range src.%s {", fname) switch elem.(type) { case *types.Pointer: writef("\t\tdst.%s[k] = v.Clone()", fname) default: - writef("\t\tv2 := v.Clone()") - writef("\t\tdst.%s[k] = *v2", fname) + writef("\t\tdst.%s[k] = *(v.Clone())", fname) } writef("\t}") + writef("}") } else { - writef("\tfor k, v := range src.%s {", fname) - writef("\t\tdst.%s[k] = v", fname) - writef("\t}") + it.Import("maps") + writef("\tdst.%s = maps.Clone(src.%s)", fname, fname) } - writef("}") default: writef(`panic("TODO: %s (%T)")`, fname, ft) } diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 394066519..e105d4a83 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -240,7 +240,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa io/ioutil from github.com/mitchellh/go-ps+ log from expvar+ log/internal from log - maps from tailscale.com/types/views + maps from tailscale.com/types/views+ math from compress/flate+ math/big from crypto/dsa+ math/bits from compress/flate+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index ef29eeebe..7ce47f189 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -256,7 +256,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep io/ioutil from golang.org/x/sys/cpu+ log from expvar+ log/internal from log - maps from tailscale.com/types/views + maps from tailscale.com/types/views+ math from compress/flate+ math/big from crypto/dsa+ math/bits from compress/flate+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index c151878b9..2e4075e64 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -466,7 +466,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de log from expvar+ log/internal from log LD log/syslog from tailscale.com/ssh/tailssh - maps from tailscale.com/types/views + maps from tailscale.com/types/views+ math from compress/flate+ math/big from crypto/dsa+ math/bits from compress/flate+ diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 3ff914126..4d82c03bd 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -6,7 +6,10 @@ package tests import ( + "maps" "net/netip" + + "tailscale.com/types/ptr" ) // Clone makes a deep copy of StructWithPtrs. @@ -18,12 +21,10 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { dst := new(StructWithPtrs) *dst = *src if dst.Value != nil { - dst.Value = new(StructWithoutPtrs) - *dst.Value = *src.Value + dst.Value = ptr.To(*src.Value) } if dst.Int != nil { - dst.Int = new(int) - *dst.Int = *src.Int + dst.Int = ptr.To(*src.Int) } return dst } @@ -60,12 +61,7 @@ func (src *Map) Clone() *Map { } dst := new(Map) *dst = *src - if dst.Int != nil { - dst.Int = map[string]int{} - for k, v := range src.Int { - dst.Int[k] = v - } - } + dst.Int = maps.Clone(src.Int) if dst.SliceInt != nil { dst.SliceInt = map[string][]int{} for k := range src.SliceInt { @@ -84,12 +80,7 @@ func (src *Map) Clone() *Map { dst.StructPtrWithoutPtr[k] = v.Clone() } } - if dst.StructWithoutPtr != nil { - dst.StructWithoutPtr = map[string]StructWithoutPtrs{} - for k, v := range src.StructWithoutPtr { - dst.StructWithoutPtr[k] = v - } - } + dst.StructWithoutPtr = maps.Clone(src.StructWithoutPtr) if dst.SlicesWithPtrs != nil { dst.SlicesWithPtrs = map[string][]*StructWithPtrs{} for k := range src.SlicesWithPtrs { @@ -102,35 +93,19 @@ func (src *Map) Clone() *Map { dst.SlicesWithoutPtrs[k] = append([]*StructWithoutPtrs{}, src.SlicesWithoutPtrs[k]...) } } - if dst.StructWithoutPtrKey != nil { - dst.StructWithoutPtrKey = map[StructWithoutPtrs]int{} - for k, v := range src.StructWithoutPtrKey { - dst.StructWithoutPtrKey[k] = v - } - } + dst.StructWithoutPtrKey = maps.Clone(src.StructWithoutPtrKey) if dst.SliceIntPtr != nil { dst.SliceIntPtr = map[string][]*int{} for k := range src.SliceIntPtr { dst.SliceIntPtr[k] = append([]*int{}, src.SliceIntPtr[k]...) } } - if dst.PointerKey != nil { - dst.PointerKey = map[*string]int{} - for k, v := range src.PointerKey { - dst.PointerKey[k] = v - } - } - if dst.StructWithPtrKey != nil { - dst.StructWithPtrKey = map[StructWithPtrs]int{} - for k, v := range src.StructWithPtrKey { - dst.StructWithPtrKey[k] = v - } - } + dst.PointerKey = maps.Clone(src.PointerKey) + dst.StructWithPtrKey = maps.Clone(src.StructWithPtrKey) if dst.StructWithPtr != nil { dst.StructWithPtr = map[string]StructWithPtrs{} for k, v := range src.StructWithPtr { - v2 := v.Clone() - dst.StructWithPtr[k] = *v2 + dst.StructWithPtr[k] = *(v.Clone()) } } return dst @@ -175,8 +150,7 @@ func (src *StructWithSlices) Clone() *StructWithSlices { } dst.Ints = make([]*int, len(src.Ints)) for i := range dst.Ints { - x := *src.Ints[i] - dst.Ints[i] = &x + dst.Ints[i] = ptr.To(*src.Ints[i]) } dst.Slice = append(src.Slice[:0:0], src.Slice...) dst.Prefixes = append(src.Prefixes[:0:0], src.Prefixes...) diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index 97207d039..5377705bb 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -6,6 +6,7 @@ package ipn import ( + "maps" "net/netip" "tailscale.com/tailcfg" @@ -73,12 +74,7 @@ func (src *ServeConfig) Clone() *ServeConfig { dst.Web[k] = v.Clone() } } - if dst.AllowFunnel != nil { - dst.AllowFunnel = map[HostPort]bool{} - for k, v := range src.AllowFunnel { - dst.AllowFunnel[k] = v - } - } + dst.AllowFunnel = maps.Clone(src.AllowFunnel) return dst } diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 577cfdb2c..9d3ca5221 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -6,12 +6,14 @@ package tailcfg import ( + "maps" "net/netip" "time" "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/opt" + "tailscale.com/types/ptr" "tailscale.com/types/structs" "tailscale.com/types/tkatype" ) @@ -54,17 +56,14 @@ func (src *Node) Clone() *Node { dst.Tags = append(src.Tags[:0:0], src.Tags...) dst.PrimaryRoutes = append(src.PrimaryRoutes[:0:0], src.PrimaryRoutes...) if dst.LastSeen != nil { - dst.LastSeen = new(time.Time) - *dst.LastSeen = *src.LastSeen + dst.LastSeen = ptr.To(*src.LastSeen) } if dst.Online != nil { - dst.Online = new(bool) - *dst.Online = *src.Online + dst.Online = ptr.To(*src.Online) } dst.Capabilities = append(src.Capabilities[:0:0], src.Capabilities...) if dst.SelfNodeV4MasqAddrForThisPeer != nil { - dst.SelfNodeV4MasqAddrForThisPeer = new(netip.Addr) - *dst.SelfNodeV4MasqAddrForThisPeer = *src.SelfNodeV4MasqAddrForThisPeer + dst.SelfNodeV4MasqAddrForThisPeer = ptr.To(*src.SelfNodeV4MasqAddrForThisPeer) } return dst } @@ -118,8 +117,7 @@ func (src *Hostinfo) Clone() *Hostinfo { dst.NetInfo = src.NetInfo.Clone() dst.SSH_HostKeys = append(src.SSH_HostKeys[:0:0], src.SSH_HostKeys...) if dst.Location != nil { - dst.Location = new(Location) - *dst.Location = *src.Location + dst.Location = ptr.To(*src.Location) } return dst } @@ -170,12 +168,7 @@ func (src *NetInfo) Clone() *NetInfo { } dst := new(NetInfo) *dst = *src - if dst.DERPLatency != nil { - dst.DERPLatency = map[string]float64{} - for k, v := range src.DERPLatency { - dst.DERPLatency[k] = v - } - } + dst.DERPLatency = maps.Clone(src.DERPLatency) return dst } @@ -295,8 +288,7 @@ func (src *RegisterResponseAuth) Clone() *RegisterResponseAuth { dst := new(RegisterResponseAuth) *dst = *src if dst.Oauth2Token != nil { - dst.Oauth2Token = new(Oauth2Token) - *dst.Oauth2Token = *src.Oauth2Token + dst.Oauth2Token = ptr.To(*src.Oauth2Token) } return dst } @@ -322,8 +314,7 @@ func (src *RegisterRequest) Clone() *RegisterRequest { dst.Hostinfo = src.Hostinfo.Clone() dst.NodeKeySignature = append(src.NodeKeySignature[:0:0], src.NodeKeySignature...) if dst.Timestamp != nil { - dst.Timestamp = new(time.Time) - *dst.Timestamp = *src.Timestamp + dst.Timestamp = ptr.To(*src.Timestamp) } dst.DeviceCert = append(src.DeviceCert[:0:0], src.DeviceCert...) dst.Signature = append(src.Signature[:0:0], src.Signature...) @@ -357,12 +348,7 @@ func (src *DERPHomeParams) Clone() *DERPHomeParams { } dst := new(DERPHomeParams) *dst = *src - if dst.RegionScore != nil { - dst.RegionScore = map[int]float64{} - for k, v := range src.RegionScore { - dst.RegionScore[k] = v - } - } + dst.RegionScore = maps.Clone(src.RegionScore) return dst } @@ -456,19 +442,13 @@ func (src *SSHRule) Clone() *SSHRule { dst := new(SSHRule) *dst = *src if dst.RuleExpires != nil { - dst.RuleExpires = new(time.Time) - *dst.RuleExpires = *src.RuleExpires + dst.RuleExpires = ptr.To(*src.RuleExpires) } dst.Principals = make([]*SSHPrincipal, len(src.Principals)) for i := range dst.Principals { dst.Principals[i] = src.Principals[i].Clone() } - if dst.SSHUsers != nil { - dst.SSHUsers = map[string]string{} - for k, v := range src.SSHUsers { - dst.SSHUsers[k] = v - } - } + dst.SSHUsers = maps.Clone(src.SSHUsers) dst.Action = src.Action.Clone() return dst } @@ -491,8 +471,7 @@ func (src *SSHAction) Clone() *SSHAction { *dst = *src dst.Recorders = append(src.Recorders[:0:0], src.Recorders...) if dst.OnRecordingFailure != nil { - dst.OnRecordingFailure = new(SSHRecorderFailureAction) - *dst.OnRecordingFailure = *src.OnRecordingFailure + dst.OnRecordingFailure = ptr.To(*src.OnRecordingFailure) } return dst } diff --git a/wgengine/wgcfg/wgcfg_clone.go b/wgengine/wgcfg/wgcfg_clone.go index 6887dd6cc..f96277cb3 100644 --- a/wgengine/wgcfg/wgcfg_clone.go +++ b/wgengine/wgcfg/wgcfg_clone.go @@ -11,6 +11,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logid" + "tailscale.com/types/ptr" ) // Clone makes a deep copy of Config. @@ -55,8 +56,7 @@ func (src *Peer) Clone() *Peer { *dst = *src dst.AllowedIPs = append(src.AllowedIPs[:0:0], src.AllowedIPs...) if dst.V4MasqAddr != nil { - dst.V4MasqAddr = new(netip.Addr) - *dst.V4MasqAddr = *src.V4MasqAddr + dst.V4MasqAddr = ptr.To(*src.V4MasqAddr) } return dst }