wgengine/magicksock: remove nullability of RebindingUDPConns

Both RebindingUDPConns now always exist. the initial bind (which now
just calls rebind) now ensures that bind is called for both, such that
they both at least contain a blockForeverConn. Calling code no longer
needs to assert their state.

Signed-off-by: James Tucker <james@tailscale.com>
pull/5535/head
James Tucker 2 years ago committed by James Tucker
parent 56f6fe204b
commit 1f959edeb0

@ -262,8 +262,8 @@ type Conn struct {
// pconn4 and pconn6 are the underlying UDP sockets used to // pconn4 and pconn6 are the underlying UDP sockets used to
// send/receive packets for wireguard and other magicsock // send/receive packets for wireguard and other magicsock
// protocols. // protocols.
pconn4 *RebindingUDPConn pconn4 RebindingUDPConn
pconn6 *RebindingUDPConn pconn6 RebindingUDPConn
// closeDisco4 and closeDisco6 are io.Closers to shut down the raw // closeDisco4 and closeDisco6 are io.Closers to shut down the raw
// disco packet receivers. If nil, no raw disco receiver is // disco packet receivers. If nil, no raw disco receiver is
@ -570,7 +570,7 @@ func NewConn(opts Options) (*Conn, error) {
} }
c.linkMon = opts.LinkMonitor c.linkMon = opts.LinkMonitor
if err := c.initialBind(); err != nil { if err := c.rebind(keepCurrentPort); err != nil {
return nil, err return nil, err
} }
@ -578,15 +578,12 @@ func NewConn(opts Options) (*Conn, error) {
c.donec = c.connCtx.Done() c.donec = c.connCtx.Done()
c.netChecker = &netcheck.Client{ c.netChecker = &netcheck.Client{
Logf: logger.WithPrefix(c.logf, "netcheck: "), Logf: logger.WithPrefix(c.logf, "netcheck: "),
GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 }, GetSTUNConn4: func() netcheck.STUNConn { return &c.pconn4 },
GetSTUNConn6: func() netcheck.STUNConn { return &c.pconn6 },
SkipExternalNetwork: inTest(), SkipExternalNetwork: inTest(),
PortMapper: c.portMapper, PortMapper: c.portMapper,
} }
if c.pconn6 != nil {
c.netChecker.GetSTUNConn6 = func() netcheck.STUNConn { return c.pconn6 }
}
c.ignoreSTUNPackets() c.ignoreSTUNPackets()
if d4, err := c.listenRawDisco("ip4"); err == nil { if d4, err := c.listenRawDisco("ip4"); err == nil {
@ -1240,10 +1237,6 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error)
return false, nil return false, nil
} }
case addr.Addr().Is6(): case addr.Addr().Is6():
if c.pconn6 == nil {
// ignore IPv6 dest if we don't have an IPv6 address.
return false, nil
}
_, err = c.pconn6.WriteToUDPAddrPort(b, addr) _, err = c.pconn6.WriteToUDPAddrPort(b, addr)
if err != nil && (c.noV6.Load() || neterror.TreatAsLostUDP(err)) { if err != nil && (c.noV6.Load() || neterror.TreatAsLostUDP(err)) {
return false, nil return false, nil
@ -2660,12 +2653,8 @@ func (c *connBind) Close() error {
} }
c.closed = true c.closed = true
// Unblock all outstanding receives. // Unblock all outstanding receives.
if c.pconn4 != nil { c.pconn4.Close()
c.pconn4.Close() c.pconn6.Close()
}
if c.pconn6 != nil {
c.pconn6.Close()
}
if c.closeDisco4 != nil { if c.closeDisco4 != nil {
c.closeDisco4.Close() c.closeDisco4.Close()
} }
@ -2710,12 +2699,8 @@ func (c *Conn) Close() error {
c.closeAllDerpLocked("conn-close") c.closeAllDerpLocked("conn-close")
// Ignore errors from c.pconnN.Close. // Ignore errors from c.pconnN.Close.
// They will frequently have been closed already by a call to connBind.Close. // They will frequently have been closed already by a call to connBind.Close.
if c.pconn6 != nil { c.pconn6.Close()
c.pconn6.Close() c.pconn4.Close()
}
if c.pconn4 != nil {
c.pconn4.Close()
}
// Wait on goroutines updating right at the end, once everything is // Wait on goroutines updating right at the end, once everything is
// already closed. We want everything else in the Conn to be // already closed. We want everything else in the Conn to be
@ -2821,20 +2806,6 @@ func (c *Conn) ReSTUN(why string) {
} }
} }
func (c *Conn) initialBind() error {
if runtime.GOOS == "js" {
return nil
}
if err := c.bindSocket(&c.pconn4, "udp4", keepCurrentPort); err != nil {
return fmt.Errorf("magicsock: initialBind IPv4 failed: %w", err)
}
c.portMapper.SetLocalPort(c.LocalPort())
if err := c.bindSocket(&c.pconn6, "udp6", keepCurrentPort); err != nil {
c.logf("magicsock: ignoring IPv6 bind failure: %v", err)
}
return nil
}
// listenPacket opens a packet listener. // listenPacket opens a packet listener.
// The network must be "udp4" or "udp6". // The network must be "udp4" or "udp6".
func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, error) { func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, error) {
@ -2852,12 +2823,7 @@ func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, er
// The caller is responsible for informing the portMapper of any changes. // The caller is responsible for informing the portMapper of any changes.
// If curPortFate is set to dropCurrentPort, no attempt is made to reuse // If curPortFate is set to dropCurrentPort, no attempt is made to reuse
// the current port. // the current port.
func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate currentPortFate) error { func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate currentPortFate) error {
if *rucPtr == nil {
*rucPtr = new(RebindingUDPConn)
}
ruc := *rucPtr
// Hold the ruc lock the entire time, so that the close+bind is atomic // Hold the ruc lock the entire time, so that the close+bind is atomic
// from the perspective of ruc receive functions. // from the perspective of ruc receive functions.
ruc.mu.Lock() ruc.mu.Lock()
@ -2930,13 +2896,13 @@ func (c *Conn) rebind(curPortFate currentPortFate) error {
if runtime.GOOS == "js" { if runtime.GOOS == "js" {
return nil return nil
} }
if err := c.bindSocket(&c.pconn6, "udp6", curPortFate); err != nil {
c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err)
}
if err := c.bindSocket(&c.pconn4, "udp4", curPortFate); err != nil { if err := c.bindSocket(&c.pconn4, "udp4", curPortFate); err != nil {
return fmt.Errorf("magicsock: Rebind IPv4 failed: %w", err) return fmt.Errorf("magicsock: Rebind IPv4 failed: %w", err)
} }
c.portMapper.SetLocalPort(c.LocalPort()) c.portMapper.SetLocalPort(c.LocalPort())
if err := c.bindSocket(&c.pconn6, "udp6", curPortFate); err != nil {
c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err)
}
return nil return nil
} }

Loading…
Cancel
Save