wgengine/magicsock: don't discard UDP packet on UDP+DERP race

Fixes #155

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/164/head
Brad Fitzpatrick 5 years ago committed by Brad Fitzpatrick
parent 96e0f86263
commit f42b9b6c9a

@ -52,6 +52,12 @@ type Conn struct {
logf func(format string, args ...interface{}) logf func(format string, args ...interface{})
sendLogLimit *rate.Limiter sendLogLimit *rate.Limiter
// bufferedIPv4From and bufferedIPv4Packet are owned by
// ReceiveIPv4, and used when both a DERP and IPv4 packet arrive
// at the same time. It stores the IPv4 packet for use in the next call.
bufferedIPv4From *net.UDPAddr // if non-nil, then bufferedIPv4Packet is valid
bufferedIPv4Packet []byte // the received packet (reused, owned by ReceiveIPv4)
connCtx context.Context // closed on Conn.Close connCtx context.Context // closed on Conn.Close
connCtxCancel func() // closes connCtx connCtxCancel func() // closes connCtx
@ -805,6 +811,19 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
} }
} }
// findEndpoint maps from a UDP address to a WireGuard endpoint, for
// ReceiveIPv4/ReceiveIPv6.
func (c *Conn) findEndpoint(addr *net.UDPAddr) conn.Endpoint {
if as := c.findAddrSet(addr); as != nil {
return as
}
// The peer that sent this packet has roamed beyond the
// knowledge provided by the control server.
// If the packet is valid wireguard will call UpdateDst
// on the original endpoint using this addr.
return (*singleEndpoint)(addr)
}
func (c *Conn) findAddrSet(addr *net.UDPAddr) *AddrSet { func (c *Conn) findAddrSet(addr *net.UDPAddr) *AddrSet {
var epAddr udpAddr var epAddr udpAddr
copy(epAddr.ip.Addr[:], addr.IP.To16()) copy(epAddr.ip.Addr[:], addr.IP.To16())
@ -827,6 +846,12 @@ type udpReadResult struct {
var aLongTimeAgo = time.Unix(233431200, 0) var aLongTimeAgo = time.Unix(233431200, 0)
func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) { func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) {
// First, process any buffered packet from earlier.
if addr := c.bufferedIPv4From; addr != nil {
c.bufferedIPv4From = nil
return copy(b, c.bufferedIPv4Packet), c.findEndpoint(addr), addr, nil
}
go func() { go func() {
// Read a packet, and process any STUN packets before returning. // Read a packet, and process any STUN packets before returning.
for { for {
@ -863,21 +888,21 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
case dm := <-c.derpRecvCh: case dm := <-c.derpRecvCh:
// Cancel the pconn read goroutine // Cancel the pconn read goroutine
c.pconn.SetReadDeadline(aLongTimeAgo) c.pconn.SetReadDeadline(aLongTimeAgo)
// Wait for the UDP-reading goroutine to be done, since it's currently
// the owner of the b []byte buffer:
select { select {
case <-c.udpRecvCh: case um := <-c.udpRecvCh:
// It's likely an error, since we just canceled the read. if um.err != nil {
// But there's a small window where the pconn.ReadFrom // The normal case. The SetReadDeadline interrupted
// could've succeeded but not yet sent, and we got into // the read and we get an error which we now ignore.
// the derp recv path first. In that case this } else {
// udpReadResult is a real non-err packet and we need to // The pconn.ReadFrom succeeded and was about to send,
// choose which to use. Currently, arbitrarily, we // but DERP sent first. So now we have both ready.
// currently select DERP and discard this result entirely. // Save the UDP packet away for use by the next
// // ReceiveIPv4 call.
// TODO(danderson): don't just discard packets here, it c.bufferedIPv4From = um.addr
// makes the stack unreliable and harder to test. c.bufferedIPv4Packet = append(c.bufferedIPv4Packet[:0], b[:um.n]...)
// }
// The main point of this receive, though, is to make sure
// that the goroutine is done with our b []byte buf.
c.pconn.SetReadDeadline(time.Time{}) c.pconn.SetReadDeadline(time.Time{})
case <-c.donec(): case <-c.donec():
return 0, nil, nil, errors.New("Conn closed") return 0, nil, nil, errors.New("Conn closed")
@ -919,17 +944,12 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
return 0, nil, nil, errors.New("socket closed") return 0, nil, nil, errors.New("socket closed")
} }
if addrSet == nil { if addrSet != nil {
addrSet = c.findAddrSet(addr) ep = addrSet
} } else {
if addrSet == nil { ep = c.findEndpoint(addr)
// The peer that sent this packet has roamed beyond the
// knowledge provided by the control server.
// If the packet is valid wireguard will call UpdateDst
// on the original endpoint using this addr.
return n, (*singleEndpoint)(addr), addr, nil
} }
return n, addrSet, addr, nil return n, ep, addr, nil
} }
func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error) { func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error) {

Loading…
Cancel
Save