diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 5dfdb00fb..a4ddd33b7 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -2826,7 +2826,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate if debugAlwaysDERP { c.logf("disabled %v per TS_DEBUG_ALWAYS_USE_DERP", network) - ruc.pconn = newBlockForeverConn() + ruc.setConnLocked(newBlockForeverConn()) return nil } @@ -2860,7 +2860,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate continue } // Success. - ruc.pconn = pconn + ruc.setConnLocked(pconn) if network == "udp4" { health.SetUDP4Unbound(false) } @@ -2871,7 +2871,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate // Set pconn to a dummy conn whose reads block until closed. // This keeps the receive funcs alive for a future in which // we get a link change and we can try binding again. - ruc.pconn = newBlockForeverConn() + ruc.setConnLocked(newBlockForeverConn()) if network == "udp4" { health.SetUDP4Unbound(true) } @@ -2974,11 +2974,26 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) { // RebindingUDPConn is a UDP socket that can be re-bound. // Unix has no notion of re-binding a socket, so we swap it out for a new one. type RebindingUDPConn struct { - mu sync.Mutex + // pconnAtomic is the same as pconn, but doesn't require acquiring mu. It's + // used for reads/writes and only upon failure do the reads/writes then + // check pconn (after acquiring mu) to see if there's been a rebind + // meanwhile. + // pconn isn't really needed, but makes some of the code simpler + // to keep it in a type safe form. TODO(bradfitz): really we should make a generic + // atomic.Value. Unfortunately Go 1.19's atomic.Pointer[T] is only for pointers, + // not interfaces. + pconnAtomic atomic.Value // of nettype.PacketConn + + mu sync.Mutex // held while changing pconn (and pconnAtomic) pconn nettype.PacketConn } -// currentConn returns c's current pconn. +func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn) { + c.pconn = p + c.pconnAtomic.Store(p) +} + +// currentConn returns c's current pconn, acquiring c.mu in the process. func (c *RebindingUDPConn) currentConn() nettype.PacketConn { c.mu.Lock() defer c.mu.Unlock() @@ -2989,7 +3004,7 @@ func (c *RebindingUDPConn) currentConn() nettype.PacketConn { // It returns the number of bytes copied and the source address. func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { for { - pconn := c.currentConn() + pconn := c.pconnAtomic.Load().(nettype.PacketConn) n, addr, err := pconn.ReadFrom(b) if err != nil && pconn != c.currentConn() { continue @@ -3007,7 +3022,7 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { // when c's underlying connection is a net.UDPConn. func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netip.AddrPort, err error) { for { - pconn := c.currentConn() + pconn := c.pconnAtomic.Load().(nettype.PacketConn) // Optimization: Treat *net.UDPConn specially. // This lets us avoid allocations by calling ReadFromUDPAddrPort. @@ -3066,17 +3081,11 @@ func (c *RebindingUDPConn) closeLocked() error { func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { for { - c.mu.Lock() - pconn := c.pconn - c.mu.Unlock() + pconn := c.pconnAtomic.Load().(nettype.PacketConn) n, err := pconn.WriteTo(b, addr) if err != nil { - c.mu.Lock() - pconn2 := c.pconn - c.mu.Unlock() - - if pconn != pconn2 { + if pconn != c.currentConn() { continue } } @@ -3086,17 +3095,11 @@ func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { for { - c.mu.Lock() - pconn := c.pconn - c.mu.Unlock() + pconn := c.pconnAtomic.Load().(nettype.PacketConn) n, err := pconn.WriteToUDPAddrPort(b, addr) if err != nil { - c.mu.Lock() - pconn2 := c.pconn - c.mu.Unlock() - - if pconn != pconn2 { + if pconn != c.currentConn() { continue } }