diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 3e320a196..2f3c47bc2 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -52,6 +52,12 @@ type Conn struct { logf func(format string, args ...interface{}) sendLogLimit *rate.Limiter + // bufferedIPv4From and bufferedIPv4Packet are owned by + // ReceiveIPv4, and used when both a DERP and IPv4 packet arrive + // at the same time. It stores the IPv4 packet for use in the next call. + bufferedIPv4From *net.UDPAddr // if non-nil, then bufferedIPv4Packet is valid + bufferedIPv4Packet []byte // the received packet (reused, owned by ReceiveIPv4) + connCtx context.Context // closed on Conn.Close connCtxCancel func() // closes connCtx @@ -805,6 +811,19 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc } } +// findEndpoint maps from a UDP address to a WireGuard endpoint, for +// ReceiveIPv4/ReceiveIPv6. +func (c *Conn) findEndpoint(addr *net.UDPAddr) conn.Endpoint { + if as := c.findAddrSet(addr); as != nil { + return as + } + // The peer that sent this packet has roamed beyond the + // knowledge provided by the control server. + // If the packet is valid wireguard will call UpdateDst + // on the original endpoint using this addr. + return (*singleEndpoint)(addr) +} + func (c *Conn) findAddrSet(addr *net.UDPAddr) *AddrSet { var epAddr udpAddr copy(epAddr.ip.Addr[:], addr.IP.To16()) @@ -827,6 +846,12 @@ type udpReadResult struct { var aLongTimeAgo = time.Unix(233431200, 0) func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) { + // First, process any buffered packet from earlier. + if addr := c.bufferedIPv4From; addr != nil { + c.bufferedIPv4From = nil + return copy(b, c.bufferedIPv4Packet), c.findEndpoint(addr), addr, nil + } + go func() { // Read a packet, and process any STUN packets before returning. for { @@ -863,21 +888,21 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr case dm := <-c.derpRecvCh: // Cancel the pconn read goroutine c.pconn.SetReadDeadline(aLongTimeAgo) + // Wait for the UDP-reading goroutine to be done, since it's currently + // the owner of the b []byte buffer: select { - case <-c.udpRecvCh: - // It's likely an error, since we just canceled the read. - // But there's a small window where the pconn.ReadFrom - // could've succeeded but not yet sent, and we got into - // the derp recv path first. In that case this - // udpReadResult is a real non-err packet and we need to - // choose which to use. Currently, arbitrarily, we - // currently select DERP and discard this result entirely. - // - // TODO(danderson): don't just discard packets here, it - // makes the stack unreliable and harder to test. - // - // The main point of this receive, though, is to make sure - // that the goroutine is done with our b []byte buf. + case um := <-c.udpRecvCh: + if um.err != nil { + // The normal case. The SetReadDeadline interrupted + // the read and we get an error which we now ignore. + } else { + // The pconn.ReadFrom succeeded and was about to send, + // but DERP sent first. So now we have both ready. + // Save the UDP packet away for use by the next + // ReceiveIPv4 call. + c.bufferedIPv4From = um.addr + c.bufferedIPv4Packet = append(c.bufferedIPv4Packet[:0], b[:um.n]...) + } c.pconn.SetReadDeadline(time.Time{}) case <-c.donec(): return 0, nil, nil, errors.New("Conn closed") @@ -919,17 +944,12 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr return 0, nil, nil, errors.New("socket closed") } - if addrSet == nil { - addrSet = c.findAddrSet(addr) - } - if addrSet == nil { - // The peer that sent this packet has roamed beyond the - // knowledge provided by the control server. - // If the packet is valid wireguard will call UpdateDst - // on the original endpoint using this addr. - return n, (*singleEndpoint)(addr), addr, nil + if addrSet != nil { + ep = addrSet + } else { + ep = c.findEndpoint(addr) } - return n, addrSet, addr, nil + return n, ep, addr, nil } func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error) {