diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 69414a08e..28a72351e 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -264,6 +264,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de LW tailscale.com/util/endian from tailscale.com/net/dns+ tailscale.com/util/groupmember from tailscale.com/ipn/ipnserver tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/mak from tailscale.com/control/controlclient+ tailscale.com/util/multierr from tailscale.com/cmd/tailscaled+ tailscale.com/util/netconv from tailscale.com/wgengine/magicsock tailscale.com/util/osshare from tailscale.com/cmd/tailscaled+ diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index 1a5f4a73f..54c1e0ce6 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -20,6 +20,7 @@ import ( "tailscale.com/control/controlhttp" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/util/mak" "tailscale.com/util/multierr" ) @@ -137,9 +138,6 @@ func (nc *noiseClient) Close() error { func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) { nc.mu.Lock() connID := nc.nextID - if nc.connPool == nil { - nc.connPool = make(map[int]*noiseConn) - } nc.nextID++ nc.mu.Unlock() @@ -161,6 +159,6 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) { nc.mu.Lock() defer nc.mu.Unlock() ncc := &noiseConn{Conn: conn, id: connID, pool: nc} - nc.connPool[ncc.id] = ncc + mak.Set(&nc.connPool, ncc.id, ncc) return ncc, nil } diff --git a/ipn/store/stores.go b/ipn/store/stores.go index d74060858..98c1d2aa0 100644 --- a/ipn/store/stores.go +++ b/ipn/store/stores.go @@ -21,6 +21,7 @@ import ( "tailscale.com/ipn/store/mem" "tailscale.com/paths" "tailscale.com/types/logger" + "tailscale.com/util/mak" ) // Provider returns a StateStore for the provided path. @@ -82,10 +83,7 @@ func Register(prefix string, fn Provider) { if _, ok := knownStores[prefix]; ok { panic(fmt.Sprintf("%q already registered", prefix)) } - if knownStores == nil { - knownStores = make(map[string]Provider) - } - knownStores[prefix] = fn + mak.Set(&knownStores, prefix, fn) } // TryWindowsAppDataMigration attempts to copy the Windows state file diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index a53f67970..9d71f1ef0 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -40,6 +40,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/types/logger" + "tailscale.com/util/mak" ) var ( @@ -471,7 +472,7 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { srv.mu.Lock() defer srv.mu.Unlock() - mapSet(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{ + mak.Set(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{ at: srv.now(), lines: lines, etag: etag, @@ -731,8 +732,8 @@ func (srv *server) startSession(ss *sshSession) { if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup { panic("dup sharedID") } - mapSet(&srv.activeSessionByH, ss.idH, ss) - mapSet(&srv.activeSessionBySharedID, ss.sharedID, ss) + mak.Set(&srv.activeSessionByH, ss.idH, ss) + mak.Set(&srv.activeSessionBySharedID, ss.sharedID, ss) } // endSession unregisters s from the list of active sessions. @@ -1248,11 +1249,3 @@ func envEq(a, b string) bool { } return a == b } - -// mapSet assigns m[k] = v, making m if necessary. -func mapSet[K comparable, V any](m *map[K]V, k K, v V) { - if *m == nil { - *m = make(map[K]V) - } - (*m)[k] = v -} diff --git a/util/mak/mak.go b/util/mak/mak.go new file mode 100644 index 000000000..e0f0d8d03 --- /dev/null +++ b/util/mak/mak.go @@ -0,0 +1,53 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package mak helps make maps. It contains generic helpers to make/assign +// things, notably to maps, but also slices. +package mak + +import ( + "fmt" + "reflect" +) + +// Set populates an entry in a map, making the map if necessary. +// +// That is, it assigns (*m)[k] = v, making *m if it was nil. +func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) { + if *m == nil { + *m = make(map[K]V) + } + (*m)[k] = v +} + +// NonNil takes a pointer to a Go data structure +// (currently only a slice or a map) and makes sure it's non-nil for +// JSON serialization. (In particular, JavaScript clients usually want +// the field to be defined after they decode the JSON.) +// MakeNonNil takes a pointer to a Go data structure +// (currently only a slice or a map) and makes sure it's non-nil for +// JSON serialization. (In particular, JavaScript clients usually want +// the field to be defined after they decode the JSON.) +func NonNil(ptr interface{}) { + if ptr == nil { + panic("nil interface") + } + rv := reflect.ValueOf(ptr) + if rv.Kind() != reflect.Ptr { + panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind())) + } + if rv.Pointer() == 0 { + panic("nil pointer") + } + rv = rv.Elem() + if rv.Pointer() != 0 { + return + } + switch rv.Type().Kind() { + case reflect.Slice: + rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) + case reflect.Map: + rv.Set(reflect.MakeMap(rv.Type())) + } +} diff --git a/util/mak/mak_test.go b/util/mak/mak_test.go new file mode 100644 index 000000000..fae40e220 --- /dev/null +++ b/util/mak/mak_test.go @@ -0,0 +1,71 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package mak contains code to help make things. +package mak + +import ( + "reflect" + "testing" +) + +type M map[string]int + +func TestSet(t *testing.T) { + t.Run("unnamed", func(t *testing.T) { + var m map[string]int + Set(&m, "foo", 42) + Set(&m, "bar", 1) + Set(&m, "bar", 2) + want := map[string]int{ + "foo": 42, + "bar": 2, + } + if got := m; !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + }) + t.Run("named", func(t *testing.T) { + var m M + Set(&m, "foo", 1) + Set(&m, "bar", 1) + Set(&m, "bar", 2) + want := M{ + "foo": 1, + "bar": 2, + } + if got := m; !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + }) +} + +func TestNonNil(t *testing.T) { + var s []string + NonNil(&s) + if len(s) != 0 { + t.Errorf("slice len = %d; want 0", len(s)) + } + if s == nil { + t.Error("slice still nil") + } + + s = append(s, "foo") + NonNil(&s) + if len(s) != 1 { + t.Errorf("len = %d; want 1", len(s)) + } + if s[0] != "foo" { + t.Errorf("value = %q; want foo", s) + } + + var m map[string]string + NonNil(&m) + if len(m) != 0 { + t.Errorf("map len = %d; want 0", len(s)) + } + if m == nil { + t.Error("map still nil") + } +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 5236268a9..66ebc8dea 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -55,6 +55,7 @@ import ( "tailscale.com/types/nettype" "tailscale.com/util/clientmetric" "tailscale.com/util/netconv" + "tailscale.com/util/mak" "tailscale.com/util/uniq" "tailscale.com/version" "tailscale.com/wgengine/monitor" @@ -438,11 +439,7 @@ func (c *Conn) removeDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp func (c *Conn) addDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) { c.mu.Lock() defer c.mu.Unlock() - if c.derpRoute == nil { - c.derpRoute = make(map[key.NodePublic]derpRoute) - } - r := derpRoute{derpID, dc} - c.derpRoute[peer] = r + mak.Set(&c.derpRoute, peer, derpRoute{derpID, dc}) } // DerpMagicIP is a fake WireGuard endpoint IP address that means @@ -1050,7 +1047,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro }, nil } - already := make(map[netaddr.IPPort]tailcfg.EndpointType) // endpoint -> how it was found + var already map[netaddr.IPPort]tailcfg.EndpointType // endpoint -> how it was found var eps []tailcfg.Endpoint // unique endpoints ipp := func(s string) (ipp netaddr.IPPort) { @@ -1062,7 +1059,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro return } if _, ok := already[ipp]; !ok { - already[ipp] = et + mak.Set(&already, ipp, et) eps = append(eps, tailcfg.Endpoint{Addr: ipp, Type: et}) } } @@ -3957,9 +3954,6 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) { for ep := range de.isCallMeMaybeEP { de.isCallMeMaybeEP[ep] = false // mark for deletion } - if de.isCallMeMaybeEP == nil { - de.isCallMeMaybeEP = map[netaddr.IPPort]bool{} - } var newEPs []netaddr.IPPort for _, ep := range m.MyNumber { if ep.IP().Is6() && ep.IP().IsLinkLocalUnicast() { @@ -3968,7 +3962,7 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) { // for these. continue } - de.isCallMeMaybeEP[ep] = true + mak.Set(&de.isCallMeMaybeEP, ep, true) if es, ok := de.endpointState[ep]; ok { es.callMeMaybeTime = now } else { diff --git a/wgengine/pendopen.go b/wgengine/pendopen.go index f0e3bcb0a..adf446973 100644 --- a/wgengine/pendopen.go +++ b/wgengine/pendopen.go @@ -15,6 +15,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/net/tstun" "tailscale.com/types/ipproto" + "tailscale.com/util/mak" "tailscale.com/wgengine/filter" ) @@ -115,14 +116,11 @@ func (e *userspaceEngine) trackOpenPostFilterOut(pp *packet.Parsed, t *tstun.Wra e.mu.Lock() defer e.mu.Unlock() - if e.pendOpen == nil { - e.pendOpen = make(map[flowtrack.Tuple]*pendingOpenFlow) - } if _, dup := e.pendOpen[flow]; dup { // Duplicates are expected when the OS retransmits. Ignore. return } - e.pendOpen[flow] = &pendingOpenFlow{timer: timer} + mak.Set(&e.pendOpen, flow, &pendingOpenFlow{timer: timer}) return filter.Accept }