all: use set.Set consistently instead of map[T]struct{}

I didn't clean up the more idiomatic map[T]bool with true values, at
least yet.  I just converted the relatively awkward struct{}-valued
maps.

Updates #cleanup

Change-Id: I758abebd2bb1f64bc7a9d0f25c32298f4679c14f
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/9325/head
Brad Fitzpatrick 1 year ago committed by Brad Fitzpatrick
parent d506a55c8a
commit dc7aa98b76

@ -27,7 +27,7 @@ var (
sysErr = map[Subsystem]error{} // error key => err (or nil for no error) sysErr = map[Subsystem]error{} // error key => err (or nil for no error)
watchers = set.HandleSet[func(Subsystem, error)]{} // opt func to run if error state changes watchers = set.HandleSet[func(Subsystem, error)]{} // opt func to run if error state changes
warnables = map[*Warnable]struct{}{} // set of warnables warnables = set.Set[*Warnable]{}
timer *time.Timer timer *time.Timer
debugHandler = map[string]http.Handler{} debugHandler = map[string]http.Handler{}
@ -84,7 +84,7 @@ func NewWarnable(opts ...WarnableOpt) *Warnable {
} }
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
warnables[w] = struct{}{} warnables.Add(w)
return w return w
} }

@ -8,6 +8,8 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
"tailscale.com/util/set"
) )
func TestAppendWarnableDebugFlags(t *testing.T) { func TestAppendWarnableDebugFlags(t *testing.T) {
@ -35,5 +37,5 @@ func TestAppendWarnableDebugFlags(t *testing.T) {
func resetWarnables() { func resetWarnables() {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
warnables = make(map[*Warnable]struct{}) warnables = set.Set[*Warnable]{}
} }

@ -13,6 +13,7 @@ import (
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
"tailscale.com/util/set"
"tailscale.com/util/winutil" "tailscale.com/util/winutil"
) )
@ -158,14 +159,14 @@ func (db *nrptRuleDatabase) detectWriteAsGP() {
} }
// Add *all* rules from the GP subkey into a set. // Add *all* rules from the GP subkey into a set.
gpSubkeyMap := make(map[string]struct{}, len(gpSubkeyNames)) gpSubkeyMap := make(set.Set[string], len(gpSubkeyNames))
for _, gpSubkey := range gpSubkeyNames { for _, gpSubkey := range gpSubkeyNames {
gpSubkeyMap[strings.ToUpper(gpSubkey)] = struct{}{} gpSubkeyMap.Add(strings.ToUpper(gpSubkey))
} }
// Remove *our* rules from the set. // Remove *our* rules from the set.
for _, ourRuleID := range db.ruleIDs { for _, ourRuleID := range db.ruleIDs {
delete(gpSubkeyMap, strings.ToUpper(ourRuleID)) gpSubkeyMap.Delete(strings.ToUpper(ourRuleID))
} }
// Any leftover rules do not belong to us. When group policy is being used // Any leftover rules do not belong to us. When group policy is being used

@ -35,6 +35,7 @@ import (
"tailscale.com/types/views" "tailscale.com/types/views"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/mak" "tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/wgengine/capture" "tailscale.com/wgengine/capture"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
@ -589,7 +590,7 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config) *natV4Config {
var ( var (
rt table.RoutingTableBuilder rt table.RoutingTableBuilder
dstMasqAddrs map[key.NodePublic]netip.Addr dstMasqAddrs map[key.NodePublic]netip.Addr
listenAddrs map[netip.Addr]struct{} listenAddrs set.Set[netip.Addr]
) )
// When using an exit node that requires masquerading, we need to // When using an exit node that requires masquerading, we need to

@ -14,6 +14,7 @@ import (
"github.com/fxamacker/cbor/v2" "github.com/fxamacker/cbor/v2"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"tailscale.com/types/tkatype" "tailscale.com/types/tkatype"
"tailscale.com/util/set"
) )
// AUMHash represents the BLAKE2s digest of an Authority Update Message (AUM). // AUMHash represents the BLAKE2s digest of an Authority Update Message (AUM).
@ -326,7 +327,7 @@ func (a *AUM) Weight(state State) uint {
// Despite the wire encoding being []byte, all KeyIDs are // Despite the wire encoding being []byte, all KeyIDs are
// 32 bytes. As such, we use that as the key for the map, // 32 bytes. As such, we use that as the key for the map,
// because map keys cannot be slices. // because map keys cannot be slices.
seenKeys := make(map[[32]byte]struct{}, 6) seenKeys := make(set.Set[[32]byte], 6)
for _, sig := range a.Signatures { for _, sig := range a.Signatures {
if len(sig.KeyID) != 32 { if len(sig.KeyID) != 32 {
panic("unexpected: keyIDs are 32 bytes") panic("unexpected: keyIDs are 32 bytes")
@ -344,12 +345,12 @@ func (a *AUM) Weight(state State) uint {
} }
panic(err) panic(err)
} }
if _, seen := seenKeys[keyID]; seen { if seenKeys.Contains(keyID) {
continue continue
} }
weight += key.Votes weight += key.Votes
seenKeys[keyID] = struct{}{} seenKeys.Add(keyID)
} }
return weight return weight

@ -14,6 +14,7 @@ import (
"github.com/fxamacker/cbor/v2" "github.com/fxamacker/cbor/v2"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/tkatype" "tailscale.com/types/tkatype"
"tailscale.com/util/set"
) )
// Strict settings for the CBOR decoder. // Strict settings for the CBOR decoder.
@ -260,13 +261,13 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error)
var ( var (
curs = topAUM curs = topAUM
state State state State
path = make(map[AUMHash]struct{}, 32) // 32 chosen arbitrarily. path = make(set.Set[AUMHash], 32) // 32 chosen arbitrarily.
) )
for i := 0; true; i++ { for i := 0; true; i++ {
if i > maxIter { if i > maxIter {
return State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter) return State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter)
} }
path[curs.Hash()] = struct{}{} path.Add(curs.Hash())
// Checkpoints encapsulate the state at that point, dope. // Checkpoints encapsulate the state at that point, dope.
if curs.MessageKind == AUMCheckpoint { if curs.MessageKind == AUMCheckpoint {
@ -307,7 +308,7 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error)
// such, we use a custom advancer here. // such, we use a custom advancer here.
advancer := func(state State, candidates []AUM) (next *AUM, out State, err error) { advancer := func(state State, candidates []AUM) (next *AUM, out State, err error) {
for _, c := range candidates { for _, c := range candidates {
if _, inPath := path[c.Hash()]; inPath { if path.Contains(c.Hash()) {
if state, err = state.applyVerifiedAUM(c); err != nil { if state, err = state.applyVerifiedAUM(c); err != nil {
return nil, State{}, fmt.Errorf("advancing state: %v", err) return nil, State{}, fmt.Errorf("advancing state: %v", err)
} }

@ -10,6 +10,9 @@ type Set[T comparable] map[T]struct{}
// Add adds e to the set. // Add adds e to the set.
func (s Set[T]) Add(e T) { s[e] = struct{}{} } func (s Set[T]) Add(e T) { s[e] = struct{}{} }
// Delete removes e from the set.
func (s Set[T]) Delete(e T) { delete(s, e) }
// Contains reports whether s contains e. // Contains reports whether s contains e.
func (s Set[T]) Contains(e T) bool { func (s Set[T]) Contains(e T) bool {
_, ok := s[e] _, ok := s[e]

@ -14,6 +14,7 @@ import (
"golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr" "golang.org/x/sys/windows/svc/mgr"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/set"
) )
// LogSvcState obtains the state of the Windows service named rootSvcName and // LogSvcState obtains the state of the Windows service named rootSvcName and
@ -78,7 +79,7 @@ func walkServices(rootSvcName string, callback walkSvcFunc) error {
} }
}() }()
seen := make(map[string]struct{}) seen := set.Set[string]{}
for err == nil && len(deps) > 0 { for err == nil && len(deps) > 0 {
err = func() error { err = func() error {
@ -87,7 +88,7 @@ func walkServices(rootSvcName string, callback walkSvcFunc) error {
deps = deps[:len(deps)-1] deps = deps[:len(deps)-1]
seen[curSvc.Name] = struct{}{} seen.Add(curSvc.Name)
curCfg, err := curSvc.Config() curCfg, err := curSvc.Config()
if err != nil { if err != nil {
@ -97,7 +98,7 @@ func walkServices(rootSvcName string, callback walkSvcFunc) error {
callback(curSvc, curCfg) callback(curSvc, curCfg)
for _, depName := range curCfg.Dependencies { for _, depName := range curCfg.Dependencies {
if _, ok := seen[depName]; ok { if seen.Contains(depName) {
continue continue
} }

@ -54,6 +54,7 @@ import (
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/mak" "tailscale.com/util/mak"
"tailscale.com/util/ringbuffer" "tailscale.com/util/ringbuffer"
"tailscale.com/util/set"
"tailscale.com/util/uniq" "tailscale.com/util/uniq"
"tailscale.com/wgengine/capture" "tailscale.com/wgengine/capture"
) )
@ -229,7 +230,7 @@ type Conn struct {
// WireGuard. These are not used to filter inbound or outbound // WireGuard. These are not used to filter inbound or outbound
// traffic at all, but only to track what state can be cleaned up // traffic at all, but only to track what state can be cleaned up
// in other maps below that are keyed by peer public key. // in other maps below that are keyed by peer public key.
peerSet map[key.NodePublic]struct{} peerSet set.Set[key.NodePublic]
// nodeOfDisco tracks the networkmap Node entity for each peer // nodeOfDisco tracks the networkmap Node entity for each peer
// discovery key. // discovery key.
@ -1708,7 +1709,7 @@ func (c *Conn) SetPrivateKey(privateKey key.NodePrivate) error {
// then removes any state for old peers. // then removes any state for old peers.
// //
// The caller passes ownership of newPeers map to UpdatePeers. // The caller passes ownership of newPeers map to UpdatePeers.
func (c *Conn) UpdatePeers(newPeers map[key.NodePublic]struct{}) { func (c *Conn) UpdatePeers(newPeers set.Set[key.NodePublic]) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -1718,7 +1719,7 @@ func (c *Conn) UpdatePeers(newPeers map[key.NodePublic]struct{}) {
// Clean up any key.NodePublic-keyed maps for peers that no longer // Clean up any key.NodePublic-keyed maps for peers that no longer
// exist. // exist.
for peer := range oldPeers { for peer := range oldPeers {
if _, ok := newPeers[peer]; !ok { if !newPeers.Contains(peer) {
delete(c.derpRoute, peer) delete(c.derpRoute, peer)
delete(c.peerLastDerp, peer) delete(c.peerLastDerp, peer)
} }

@ -58,6 +58,7 @@ import (
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
"tailscale.com/util/cibuild" "tailscale.com/util/cibuild"
"tailscale.com/util/racebuild" "tailscale.com/util/racebuild"
"tailscale.com/util/set"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
"tailscale.com/wgengine/wgcfg/nmcfg" "tailscale.com/wgengine/wgcfg/nmcfg"
@ -306,9 +307,9 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM
for i, m := range ms { for i, m := range ms {
nm := buildNetmapLocked(i) nm := buildNetmapLocked(i)
m.conn.SetNetworkMap(nm) m.conn.SetNetworkMap(nm)
peerSet := make(map[key.NodePublic]struct{}, len(nm.Peers)) peerSet := make(set.Set[key.NodePublic], len(nm.Peers))
for _, peer := range nm.Peers { for _, peer := range nm.Peers {
peerSet[peer.Key()] = struct{}{} peerSet.Add(peer.Key())
} }
m.conn.UpdatePeers(peerSet) m.conn.UpdatePeers(peerSet)
wg, err := nmcfg.WGCfg(nm, logf, netmap.AllowSingleHosts, "") wg, err := nmcfg.WGCfg(nm, logf, netmap.AllowSingleHosts, "")

@ -14,6 +14,7 @@ import (
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/set"
) )
// For now this router only supports the WireGuard userspace implementation. // For now this router only supports the WireGuard userspace implementation.
@ -26,7 +27,7 @@ type openbsdRouter struct {
tunname string tunname string
local4 netip.Prefix local4 netip.Prefix
local6 netip.Prefix local6 netip.Prefix
routes map[netip.Prefix]struct{} routes set.Set[netip.Prefix]
} }
func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor) (Router, error) { func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor) (Router, error) {
@ -173,9 +174,9 @@ func (r *openbsdRouter) Set(cfg *Config) error {
} }
} }
newRoutes := make(map[netip.Prefix]struct{}) newRoutes := set.Set[netip.Prefix]{}
for _, route := range cfg.Routes { for _, route := range cfg.Routes {
newRoutes[route] = struct{}{} newRoutes.Add(route)
} }
for route := range r.routes { for route := range r.routes {
if _, keep := newRoutes[route]; !keep { if _, keep := newRoutes[route]; !keep {

@ -44,6 +44,7 @@ import (
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/deephash" "tailscale.com/util/deephash"
"tailscale.com/util/mak" "tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/wgengine/capture" "tailscale.com/wgengine/capture"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/magicsock"
@ -782,12 +783,12 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config,
e.tundev.SetWGConfig(cfg) e.tundev.SetWGConfig(cfg)
e.lastDNSConfig = dnsCfg e.lastDNSConfig = dnsCfg
peerSet := make(map[key.NodePublic]struct{}, len(cfg.Peers)) peerSet := make(set.Set[key.NodePublic], len(cfg.Peers))
e.mu.Lock() e.mu.Lock()
e.peerSequence = e.peerSequence[:0] e.peerSequence = e.peerSequence[:0]
for _, p := range cfg.Peers { for _, p := range cfg.Peers {
e.peerSequence = append(e.peerSequence, p.PublicKey) e.peerSequence = append(e.peerSequence, p.PublicKey)
peerSet[p.PublicKey] = struct{}{} peerSet.Add(p.PublicKey)
} }
nm := e.netMap nm := e.netMap
e.mu.Unlock() e.mu.Unlock()

Loading…
Cancel
Save