diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 32d531b56..6cbe3c7ee 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -262,8 +262,8 @@ type Conn struct { // pconn4 and pconn6 are the underlying UDP sockets used to // send/receive packets for wireguard and other magicsock // protocols. - pconn4 *RebindingUDPConn - pconn6 *RebindingUDPConn + pconn4 RebindingUDPConn + pconn6 RebindingUDPConn // closeDisco4 and closeDisco6 are io.Closers to shut down the raw // disco packet receivers. If nil, no raw disco receiver is @@ -570,7 +570,7 @@ func NewConn(opts Options) (*Conn, error) { } c.linkMon = opts.LinkMonitor - if err := c.initialBind(); err != nil { + if err := c.rebind(keepCurrentPort); err != nil { return nil, err } @@ -578,15 +578,12 @@ func NewConn(opts Options) (*Conn, error) { c.donec = c.connCtx.Done() c.netChecker = &netcheck.Client{ Logf: logger.WithPrefix(c.logf, "netcheck: "), - GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 }, + GetSTUNConn4: func() netcheck.STUNConn { return &c.pconn4 }, + GetSTUNConn6: func() netcheck.STUNConn { return &c.pconn6 }, SkipExternalNetwork: inTest(), PortMapper: c.portMapper, } - if c.pconn6 != nil { - c.netChecker.GetSTUNConn6 = func() netcheck.STUNConn { return c.pconn6 } - } - c.ignoreSTUNPackets() if d4, err := c.listenRawDisco("ip4"); err == nil { @@ -1240,10 +1237,6 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) return false, nil } case addr.Addr().Is6(): - if c.pconn6 == nil { - // ignore IPv6 dest if we don't have an IPv6 address. - return false, nil - } _, err = c.pconn6.WriteToUDPAddrPort(b, addr) if err != nil && (c.noV6.Load() || neterror.TreatAsLostUDP(err)) { return false, nil @@ -2660,12 +2653,8 @@ func (c *connBind) Close() error { } c.closed = true // Unblock all outstanding receives. - if c.pconn4 != nil { - c.pconn4.Close() - } - if c.pconn6 != nil { - c.pconn6.Close() - } + c.pconn4.Close() + c.pconn6.Close() if c.closeDisco4 != nil { c.closeDisco4.Close() } @@ -2710,12 +2699,8 @@ func (c *Conn) Close() error { c.closeAllDerpLocked("conn-close") // Ignore errors from c.pconnN.Close. // They will frequently have been closed already by a call to connBind.Close. - if c.pconn6 != nil { - c.pconn6.Close() - } - if c.pconn4 != nil { - c.pconn4.Close() - } + c.pconn6.Close() + c.pconn4.Close() // Wait on goroutines updating right at the end, once everything is // already closed. We want everything else in the Conn to be @@ -2821,20 +2806,6 @@ func (c *Conn) ReSTUN(why string) { } } -func (c *Conn) initialBind() error { - if runtime.GOOS == "js" { - return nil - } - if err := c.bindSocket(&c.pconn4, "udp4", keepCurrentPort); err != nil { - return fmt.Errorf("magicsock: initialBind IPv4 failed: %w", err) - } - c.portMapper.SetLocalPort(c.LocalPort()) - if err := c.bindSocket(&c.pconn6, "udp6", keepCurrentPort); err != nil { - c.logf("magicsock: ignoring IPv6 bind failure: %v", err) - } - return nil -} - // listenPacket opens a packet listener. // The network must be "udp4" or "udp6". func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, error) { @@ -2852,12 +2823,7 @@ func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, er // The caller is responsible for informing the portMapper of any changes. // If curPortFate is set to dropCurrentPort, no attempt is made to reuse // the current port. -func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate currentPortFate) error { - if *rucPtr == nil { - *rucPtr = new(RebindingUDPConn) - } - ruc := *rucPtr - +func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate currentPortFate) error { // Hold the ruc lock the entire time, so that the close+bind is atomic // from the perspective of ruc receive functions. ruc.mu.Lock() @@ -2930,13 +2896,13 @@ func (c *Conn) rebind(curPortFate currentPortFate) error { if runtime.GOOS == "js" { return nil } + if err := c.bindSocket(&c.pconn6, "udp6", curPortFate); err != nil { + c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err) + } if err := c.bindSocket(&c.pconn4, "udp4", curPortFate); err != nil { return fmt.Errorf("magicsock: Rebind IPv4 failed: %w", err) } c.portMapper.SetLocalPort(c.LocalPort()) - if err := c.bindSocket(&c.pconn6, "udp6", curPortFate); err != nil { - c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err) - } return nil }