From 1f959edeb0c6ad3a82fcdaa4ca65a02571493cc6 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Fri, 2 Sep 2022 14:24:51 -0700 Subject: [PATCH] wgengine/magicksock: remove nullability of RebindingUDPConns Both RebindingUDPConns now always exist. the initial bind (which now just calls rebind) now ensures that bind is called for both, such that they both at least contain a blockForeverConn. Calling code no longer needs to assert their state. Signed-off-by: James Tucker --- wgengine/magicsock/magicsock.go | 60 +++++++-------------------------- 1 file changed, 13 insertions(+), 47 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 32d531b56..6cbe3c7ee 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -262,8 +262,8 @@ type Conn struct { // pconn4 and pconn6 are the underlying UDP sockets used to // send/receive packets for wireguard and other magicsock // protocols. - pconn4 *RebindingUDPConn - pconn6 *RebindingUDPConn + pconn4 RebindingUDPConn + pconn6 RebindingUDPConn // closeDisco4 and closeDisco6 are io.Closers to shut down the raw // disco packet receivers. If nil, no raw disco receiver is @@ -570,7 +570,7 @@ func NewConn(opts Options) (*Conn, error) { } c.linkMon = opts.LinkMonitor - if err := c.initialBind(); err != nil { + if err := c.rebind(keepCurrentPort); err != nil { return nil, err } @@ -578,15 +578,12 @@ func NewConn(opts Options) (*Conn, error) { c.donec = c.connCtx.Done() c.netChecker = &netcheck.Client{ Logf: logger.WithPrefix(c.logf, "netcheck: "), - GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 }, + GetSTUNConn4: func() netcheck.STUNConn { return &c.pconn4 }, + GetSTUNConn6: func() netcheck.STUNConn { return &c.pconn6 }, SkipExternalNetwork: inTest(), PortMapper: c.portMapper, } - if c.pconn6 != nil { - c.netChecker.GetSTUNConn6 = func() netcheck.STUNConn { return c.pconn6 } - } - c.ignoreSTUNPackets() if d4, err := c.listenRawDisco("ip4"); err == nil { @@ -1240,10 +1237,6 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) return false, nil } case addr.Addr().Is6(): - if c.pconn6 == nil { - // ignore IPv6 dest if we don't have an IPv6 address. - return false, nil - } _, err = c.pconn6.WriteToUDPAddrPort(b, addr) if err != nil && (c.noV6.Load() || neterror.TreatAsLostUDP(err)) { return false, nil @@ -2660,12 +2653,8 @@ func (c *connBind) Close() error { } c.closed = true // Unblock all outstanding receives. - if c.pconn4 != nil { - c.pconn4.Close() - } - if c.pconn6 != nil { - c.pconn6.Close() - } + c.pconn4.Close() + c.pconn6.Close() if c.closeDisco4 != nil { c.closeDisco4.Close() } @@ -2710,12 +2699,8 @@ func (c *Conn) Close() error { c.closeAllDerpLocked("conn-close") // Ignore errors from c.pconnN.Close. // They will frequently have been closed already by a call to connBind.Close. - if c.pconn6 != nil { - c.pconn6.Close() - } - if c.pconn4 != nil { - c.pconn4.Close() - } + c.pconn6.Close() + c.pconn4.Close() // Wait on goroutines updating right at the end, once everything is // already closed. We want everything else in the Conn to be @@ -2821,20 +2806,6 @@ func (c *Conn) ReSTUN(why string) { } } -func (c *Conn) initialBind() error { - if runtime.GOOS == "js" { - return nil - } - if err := c.bindSocket(&c.pconn4, "udp4", keepCurrentPort); err != nil { - return fmt.Errorf("magicsock: initialBind IPv4 failed: %w", err) - } - c.portMapper.SetLocalPort(c.LocalPort()) - if err := c.bindSocket(&c.pconn6, "udp6", keepCurrentPort); err != nil { - c.logf("magicsock: ignoring IPv6 bind failure: %v", err) - } - return nil -} - // listenPacket opens a packet listener. // The network must be "udp4" or "udp6". func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, error) { @@ -2852,12 +2823,7 @@ func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, er // The caller is responsible for informing the portMapper of any changes. // If curPortFate is set to dropCurrentPort, no attempt is made to reuse // the current port. -func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate currentPortFate) error { - if *rucPtr == nil { - *rucPtr = new(RebindingUDPConn) - } - ruc := *rucPtr - +func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate currentPortFate) error { // Hold the ruc lock the entire time, so that the close+bind is atomic // from the perspective of ruc receive functions. ruc.mu.Lock() @@ -2930,13 +2896,13 @@ func (c *Conn) rebind(curPortFate currentPortFate) error { if runtime.GOOS == "js" { return nil } + if err := c.bindSocket(&c.pconn6, "udp6", curPortFate); err != nil { + c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err) + } if err := c.bindSocket(&c.pconn4, "udp4", curPortFate); err != nil { return fmt.Errorf("magicsock: Rebind IPv4 failed: %w", err) } c.portMapper.SetLocalPort(c.LocalPort()) - if err := c.bindSocket(&c.pconn6, "udp6", curPortFate); err != nil { - c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err) - } return nil }