wgengine/magicsock: fix Conn.Rebind race that let ErrClosed errors be read

There was a logical race where Conn.Rebind could acquire the
RebindingUDPConn mutex, close the connection, fail to rebind, release
the mutex, and then because the mutex was no longer held, ReceiveIPv4
wouldn't retry reads that failed with net.ErrClosed, letting that
error back to wireguard-go, which would then stop running that receive
IP goroutine.

Instead, keep the RebindingUDPConn mutex held for the entirety of the
replacement in all cases.

Updates tailscale/corp#1289

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/1472/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent fee74e7ea7
commit 387e83c8fe

@ -2630,28 +2630,38 @@ func (c *Conn) Rebind() {
host = "127.0.0.1"
}
listenCtx := context.Background() // unused without DNS name to resolve
if c.port != 0 {
c.pconn4.mu.Lock()
oldPort := c.pconn4.localAddrLocked().Port
if err := c.pconn4.pconn.Close(); err != nil {
c.logf("magicsock: link change close failed: %v", err)
}
packetConn, err := c.listenPacket(listenCtx, "udp4", fmt.Sprintf("%s:%d", host, c.port))
if err == nil {
packetConn, err := c.listenPacket(listenCtx, "udp4", net.JoinHostPort(host, fmt.Sprint(c.port)))
if err != nil {
c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.port, err)
packetConn, err = c.listenPacket(listenCtx, "udp4", net.JoinHostPort(host, "0"))
if err != nil {
c.logf("magicsock: link change failed to bind random port: %v", err)
c.pconn4.mu.Unlock()
return
}
newPort := c.pconn4.localAddrLocked().Port
c.logf("magicsock: link change rebound port: from %v to %v (failed to get %v)", oldPort, newPort, c.port)
} else {
c.logf("magicsock: link change rebound port: %d", c.port)
c.pconn4.pconn = packetConn.(*net.UDPConn)
c.pconn4.mu.Unlock()
return
}
c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.port, err)
c.pconn4.pconn = packetConn.(*net.UDPConn)
c.pconn4.mu.Unlock()
} else {
c.logf("magicsock: link change, binding new port")
packetConn, err := c.listenPacket(listenCtx, "udp4", host+":0")
if err != nil {
c.logf("magicsock: link change failed to bind new port: %v", err)
return
}
c.pconn4.Reset(packetConn.(*net.UDPConn))
}
c.logf("magicsock: link change, binding new port")
packetConn, err := c.listenPacket(listenCtx, "udp4", host+":0")
if err != nil {
c.logf("magicsock: link change failed to bind new port: %v", err)
return
}
c.pconn4.Reset(packetConn.(*net.UDPConn))
c.portMapper.SetLocalPort(c.LocalPort())
c.mu.Lock()
@ -2833,6 +2843,10 @@ func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netaddr.IPPort,
func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr {
c.mu.Lock()
defer c.mu.Unlock()
return c.localAddrLocked()
}
func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr {
return c.pconn.LocalAddr().(*net.UDPAddr)
}

@ -1662,6 +1662,15 @@ func BenchmarkReceiveFrom_Native(b *testing.B) {
}
}
func logBufWriter(buf *bytes.Buffer) logger.Logf {
return func(format string, a ...interface{}) {
fmt.Fprintf(buf, format, a...)
if !bytes.HasSuffix(buf.Bytes(), []byte("\n")) {
buf.WriteByte('\n')
}
}
}
// Test that a netmap update where node changes its node key but
// doesn't change its disco key doesn't result in a broken state.
//
@ -1670,12 +1679,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) {
conn := newNonLegacyTestConn(t)
t.Cleanup(func() { conn.Close() })
var logBuf bytes.Buffer
conn.logf = func(format string, a ...interface{}) {
fmt.Fprintf(&logBuf, format, a...)
if !bytes.HasSuffix(logBuf.Bytes(), []byte("\n")) {
logBuf.WriteByte('\n')
}
}
conn.logf = logBufWriter(&logBuf)
conn.SetPrivateKey(wgkey.Private{0: 1})
@ -1729,3 +1733,63 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) {
t.Logf("log output: %s", log)
}
}
func TestRebindStress(t *testing.T) {
conn := newNonLegacyTestConn(t)
var logBuf bytes.Buffer
conn.logf = logBufWriter(&logBuf)
closed := false
t.Cleanup(func() {
if !closed {
conn.Close()
}
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errc := make(chan error, 1)
go func() {
buf := make([]byte, 1500)
for {
_, _, err := conn.ReceiveIPv4(buf)
if ctx.Err() != nil {
errc <- nil
return
}
if err != nil {
errc <- err
return
}
}
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for i := 0; i < 2000; i++ {
conn.Rebind()
}
}()
go func() {
defer wg.Done()
for i := 0; i < 2000; i++ {
conn.Rebind()
}
}()
wg.Wait()
cancel()
if err := conn.Close(); err != nil {
t.Fatal(err)
}
closed = true
err := <-errc
if err != nil {
t.Fatalf("Got ReceiveIPv4 error: %v (is closed = %v). Log:\n%s", err, errors.Is(err, net.ErrClosed), logBuf.Bytes())
}
}

Loading…
Cancel
Save