diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index fc3fea48f..17dd73618 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -2630,28 +2630,38 @@ func (c *Conn) Rebind() { host = "127.0.0.1" } listenCtx := context.Background() // unused without DNS name to resolve + if c.port != 0 { c.pconn4.mu.Lock() + oldPort := c.pconn4.localAddrLocked().Port if err := c.pconn4.pconn.Close(); err != nil { c.logf("magicsock: link change close failed: %v", err) } - packetConn, err := c.listenPacket(listenCtx, "udp4", fmt.Sprintf("%s:%d", host, c.port)) - if err == nil { + packetConn, err := c.listenPacket(listenCtx, "udp4", net.JoinHostPort(host, fmt.Sprint(c.port))) + if err != nil { + c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.port, err) + packetConn, err = c.listenPacket(listenCtx, "udp4", net.JoinHostPort(host, "0")) + if err != nil { + c.logf("magicsock: link change failed to bind random port: %v", err) + c.pconn4.mu.Unlock() + return + } + newPort := c.pconn4.localAddrLocked().Port + c.logf("magicsock: link change rebound port: from %v to %v (failed to get %v)", oldPort, newPort, c.port) + } else { c.logf("magicsock: link change rebound port: %d", c.port) - c.pconn4.pconn = packetConn.(*net.UDPConn) - c.pconn4.mu.Unlock() - return } - c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.port, err) + c.pconn4.pconn = packetConn.(*net.UDPConn) c.pconn4.mu.Unlock() + } else { + c.logf("magicsock: link change, binding new port") + packetConn, err := c.listenPacket(listenCtx, "udp4", host+":0") + if err != nil { + c.logf("magicsock: link change failed to bind new port: %v", err) + return + } + c.pconn4.Reset(packetConn.(*net.UDPConn)) } - c.logf("magicsock: link change, binding new port") - packetConn, err := c.listenPacket(listenCtx, "udp4", host+":0") - if err != nil { - c.logf("magicsock: link change failed to bind new port: %v", err) - return - } - c.pconn4.Reset(packetConn.(*net.UDPConn)) c.portMapper.SetLocalPort(c.LocalPort()) c.mu.Lock() @@ -2833,6 +2843,10 @@ func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netaddr.IPPort, func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr { c.mu.Lock() defer c.mu.Unlock() + return c.localAddrLocked() +} + +func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr { return c.pconn.LocalAddr().(*net.UDPAddr) } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 914533a73..8e64a2696 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1662,6 +1662,15 @@ func BenchmarkReceiveFrom_Native(b *testing.B) { } } +func logBufWriter(buf *bytes.Buffer) logger.Logf { + return func(format string, a ...interface{}) { + fmt.Fprintf(buf, format, a...) + if !bytes.HasSuffix(buf.Bytes(), []byte("\n")) { + buf.WriteByte('\n') + } + } +} + // Test that a netmap update where node changes its node key but // doesn't change its disco key doesn't result in a broken state. // @@ -1670,12 +1679,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) { conn := newNonLegacyTestConn(t) t.Cleanup(func() { conn.Close() }) var logBuf bytes.Buffer - conn.logf = func(format string, a ...interface{}) { - fmt.Fprintf(&logBuf, format, a...) - if !bytes.HasSuffix(logBuf.Bytes(), []byte("\n")) { - logBuf.WriteByte('\n') - } - } + conn.logf = logBufWriter(&logBuf) conn.SetPrivateKey(wgkey.Private{0: 1}) @@ -1729,3 +1733,63 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) { t.Logf("log output: %s", log) } } + +func TestRebindStress(t *testing.T) { + conn := newNonLegacyTestConn(t) + + var logBuf bytes.Buffer + conn.logf = logBufWriter(&logBuf) + + closed := false + t.Cleanup(func() { + if !closed { + conn.Close() + } + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errc := make(chan error, 1) + go func() { + buf := make([]byte, 1500) + for { + _, _, err := conn.ReceiveIPv4(buf) + if ctx.Err() != nil { + errc <- nil + return + } + if err != nil { + errc <- err + return + } + } + }() + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < 2000; i++ { + conn.Rebind() + } + }() + go func() { + defer wg.Done() + for i := 0; i < 2000; i++ { + conn.Rebind() + } + }() + wg.Wait() + + cancel() + if err := conn.Close(); err != nil { + t.Fatal(err) + } + closed = true + + err := <-errc + if err != nil { + t.Fatalf("Got ReceiveIPv4 error: %v (is closed = %v). Log:\n%s", err, errors.Is(err, net.ErrClosed), logBuf.Bytes()) + } +}