From ec4feaf31c3398084718f5ee7dee1e8aaf6f5063 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 27 Jul 2020 10:40:34 -0700 Subject: [PATCH] cmd/cloner, tailcfg: fix nil vs len 0 issues, add tests, use for Hostinfo Also use go:generate and https://golang.org/s/generatedcode header style. Signed-off-by: Brad Fitzpatrick --- cmd/cloner/cloner.go | 6 +++--- tailcfg/tailcfg.go | 16 ++-------------- tailcfg/tailcfg_clone.go | 28 +++++++++++++++++++++------- tailcfg/tailcfg_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 66 insertions(+), 24 deletions(-) diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index a0bd04e6a..c7463b0f8 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -123,7 +123,7 @@ const header = `// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserve // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// AUTO-GENERATED by: tailscale.com/cmd/cloner -type %s +// Code generated by tailscale.com/cmd/cloner -type %s; DO NOT EDIT. package %s @@ -168,8 +168,8 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, name string, typ *types } switch ft := ft.Underlying().(type) { case *types.Slice: - n := importedName(ft.Elem()) if containsPointers(ft.Elem()) { + n := importedName(ft.Elem()) writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) writef("for i := range dst.%s {", fname) if _, isPtr := ft.Elem().(*types.Pointer); isPtr { @@ -179,7 +179,7 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, name string, typ *types } writef("}") } else { - writef("dst.%s = append([]%s(nil), src.%s...)", fname, n, fname) + writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname) } case *types.Pointer: if named, _ := ft.Elem().(*types.Named); named != nil && containsPointers(ft.Elem()) { diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 0e98f2a72..912005a6c 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -4,6 +4,8 @@ package tailcfg +//go:generate go run tailscale.com/cmd/cloner -type=User,Node,Hostinfo,NetInfo -output=tailcfg_clone.go + import ( "bytes" "errors" @@ -371,20 +373,6 @@ func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool { ni.LinkType == ni2.LinkType } -// Clone makes a deep copy of Hostinfo. -// The result aliases no memory with the original. -// -// TODO: use cmd/cloner, reconcile len(0) vs. nil. -func (h *Hostinfo) Clone() (res *Hostinfo) { - res = new(Hostinfo) - *res = *h - - res.RoutableIPs = append([]wgcfg.CIDR{}, h.RoutableIPs...) - res.Services = append([]Service{}, h.Services...) - res.NetInfo = h.NetInfo.Clone() - return res -} - // Equal reports whether h and h2 are equal. func (h *Hostinfo) Equal(h2 *Hostinfo) bool { if h == nil && h2 == nil { diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index cc936e51f..466871bbd 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -2,12 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// AUTO-GENERATED by tailscale.com/cmd/cloner -type User,Node,NetInfo +// Code generated by tailscale.com/cmd/cloner -type User,Node,Hostinfo,NetInfo; DO NOT EDIT. package tailcfg import ( - "github.com/tailscale/wireguard-go/wgcfg" "time" ) @@ -19,8 +18,8 @@ func (src *User) Clone() *User { } dst := new(User) *dst = *src - dst.Logins = append([]LoginID(nil), src.Logins...) - dst.Roles = append([]RoleID(nil), src.Roles...) + dst.Logins = append(src.Logins[:0:0], src.Logins...) + dst.Roles = append(src.Roles[:0:0], src.Roles...) return dst } @@ -32,9 +31,9 @@ func (src *Node) Clone() *Node { } dst := new(Node) *dst = *src - dst.Addresses = append([]wgcfg.CIDR(nil), src.Addresses...) - dst.AllowedIPs = append([]wgcfg.CIDR(nil), src.AllowedIPs...) - dst.Endpoints = append([]string(nil), src.Endpoints...) + dst.Addresses = append(src.Addresses[:0:0], src.Addresses...) + dst.AllowedIPs = append(src.AllowedIPs[:0:0], src.AllowedIPs...) + dst.Endpoints = append(src.Endpoints[:0:0], src.Endpoints...) dst.Hostinfo = *src.Hostinfo.Clone() if dst.LastSeen != nil { dst.LastSeen = new(time.Time) @@ -43,6 +42,21 @@ func (src *Node) Clone() *Node { return dst } +// Clone makes a deep copy of Hostinfo. +// The result aliases no memory with the original. +func (src *Hostinfo) Clone() *Hostinfo { + if src == nil { + return nil + } + dst := new(Hostinfo) + *dst = *src + dst.RoutableIPs = append(src.RoutableIPs[:0:0], src.RoutableIPs...) + dst.RequestTags = append(src.RequestTags[:0:0], src.RequestTags...) + dst.Services = append(src.Services[:0:0], src.Services...) + dst.NetInfo = src.NetInfo.Clone() + return dst +} + // Clone makes a deep copy of NetInfo. // The result aliases no memory with the original. func (src *NetInfo) Clone() *NetInfo { diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index cd35f7647..1ed974f98 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -390,3 +390,43 @@ func testKey(t *testing.T, prefix string, in keyIn, out encoding.TextUnmarshaler t.Errorf("mismatch after unmarshal") } } + +func TestCloneUser(t *testing.T) { + tests := []struct { + name string + u *User + }{ + {"nil_logins", &User{}}, + {"zero_logins", &User{Logins: make([]LoginID, 0)}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u2 := tt.u.Clone() + if !reflect.DeepEqual(tt.u, u2) { + t.Errorf("not equal") + } + }) + } +} + +func TestCloneNode(t *testing.T) { + tests := []struct { + name string + v *Node + }{ + {"nil_fields", &Node{}}, + {"zero_fields", &Node{ + Addresses: make([]wgcfg.CIDR, 0), + AllowedIPs: make([]wgcfg.CIDR, 0), + Endpoints: make([]string, 0), + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v2 := tt.v.Clone() + if !reflect.DeepEqual(tt.v, v2) { + t.Errorf("not equal") + } + }) + } +}