diff --git a/cmd/derper/bootstrap_dns.go b/cmd/derper/bootstrap_dns.go index a93acb812..6c909bc36 100644 --- a/cmd/derper/bootstrap_dns.go +++ b/cmd/derper/bootstrap_dns.go @@ -12,11 +12,12 @@ import ( "net" "net/http" "strings" - "sync/atomic" "time" + + "tailscale.com/syncs" ) -var dnsCache atomic.Value // of []byte +var dnsCache syncs.AtomicValue[[]byte] var bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests") @@ -58,7 +59,7 @@ func refreshBootstrapDNS() { func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) { bootstrapDNSRequests.Add(1) w.Header().Set("Content-Type", "application/json") - j, _ := dnsCache.Load().([]byte) + j := dnsCache.Load() // Bootstrap DNS requests occur cross-regions, // and are randomized per request, // so keeping a connection open is pointlessly expensive. diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 4351e60a9..365336ca9 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -71,7 +71,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep 💣 tailscale.com/net/tshttpproxy from tailscale.com/derp/derphttp+ tailscale.com/paths from tailscale.com/cmd/tailscale/cli+ tailscale.com/safesocket from tailscale.com/cmd/tailscale/cli+ - tailscale.com/syncs from tailscale.com/net/netcheck + tailscale.com/syncs from tailscale.com/net/netcheck+ tailscale.com/tailcfg from tailscale.com/cmd/tailscale/cli+ tailscale.com/tka from tailscale.com/types/key W tailscale.com/tsconst from tailscale.com/net/interfaces diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 96d49f13f..2966ac227 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -23,7 +23,6 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "time" "go4.org/mem" @@ -41,6 +40,7 @@ import ( "tailscale.com/net/tlsdial" "tailscale.com/net/tsdial" "tailscale.com/net/tshttpproxy" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -939,8 +939,8 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool if resp.Debug.GoroutineDumpURL != "" { go dumpGoroutinesToURL(c.httpc, resp.Debug.GoroutineDumpURL) } - setControlAtomic(&controlUseDERPRoute, resp.Debug.DERPRoute) - setControlAtomic(&controlTrimWGConfig, resp.Debug.TrimWGConfig) + controlUseDERPRoute.Store(resp.Debug.DERPRoute) + controlTrimWGConfig.Store(resp.Debug.TrimWGConfig) if sleep := time.Duration(resp.Debug.SleepSeconds * float64(time.Second)); sleep > 0 { if err := sleepAsRequested(ctx, c.logf, timeoutReset, sleep); err != nil { return err @@ -1151,29 +1151,20 @@ var clockNow = time.Now // opt.Bool configs from control. var ( - controlUseDERPRoute atomic.Value // of opt.Bool - controlTrimWGConfig atomic.Value // of opt.Bool + controlUseDERPRoute syncs.AtomicValue[opt.Bool] + controlTrimWGConfig syncs.AtomicValue[opt.Bool] ) -func setControlAtomic(dst *atomic.Value, v opt.Bool) { - old, ok := dst.Load().(opt.Bool) - if !ok || old != v { - dst.Store(v) - } -} - // DERPRouteFlag reports the last reported value from control for whether // DERP route optimization (Issue 150) should be enabled. func DERPRouteFlag() opt.Bool { - v, _ := controlUseDERPRoute.Load().(opt.Bool) - return v + return controlUseDERPRoute.Load() } // TrimWGConfig reports the last reported value from control for whether // we should do lazy wireguard configuration. func TrimWGConfig() opt.Bool { - v, _ := controlTrimWGConfig.Load().(opt.Bool) - return v + return controlTrimWGConfig.Load() } // ipForwardingBroken reports whether the system's IP forwarding is disabled diff --git a/derp/derp_client.go b/derp/derp_client.go index f7282f646..9f14dca70 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -13,11 +13,11 @@ import ( "io" "net/netip" "sync" - "sync/atomic" "time" "go4.org/mem" "golang.org/x/time/rate" + "tailscale.com/syncs" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -39,8 +39,8 @@ type Client struct { rate *rate.Limiter // if non-nil, rate limiter to use // Owned by Recv: - peeked int // bytes to discard on next Recv - readErr atomic.Value // of error; sticky (set by Recv) + peeked int // bytes to discard on next Recv + readErr syncs.AtomicValue[error] // sticky (set by Recv) } // ClientOpt is an option passed to NewClient. @@ -445,7 +445,7 @@ func (c *Client) Recv() (m ReceivedMessage, err error) { } func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err error) { - readErr, _ := c.readErr.Load().(error) + readErr := c.readErr.Load() if readErr != nil { return nil, readErr } diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 28d990bc4..3b6cfd9bd 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -27,7 +27,6 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "time" "go4.org/mem" @@ -37,6 +36,7 @@ import ( "tailscale.com/net/netns" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -69,7 +69,7 @@ type Client struct { // by SetAddressFamilySelector. It's an atomic because it needs // to be accessed by multiple racing routines started while // Client.conn holds mu. - addrFamSelAtomic atomic.Value // of AddressFamilySelector + addrFamSelAtomic syncs.AtomicValue[AddressFamilySelector] mu sync.Mutex preferred bool diff --git a/hostinfo/hostinfo_windows.go b/hostinfo/hostinfo_windows.go index 023f92463..4b8ba40ef 100644 --- a/hostinfo/hostinfo_windows.go +++ b/hostinfo/hostinfo_windows.go @@ -8,10 +8,10 @@ import ( "fmt" "os" "path/filepath" - "sync/atomic" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" + "tailscale.com/syncs" "tailscale.com/util/winutil" ) @@ -20,10 +20,10 @@ func init() { packageType = packageTypeWindows } -var winVerCache atomic.Value // of string +var winVerCache syncs.AtomicValue[string] func osVersionWindows() string { - if s, ok := winVerCache.Load().(string); ok { + if s, ok := winVerCache.LoadOk(); ok { return s } major, minor, build := windows.RtlGetNtVersionNumbers() diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 7e8c5d30f..af085045f 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -40,6 +40,7 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/paths" "tailscale.com/portlist" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/types/dnstype" @@ -130,7 +131,7 @@ type LocalBackend struct { shutdownCalled bool // if Shutdown has been called filterAtomic atomic.Pointer[filter.Filter] - containsViaIPFuncAtomic atomic.Value // of func(netip.Addr) bool + containsViaIPFuncAtomic syncs.AtomicValue[func(netip.Addr) bool] // The mutex protects the following elements. mu sync.Mutex @@ -1500,17 +1501,17 @@ func (b *LocalBackend) tellClientToBrowseToURL(url string) { var panicOnMachineKeyGeneration = envknob.Bool("TS_DEBUG_PANIC_MACHINE_KEY") func (b *LocalBackend) createGetMachinePrivateKeyFunc() func() (key.MachinePrivate, error) { - var cache atomic.Value + var cache syncs.AtomicValue[key.MachinePrivate] return func() (key.MachinePrivate, error) { if panicOnMachineKeyGeneration { panic("machine key generated") } - if v, ok := cache.Load().(key.MachinePrivate); ok { + if v, ok := cache.LoadOk(); ok { return v, nil } b.mu.Lock() defer b.mu.Unlock() - if v, ok := cache.Load().(key.MachinePrivate); ok { + if v, ok := cache.LoadOk(); ok { return v, nil } if err := b.initMachineKeyLocked(); err != nil { @@ -1522,11 +1523,11 @@ func (b *LocalBackend) createGetMachinePrivateKeyFunc() func() (key.MachinePriva } func (b *LocalBackend) createGetNLPublicKeyFunc() func() (key.NLPublic, error) { - var cache atomic.Value + var cache syncs.AtomicValue[key.NLPublic] return func() (key.NLPublic, error) { b.mu.Lock() defer b.mu.Unlock() - if v, ok := cache.Load().(key.NLPublic); ok { + if v, ok := cache.LoadOk(); ok { return v, nil } @@ -2524,8 +2525,7 @@ func (b *LocalBackend) TailscaleVarRoot() string { } switch runtime.GOOS { case "ios", "android", "darwin": - dir, _ := paths.AppSharedDir.Load().(string) - return dir + return paths.AppSharedDir.Load() } return "" } @@ -3058,7 +3058,7 @@ func (b *LocalBackend) ShouldRunSSH() bool { return b.sshAtomicBool.Load() && ca // Tailscale ULA's v6 "via" range embedding an IPv4 address to be forwarded to // by Tailscale. func (b *LocalBackend) ShouldHandleViaIP(ip netip.Addr) bool { - if f, ok := b.containsViaIPFuncAtomic.Load().(func(netip.Addr) bool); ok { + if f, ok := b.containsViaIPFuncAtomic.LoadOk(); ok { return f(ip) } return false diff --git a/net/dns/nrpt_windows.go b/net/dns/nrpt_windows.go index 4d2b641ff..7ec29b344 100644 --- a/net/dns/nrpt_windows.go +++ b/net/dns/nrpt_windows.go @@ -63,8 +63,8 @@ const _RP_FORCE = 1 // Flag for RefreshPolicyEx type nrptRuleDatabase struct { logf logger.Logf watcher *gpNotificationWatcher - isGPRefreshPending atomic.Value // of bool - mu sync.Mutex // protects the fields below + isGPRefreshPending atomic.Bool + mu sync.Mutex // protects the fields below ruleIDs []string isGPDirty bool writeAsGP bool diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 91cc2f96c..05b285a58 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -21,7 +21,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" dns "golang.org/x/net/dns/dnsmessage" @@ -29,6 +28,7 @@ import ( "tailscale.com/net/netaddr" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" + "tailscale.com/syncs" "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" @@ -495,7 +495,7 @@ type resolvConfCache struct { // resolvConfCacheValue contains the most recent stat metadata and parsed // version of /etc/resolv.conf. -var resolvConfCacheValue atomic.Value // of resolvConfCache +var resolvConfCacheValue syncs.AtomicValue[resolvConfCache] var errEmptyResolvConf = errors.New("resolv.conf has no nameservers") @@ -510,7 +510,7 @@ func stubResolverForOS() (ip netip.Addr, err error) { mod: fi.ModTime(), size: fi.Size(), } - if c, ok := resolvConfCacheValue.Load().(resolvConfCache); ok && c.mod == cur.mod && c.size == cur.size { + if c, ok := resolvConfCacheValue.LoadOk(); ok && c.mod == cur.mod && c.size == cur.size { return c.ip, nil } conf, err := resolvconffile.ParseFile(resolvconffile.Path) diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index a24394bc8..c6e817819 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -15,7 +15,6 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "syscall" "time" @@ -43,8 +42,6 @@ type Dialer struct { // If nil, it's not used. NetstackDialTCP func(context.Context, netip.AddrPort) (net.Conn, error) - peerDialControlFuncAtomic atomic.Value // of func() func(network, address string, c syscall.RawConn) error - peerClientOnce sync.Once peerClient *http.Client diff --git a/net/tshttpproxy/tshttpproxy_windows.go b/net/tshttpproxy/tshttpproxy_windows.go index 651970ae9..ff04c5dae 100644 --- a/net/tshttpproxy/tshttpproxy_windows.go +++ b/net/tshttpproxy/tshttpproxy_windows.go @@ -14,7 +14,6 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "syscall" "time" "unsafe" @@ -22,6 +21,7 @@ import ( "github.com/alexbrainman/sspi/negotiate" "golang.org/x/sys/windows" "tailscale.com/hostinfo" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/cmpver" ) @@ -155,10 +155,10 @@ const win8dot1Ver = "6.3" // accessType is the flag we must pass to WinHttpOpen for proxy resolution // depending on whether or not we're running Windows < 8.1 -var accessType atomic.Value // of uint32 +var accessType syncs.AtomicValue[uint32] func getAccessFlag() uint32 { - if flag, ok := accessType.Load().(uint32); ok { + if flag, ok := accessType.LoadOk(); ok { return flag } var flag uint32 diff --git a/net/tstun/tap_linux.go b/net/tstun/tap_linux.go index e6b699aec..643e997eb 100644 --- a/net/tstun/tap_linux.go +++ b/net/tstun/tap_linux.go @@ -341,8 +341,7 @@ func run(prog string, args ...string) error { } func (t *Wrapper) destMAC() [6]byte { - mac, _ := t.destMACAtomic.Load().([6]byte) - return mac + return t.destMACAtomic.Load() } func (t *Wrapper) tapWrite(buf []byte, offset int) (int, error) { diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 0fd3e65ec..658c85eff 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -24,6 +24,7 @@ import ( "tailscale.com/disco" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" + "tailscale.com/syncs" "tailscale.com/tstime/mono" "tailscale.com/types/ipproto" "tailscale.com/types/key" @@ -82,9 +83,9 @@ type Wrapper struct { // you might need to add a pad32.Four field here. lastActivityAtomic mono.Time // time of last send or receive - destIPActivity atomic.Value // of map[netip.Addr]func() - destMACAtomic atomic.Value // of [6]byte - discoKey atomic.Value // of key.DiscoPublic + destIPActivity syncs.AtomicValue[map[netip.Addr]func()] + destMACAtomic syncs.AtomicValue[[6]byte] + discoKey syncs.AtomicValue[key.DiscoPublic] // buffer stores the oldest unconsumed packet from tdev. // It is made a static buffer in order to avoid allocations. @@ -247,8 +248,8 @@ func (t *Wrapper) isSelfDisco(p *packet.Parsed) bool { return false } discoSrc := key.DiscoPublicFromRaw32(mem.B(discobs)) - selfDiscoPub, ok := t.discoKey.Load().(key.DiscoPublic) - return ok && selfDiscoPub == discoSrc + selfDiscoPub := t.discoKey.Load() + return selfDiscoPub == discoSrc } func (t *Wrapper) Close() error { @@ -543,7 +544,7 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) { defer parsedPacketPool.Put(p) p.Decode(buf[offset : offset+n]) - if m, ok := t.destIPActivity.Load().(map[netip.Addr]func()); ok { + if m := t.destIPActivity.Load(); m != nil { if fn := m[p.Dst.Addr()]; fn != nil { fn() } diff --git a/paths/paths.go b/paths/paths.go index 183e97ff6..778330277 100644 --- a/paths/paths.go +++ b/paths/paths.go @@ -10,14 +10,14 @@ import ( "os" "path/filepath" "runtime" - "sync/atomic" + "tailscale.com/syncs" "tailscale.com/version/distro" ) // AppSharedDir is a string set by the iOS or Android app on start // containing a directory we can read/write in. -var AppSharedDir atomic.Value // of string +var AppSharedDir syncs.AtomicValue[string] // DefaultTailscaledSocket returns the path to the tailscaled Unix socket // or the empty string if there's no reasonable default. diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 787bef758..c0ef0f071 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -35,6 +35,7 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store" "tailscale.com/safesocket" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/integration/testcontrol" @@ -46,7 +47,7 @@ var ( verboseTailscale = flag.Bool("verbose-tailscale", false, "verbose tailscale CLI logging") ) -var mainError atomic.Value // of error +var mainError syncs.AtomicValue[error] func TestMain(m *testing.M) { // Have to disable UPnP which hits the network, otherwise it fails due to HTTP proxy. @@ -57,7 +58,7 @@ func TestMain(m *testing.M) { if v != 0 { os.Exit(v) } - if err, ok := mainError.Load().(error); ok { + if err := mainError.Load(); err != nil { fmt.Fprintf(os.Stderr, "FAIL: %v\n", err) os.Exit(1) } @@ -936,14 +937,11 @@ func (n *testNode) MustStatus() *ipnstate.Status { // HTTP traffic tries to leave localhost from tailscaled. We don't // expect any, so any request triggers a failure. type trafficTrap struct { - atomicErr atomic.Value // of error + atomicErr syncs.AtomicValue[error] } func (tt *trafficTrap) Err() error { - if err, ok := tt.atomicErr.Load().(error); ok { - return err - } - return nil + return tt.atomicErr.Load() } func (tt *trafficTrap) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/types/key/nl.go b/types/key/nl.go index ceaf244e2..e9092d437 100644 --- a/types/key/nl.go +++ b/types/key/nl.go @@ -107,3 +107,13 @@ func (k NLPublic) MarshalText() ([]byte, error) { func (k NLPublic) Verifier() ed25519.PublicKey { return ed25519.PublicKey(k.k[:]) } + +// IsZero reports whether k is the zero value. +func (k NLPublic) IsZero() bool { + return k.Equal(NLPublic{}) +} + +// Equal reports whether k and other are the same key. +func (k NLPublic) Equal(other NLPublic) bool { + return subtle.ConstantTimeCompare(k.k[:], other.k[:]) == 1 +} diff --git a/util/cloudenv/cloudenv.go b/util/cloudenv/cloudenv.go index da53db464..b6bf60979 100644 --- a/util/cloudenv/cloudenv.go +++ b/util/cloudenv/cloudenv.go @@ -14,8 +14,9 @@ import ( "os" "runtime" "strings" - "sync/atomic" "time" + + "tailscale.com/syncs" ) // CommonNonRoutableMetadataIP is the IP address of the metadata server @@ -69,15 +70,14 @@ func (c Cloud) HasInternalTLD() bool { return false } -var cloudAtomic atomic.Value // of Cloud +var cloudAtomic syncs.AtomicValue[Cloud] // Get returns the current cloud, or the empty string if unknown. func Get() Cloud { - c, ok := cloudAtomic.Load().(Cloud) - if ok { + if c, ok := cloudAtomic.LoadOk(); ok { return c } - c = getCloud() + c := getCloud() cloudAtomic.Store(c) // even if empty return c } diff --git a/version/distro/distro.go b/version/distro/distro.go index df97327a4..3b1d36378 100644 --- a/version/distro/distro.go +++ b/version/distro/distro.go @@ -9,7 +9,8 @@ import ( "os" "runtime" "strconv" - "sync/atomic" + + "tailscale.com/syncs" ) type Distro string @@ -28,14 +29,14 @@ const ( WDMyCloud = Distro("wdmycloud") ) -var distroAtomic atomic.Value // of Distro +var distroAtomic syncs.AtomicValue[Distro] // Get returns the current distro, or the empty string if unknown. func Get() Distro { - d, ok := distroAtomic.Load().(Distro) - if ok { + if d, ok := distroAtomic.LoadOk(); ok { return d } + var d Distro switch runtime.GOOS { case "linux": d = linuxDistro() diff --git a/version/prop.go b/version/prop.go index fef10c4ff..68817c894 100644 --- a/version/prop.go +++ b/version/prop.go @@ -9,7 +9,8 @@ import ( "path/filepath" "runtime" "strings" - "sync/atomic" + + "tailscale.com/syncs" ) // IsMobile reports whether this is a mobile client build. @@ -43,7 +44,7 @@ func IsSandboxedMacOS() bool { return strings.HasSuffix(exe, "/Contents/MacOS/Tailscale") } -var isMacSysExt atomic.Value // of bool +var isMacSysExt syncs.AtomicValue[bool] // IsMacSysExt whether this binary is from the standalone "System // Extension" (a.k.a. "macsys") version of Tailscale for macOS. @@ -51,7 +52,7 @@ func IsMacSysExt() bool { if runtime.GOOS != "darwin" { return false } - if b, ok := isMacSysExt.Load().(bool); ok { + if b, ok := isMacSysExt.LoadOk(); ok { return b } exe, err := os.Executable() diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 1f25e8ac1..b7056fede 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -264,7 +264,7 @@ type Conn struct { // stunReceiveFunc holds the current STUN packet processing func. // Its Loaded value is always non-nil. - stunReceiveFunc atomic.Value // of func(p []byte, fromAddr *net.UDPAddr) + stunReceiveFunc syncs.AtomicValue[func(p []byte, fromAddr netip.AddrPort)] // derpRecvCh is used by receiveDERP to read DERP messages. // It must have buffer size > 0; see issue 3736. @@ -300,7 +300,7 @@ type Conn struct { // havePrivateKey is whether privateKey is non-zero. havePrivateKey atomic.Bool - publicKeyAtomic atomic.Value // of key.NodePublic (or NodeKey zero value if !havePrivateKey) + publicKeyAtomic syncs.AtomicValue[key.NodePublic] // or NodeKey zero value if !havePrivateKey // derpMapAtomic is the same as derpMap, but without requiring // sync.Mutex. For use with NewRegionClient's callback, to avoid @@ -1668,7 +1668,7 @@ func (c *Conn) receiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { // caller). func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) (ep *endpoint, ok bool) { if stun.Is(b) { - c.stunReceiveFunc.Load().(func([]byte, netip.AddrPort))(b, ipp) + c.stunReceiveFunc.Load()(b, ipp) return nil, false } if c.handleDiscoMessage(b, ipp, key.NodePublic{}) { @@ -2979,10 +2979,8 @@ type RebindingUDPConn struct { // check pconn (after acquiring mu) to see if there's been a rebind // meanwhile. // pconn isn't really needed, but makes some of the code simpler - // to keep it in a type safe form. TODO(bradfitz): really we should make a generic - // atomic.Value. Unfortunately Go 1.19's atomic.Pointer[T] is only for pointers, - // not interfaces. - pconnAtomic atomic.Value // of nettype.PacketConn + // to keep it in a type safe form. + pconnAtomic syncs.AtomicValue[nettype.PacketConn] mu sync.Mutex // held while changing pconn (and pconnAtomic) pconn nettype.PacketConn @@ -3004,7 +3002,7 @@ func (c *RebindingUDPConn) currentConn() nettype.PacketConn { // It returns the number of bytes copied and the source address. func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { for { - pconn := c.pconnAtomic.Load().(nettype.PacketConn) + pconn := c.pconnAtomic.Load() n, addr, err := pconn.ReadFrom(b) if err != nil && pconn != c.currentConn() { continue @@ -3022,7 +3020,7 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { // when c's underlying connection is a net.UDPConn. func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netip.AddrPort, err error) { for { - pconn := c.pconnAtomic.Load().(nettype.PacketConn) + pconn := c.pconnAtomic.Load() // Optimization: Treat *net.UDPConn specially. // This lets us avoid allocations by calling ReadFromUDPAddrPort. @@ -3081,7 +3079,7 @@ func (c *RebindingUDPConn) closeLocked() error { func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { for { - pconn := c.pconnAtomic.Load().(nettype.PacketConn) + pconn := c.pconnAtomic.Load() n, err := pconn.WriteTo(b, addr) if err != nil { @@ -3095,7 +3093,7 @@ func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { for { - pconn := c.pconnAtomic.Load().(nettype.PacketConn) + pconn := c.pconnAtomic.Load() n, err := pconn.WriteToUDPAddrPort(b, addr) if err != nil { @@ -3643,10 +3641,9 @@ func (de *endpoint) removeSentPingLocked(txid stun.TxID, sp sentPing) { // The caller should use de.discoKey as the discoKey argument. // It is passed in so that sendDiscoPing doesn't need to lock de.mu. func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, txid stun.TxID, logLevel discoLogLevel) { - selfPubKey, _ := de.c.publicKeyAtomic.Load().(key.NodePublic) sent, _ := de.c.sendDiscoMessage(ep, de.publicKey, discoKey, &disco.Ping{ TxID: [12]byte(txid), - NodeKey: selfPubKey, + NodeKey: de.c.publicKeyAtomic.Load(), }, logLevel) if !sent { de.forgetPing(txid) diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 46e23561f..418840ea9 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -117,7 +117,7 @@ type Impl struct { // is a local (non-subnet) Tailscale IP address of this // machine. It's always a non-nil func. It's changed on netmap // updates. - atomicIsLocalIPFunc atomic.Value // of func(netip.Addr) bool + atomicIsLocalIPFunc syncs.AtomicValue[func(netip.Addr) bool] mu sync.Mutex // connsOpenBySubnetIP keeps track of number of connections open @@ -513,7 +513,7 @@ func (ns *Impl) inject() { // isLocalIP reports whether ip is a Tailscale IP assigned to this // node directly (but not a subnet-routed IP). func (ns *Impl) isLocalIP(ip netip.Addr) bool { - return ns.atomicIsLocalIPFunc.Load().(func(netip.Addr) bool)(ip) + return ns.atomicIsLocalIPFunc.Load()(ip) } func (ns *Impl) processSSH() bool { diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 949806c7b..2ec61c359 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -17,7 +17,6 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "time" "go4.org/mem" @@ -36,6 +35,7 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/net/tshttpproxy" "tailscale.com/net/tstun" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime/mono" "tailscale.com/types/dnstype" @@ -107,11 +107,11 @@ type userspaceEngine struct { // isLocalAddr reports the whether an IP is assigned to the local // tunnel interface. It's used to reflect local packets // incorrectly sent to us. - isLocalAddr atomic.Value // of func(netip.Addr)bool + isLocalAddr syncs.AtomicValue[func(netip.Addr) bool] // isDNSIPOverTailscale reports the whether a DNS resolver's IP // is being routed over Tailscale. - isDNSIPOverTailscale atomic.Value // of func(netip.Addr)bool + isDNSIPOverTailscale syncs.AtomicValue[func(netip.Addr) bool] wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below lastCfgFull wgcfg.Config @@ -497,7 +497,7 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) } if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { - isLocalAddr, ok := e.isLocalAddr.Load().(func(netip.Addr) bool) + isLocalAddr, ok := e.isLocalAddr.LoadOk() if !ok { e.logf("[unexpected] e.isLocalAddr was nil, can't check for loopback packet") } else if isLocalAddr(p.Dst.Addr()) { @@ -1621,7 +1621,7 @@ type fwdDNSLinkSelector struct { } func (ls fwdDNSLinkSelector) PickLink(ip netip.Addr) (linkName string) { - if ls.ue.isDNSIPOverTailscale.Load().(func(netip.Addr) bool)(ip) { + if ls.ue.isDNSIPOverTailscale.Load()(ip) { return ls.tunName } return "" diff --git a/wgengine/wglog/wglog.go b/wgengine/wglog/wglog.go index e78216e2b..779d1df15 100644 --- a/wgengine/wglog/wglog.go +++ b/wgengine/wglog/wglog.go @@ -9,9 +9,9 @@ import ( "fmt" "strings" "sync" - "sync/atomic" "golang.zx2c4.com/wireguard/device" + "tailscale.com/syncs" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/wgengine/wgcfg" @@ -21,7 +21,7 @@ import ( // It can be modified at run time to adjust to new wireguard-go configurations. type Logger struct { DeviceLogger *device.Logger - replace atomic.Value // of map[string]string + replace syncs.AtomicValue[map[string]string] mu sync.Mutex // protects strs strs map[key.NodePublic]*strCache // cached strs used to populate replace } @@ -52,7 +52,7 @@ func NewLogger(logf logger.Logf) *Logger { // See https://github.com/tailscale/tailscale/issues/1388. return } - replace, _ := ret.replace.Load().(map[string]string) + replace := ret.replace.Load() if replace == nil { // No replacements specified; log as originally planned. logf(format, args...)