util/mak: move tailssh's mapSet into a new package for reuse elsewhere

Change-Id: Idfe95db82275fd2be6ca88f245830731a0d5aecf
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/4501/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent c2eff20008
commit 910ae68e0b

@ -264,6 +264,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
LW tailscale.com/util/endian from tailscale.com/net/dns+ LW tailscale.com/util/endian from tailscale.com/net/dns+
tailscale.com/util/groupmember from tailscale.com/ipn/ipnserver tailscale.com/util/groupmember from tailscale.com/ipn/ipnserver
tailscale.com/util/lineread from tailscale.com/hostinfo+ tailscale.com/util/lineread from tailscale.com/hostinfo+
tailscale.com/util/mak from tailscale.com/control/controlclient+
tailscale.com/util/multierr from tailscale.com/cmd/tailscaled+ tailscale.com/util/multierr from tailscale.com/cmd/tailscaled+
tailscale.com/util/netconv from tailscale.com/wgengine/magicsock tailscale.com/util/netconv from tailscale.com/wgengine/magicsock
tailscale.com/util/osshare from tailscale.com/cmd/tailscaled+ tailscale.com/util/osshare from tailscale.com/cmd/tailscaled+

@ -20,6 +20,7 @@ import (
"tailscale.com/control/controlhttp" "tailscale.com/control/controlhttp"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/util/mak"
"tailscale.com/util/multierr" "tailscale.com/util/multierr"
) )
@ -137,9 +138,6 @@ func (nc *noiseClient) Close() error {
func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) { func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
nc.mu.Lock() nc.mu.Lock()
connID := nc.nextID connID := nc.nextID
if nc.connPool == nil {
nc.connPool = make(map[int]*noiseConn)
}
nc.nextID++ nc.nextID++
nc.mu.Unlock() nc.mu.Unlock()
@ -161,6 +159,6 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
nc.mu.Lock() nc.mu.Lock()
defer nc.mu.Unlock() defer nc.mu.Unlock()
ncc := &noiseConn{Conn: conn, id: connID, pool: nc} ncc := &noiseConn{Conn: conn, id: connID, pool: nc}
nc.connPool[ncc.id] = ncc mak.Set(&nc.connPool, ncc.id, ncc)
return ncc, nil return ncc, nil
} }

@ -21,6 +21,7 @@ import (
"tailscale.com/ipn/store/mem" "tailscale.com/ipn/store/mem"
"tailscale.com/paths" "tailscale.com/paths"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/mak"
) )
// Provider returns a StateStore for the provided path. // Provider returns a StateStore for the provided path.
@ -82,10 +83,7 @@ func Register(prefix string, fn Provider) {
if _, ok := knownStores[prefix]; ok { if _, ok := knownStores[prefix]; ok {
panic(fmt.Sprintf("%q already registered", prefix)) panic(fmt.Sprintf("%q already registered", prefix))
} }
if knownStores == nil { mak.Set(&knownStores, prefix, fn)
knownStores = make(map[string]Provider)
}
knownStores[prefix] = fn
} }
// TryWindowsAppDataMigration attempts to copy the Windows state file // TryWindowsAppDataMigration attempts to copy the Windows state file

@ -40,6 +40,7 @@ import (
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/tempfork/gliderlabs/ssh"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/mak"
) )
var ( var (
@ -471,7 +472,7 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
mapSet(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{ mak.Set(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{
at: srv.now(), at: srv.now(),
lines: lines, lines: lines,
etag: etag, etag: etag,
@ -731,8 +732,8 @@ func (srv *server) startSession(ss *sshSession) {
if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup { if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
panic("dup sharedID") panic("dup sharedID")
} }
mapSet(&srv.activeSessionByH, ss.idH, ss) mak.Set(&srv.activeSessionByH, ss.idH, ss)
mapSet(&srv.activeSessionBySharedID, ss.sharedID, ss) mak.Set(&srv.activeSessionBySharedID, ss.sharedID, ss)
} }
// endSession unregisters s from the list of active sessions. // endSession unregisters s from the list of active sessions.
@ -1248,11 +1249,3 @@ func envEq(a, b string) bool {
} }
return a == b return a == b
} }
// mapSet assigns m[k] = v, making m if necessary.
func mapSet[K comparable, V any](m *map[K]V, k K, v V) {
if *m == nil {
*m = make(map[K]V)
}
(*m)[k] = v
}

@ -0,0 +1,53 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package mak helps make maps. It contains generic helpers to make/assign
// things, notably to maps, but also slices.
package mak
import (
"fmt"
"reflect"
)
// Set populates an entry in a map, making the map if necessary.
//
// That is, it assigns (*m)[k] = v, making *m if it was nil.
func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) {
if *m == nil {
*m = make(map[K]V)
}
(*m)[k] = v
}
// NonNil takes a pointer to a Go data structure
// (currently only a slice or a map) and makes sure it's non-nil for
// JSON serialization. (In particular, JavaScript clients usually want
// the field to be defined after they decode the JSON.)
// MakeNonNil takes a pointer to a Go data structure
// (currently only a slice or a map) and makes sure it's non-nil for
// JSON serialization. (In particular, JavaScript clients usually want
// the field to be defined after they decode the JSON.)
func NonNil(ptr interface{}) {
if ptr == nil {
panic("nil interface")
}
rv := reflect.ValueOf(ptr)
if rv.Kind() != reflect.Ptr {
panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind()))
}
if rv.Pointer() == 0 {
panic("nil pointer")
}
rv = rv.Elem()
if rv.Pointer() != 0 {
return
}
switch rv.Type().Kind() {
case reflect.Slice:
rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
case reflect.Map:
rv.Set(reflect.MakeMap(rv.Type()))
}
}

@ -0,0 +1,71 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package mak contains code to help make things.
package mak
import (
"reflect"
"testing"
)
type M map[string]int
func TestSet(t *testing.T) {
t.Run("unnamed", func(t *testing.T) {
var m map[string]int
Set(&m, "foo", 42)
Set(&m, "bar", 1)
Set(&m, "bar", 2)
want := map[string]int{
"foo": 42,
"bar": 2,
}
if got := m; !reflect.DeepEqual(got, want) {
t.Errorf("got %v; want %v", got, want)
}
})
t.Run("named", func(t *testing.T) {
var m M
Set(&m, "foo", 1)
Set(&m, "bar", 1)
Set(&m, "bar", 2)
want := M{
"foo": 1,
"bar": 2,
}
if got := m; !reflect.DeepEqual(got, want) {
t.Errorf("got %v; want %v", got, want)
}
})
}
func TestNonNil(t *testing.T) {
var s []string
NonNil(&s)
if len(s) != 0 {
t.Errorf("slice len = %d; want 0", len(s))
}
if s == nil {
t.Error("slice still nil")
}
s = append(s, "foo")
NonNil(&s)
if len(s) != 1 {
t.Errorf("len = %d; want 1", len(s))
}
if s[0] != "foo" {
t.Errorf("value = %q; want foo", s)
}
var m map[string]string
NonNil(&m)
if len(m) != 0 {
t.Errorf("map len = %d; want 0", len(s))
}
if m == nil {
t.Error("map still nil")
}
}

@ -55,6 +55,7 @@ import (
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/netconv" "tailscale.com/util/netconv"
"tailscale.com/util/mak"
"tailscale.com/util/uniq" "tailscale.com/util/uniq"
"tailscale.com/version" "tailscale.com/version"
"tailscale.com/wgengine/monitor" "tailscale.com/wgengine/monitor"
@ -438,11 +439,7 @@ func (c *Conn) removeDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp
func (c *Conn) addDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) { func (c *Conn) addDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.derpRoute == nil { mak.Set(&c.derpRoute, peer, derpRoute{derpID, dc})
c.derpRoute = make(map[key.NodePublic]derpRoute)
}
r := derpRoute{derpID, dc}
c.derpRoute[peer] = r
} }
// DerpMagicIP is a fake WireGuard endpoint IP address that means // DerpMagicIP is a fake WireGuard endpoint IP address that means
@ -1050,7 +1047,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
}, nil }, nil
} }
already := make(map[netaddr.IPPort]tailcfg.EndpointType) // endpoint -> how it was found var already map[netaddr.IPPort]tailcfg.EndpointType // endpoint -> how it was found
var eps []tailcfg.Endpoint // unique endpoints var eps []tailcfg.Endpoint // unique endpoints
ipp := func(s string) (ipp netaddr.IPPort) { ipp := func(s string) (ipp netaddr.IPPort) {
@ -1062,7 +1059,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
return return
} }
if _, ok := already[ipp]; !ok { if _, ok := already[ipp]; !ok {
already[ipp] = et mak.Set(&already, ipp, et)
eps = append(eps, tailcfg.Endpoint{Addr: ipp, Type: et}) eps = append(eps, tailcfg.Endpoint{Addr: ipp, Type: et})
} }
} }
@ -3957,9 +3954,6 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) {
for ep := range de.isCallMeMaybeEP { for ep := range de.isCallMeMaybeEP {
de.isCallMeMaybeEP[ep] = false // mark for deletion de.isCallMeMaybeEP[ep] = false // mark for deletion
} }
if de.isCallMeMaybeEP == nil {
de.isCallMeMaybeEP = map[netaddr.IPPort]bool{}
}
var newEPs []netaddr.IPPort var newEPs []netaddr.IPPort
for _, ep := range m.MyNumber { for _, ep := range m.MyNumber {
if ep.IP().Is6() && ep.IP().IsLinkLocalUnicast() { if ep.IP().Is6() && ep.IP().IsLinkLocalUnicast() {
@ -3968,7 +3962,7 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) {
// for these. // for these.
continue continue
} }
de.isCallMeMaybeEP[ep] = true mak.Set(&de.isCallMeMaybeEP, ep, true)
if es, ok := de.endpointState[ep]; ok { if es, ok := de.endpointState[ep]; ok {
es.callMeMaybeTime = now es.callMeMaybeTime = now
} else { } else {

@ -15,6 +15,7 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/net/tstun" "tailscale.com/net/tstun"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/util/mak"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
) )
@ -115,14 +116,11 @@ func (e *userspaceEngine) trackOpenPostFilterOut(pp *packet.Parsed, t *tstun.Wra
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock() defer e.mu.Unlock()
if e.pendOpen == nil {
e.pendOpen = make(map[flowtrack.Tuple]*pendingOpenFlow)
}
if _, dup := e.pendOpen[flow]; dup { if _, dup := e.pendOpen[flow]; dup {
// Duplicates are expected when the OS retransmits. Ignore. // Duplicates are expected when the OS retransmits. Ignore.
return return
} }
e.pendOpen[flow] = &pendingOpenFlow{timer: timer} mak.Set(&e.pendOpen, flow, &pendingOpenFlow{timer: timer})
return filter.Accept return filter.Accept
} }

Loading…
Cancel
Save