From f475e5550cd92aee0cb81f96e808ae2f3b34ecbf Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 4 Apr 2023 16:32:16 -0700 Subject: [PATCH] net/neterror, wgengine/magicsock: use UDP GSO and GRO on Linux (#7791) This commit implements UDP offloading for Linux. GSO size is passed to and from the kernel via socket control messages. Support is probed at runtime. UDP GSO is dependent on checksum offload support on the egress netdev. UDP GSO will be disabled in the event sendmmsg() returns EIO, which is a strong signal that the egress netdev does not support checksum offload. Updates tailscale/corp#8734 Signed-off-by: Jordan Whited --- net/neterror/neterror.go | 23 + net/neterror/neterror_linux.go | 26 ++ wgengine/magicsock/magicsock.go | 579 ++++++++++++++++-------- wgengine/magicsock/magicsock_default.go | 18 + wgengine/magicsock/magicsock_linux.go | 89 +++- wgengine/magicsock/magicsock_test.go | 239 +++++++++- 6 files changed, 778 insertions(+), 196 deletions(-) create mode 100644 net/neterror/neterror_linux.go diff --git a/net/neterror/neterror.go b/net/neterror/neterror.go index c2d3269d5..e2387440d 100644 --- a/net/neterror/neterror.go +++ b/net/neterror/neterror.go @@ -6,6 +6,7 @@ package neterror import ( "errors" + "fmt" "runtime" "syscall" ) @@ -57,3 +58,25 @@ func PacketWasTruncated(err error) bool { } return packetWasTruncated(err) } + +var shouldDisableUDPGSO func(error) bool // non-nil on Linux + +func ShouldDisableUDPGSO(err error) bool { + if shouldDisableUDPGSO == nil { + return false + } + return shouldDisableUDPGSO(err) +} + +type ErrUDPGSODisabled struct { + OnLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} diff --git a/net/neterror/neterror_linux.go b/net/neterror/neterror_linux.go new file mode 100644 index 000000000..857367fe8 --- /dev/null +++ b/net/neterror/neterror_linux.go @@ -0,0 +1,26 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func init() { + shouldDisableUDPGSO = func(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not + // have tx checksumming enabled, which is a hard requirement of + // UDP_SEGMENT. See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false + } +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index d33cbd244..25ebcc97e 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -44,7 +44,6 @@ import ( "tailscale.com/net/connstats" "tailscale.com/net/dnscache" "tailscale.com/net/interfaces" - "tailscale.com/net/netaddr" "tailscale.com/net/netcheck" "tailscale.com/net/neterror" "tailscale.com/net/netns" @@ -281,7 +280,6 @@ type Conn struct { pconn6 RebindingUDPConn receiveBatchPool sync.Pool - sendBatchPool sync.Pool // closeDisco4 and closeDisco6 are io.Closers to shut down the raw // disco packet receivers. If nil, no raw disco receiver is @@ -597,26 +595,13 @@ func newConn() *Conn { msgs := make([]ipv6.Message, c.bind.BatchSize()) for i := range msgs { msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, controlMessageSize) } batch := &receiveBatch{ msgs: msgs, } return batch }} - c.sendBatchPool = sync.Pool{New: func() any { - ua := &net.UDPAddr{ - IP: make([]byte, 16), - } - msgs := make([]ipv6.Message, c.bind.BatchSize()) - for i := range msgs { - msgs[i].Buffers = make([][]byte, 1) - msgs[i].Addr = ua - } - return &sendBatch{ - ua: ua, - msgs: msgs, - } - }} c.muCond = sync.NewCond(&c.mu) c.networkUp.Store(true) // assume up until told otherwise return c @@ -1301,19 +1286,11 @@ var errNoUDP = errors.New("no UDP available on platform") var ( // This acts as a compile-time check for our usage of ipv6.Message in - // udpConnWithBatchOps for both IPv6 and IPv4 operations. + // batchingUDPConn for both IPv6 and IPv4 operations. _ ipv6.Message = ipv4.Message{} ) -type sendBatch struct { - ua *net.UDPAddr - msgs []ipv6.Message // ipv4.Message and ipv6.Message are the same underlying type -} - func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) { - batch := c.sendBatchPool.Get().(*sendBatch) - defer c.sendBatchPool.Put(batch) - isIPv6 := false switch { case addr.Addr().Is4(): @@ -1322,19 +1299,17 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err default: panic("bogus sendUDPBatch addr type") } - - as16 := addr.Addr().As16() - copy(batch.ua.IP, as16[:]) - batch.ua.Port = int(addr.Port()) - for i, buff := range buffs { - batch.msgs[i].Buffers[0] = buff - batch.msgs[i].Addr = batch.ua - } - if isIPv6 { - _, err = c.pconn6.WriteBatch(batch.msgs[:len(buffs)], 0) + err = c.pconn6.WriteBatchTo(buffs, addr) } else { - _, err = c.pconn4.WriteBatch(batch.msgs[:len(buffs)], 0) + err = c.pconn4.WriteBatchTo(buffs, addr) + } + if err != nil { + var errGSO neterror.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + c.logf("magicsock: %s", errGSO.Error()) + err = errGSO.RetryErr + } } return err == nil, err } @@ -1844,14 +1819,18 @@ type receiveBatch struct { msgs []ipv6.Message } -func (c *Conn) getReceiveBatch() *receiveBatch { +func (c *Conn) getReceiveBatchForBuffs(buffs [][]byte) *receiveBatch { batch := c.receiveBatchPool.Get().(*receiveBatch) + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + batch.msgs[i].OOB = batch.msgs[i].OOB[:cap(batch.msgs[i].OOB)] + } return batch } func (c *Conn) putReceiveBatch(batch *receiveBatch) { for i := range batch.msgs { - batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers} + batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} } c.receiveBatchPool.Put(batch) } @@ -1860,13 +1839,10 @@ func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in health.ReceiveIPv6.Enter() defer health.ReceiveIPv6.Exit() - batch := c.getReceiveBatch() + batch := c.getReceiveBatchForBuffs(buffs) defer c.putReceiveBatch(batch) for { - for i := range buffs { - batch.msgs[i].Buffers[0] = buffs[i] - } - numMsgs, err := c.pconn6.ReadBatch(batch.msgs, 0) + numMsgs, err := c.pconn6.ReadBatch(batch.msgs[:len(buffs)], 0) if err != nil { if neterror.PacketWasTruncated(err) { // TODO(raggi): discuss whether to log? @@ -1877,6 +1853,10 @@ func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in reportToCaller := false for i, msg := range batch.msgs[:numMsgs] { + if msg.N == 0 { + sizes[i] = 0 + continue + } ipp := msg.Addr.(*net.UDPAddr).AddrPort() if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok { metricRecvDataIPv6.Add(1) @@ -1898,13 +1878,10 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in health.ReceiveIPv4.Enter() defer health.ReceiveIPv4.Exit() - batch := c.getReceiveBatch() + batch := c.getReceiveBatchForBuffs(buffs) defer c.putReceiveBatch(batch) for { - for i := range buffs { - batch.msgs[i].Buffers[0] = buffs[i] - } - numMsgs, err := c.pconn4.ReadBatch(batch.msgs, 0) + numMsgs, err := c.pconn4.ReadBatch(batch.msgs[:len(buffs)], 0) if err != nil { if neterror.PacketWasTruncated(err) { // TODO(raggi): discuss whether to log? @@ -1915,6 +1892,10 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in reportToCaller := false for i, msg := range batch.msgs[:numMsgs] { + if msg.N == 0 { + sizes[i] = 0 + continue + } ipp := msg.Addr.(*net.UDPAddr).AddrPort() if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok { metricRecvDataIPv4.Add(1) @@ -1940,7 +1921,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) c.stunReceiveFunc.Load()(b, ipp) return nil, false } - if c.handleDiscoMessage(b, ipp, key.NodePublic{}) { + if c.handleDiscoMessage(b, ipp, key.NodePublic{}, discoRXPathUDP) { return nil, false } if !c.havePrivateKey.Load() { @@ -2005,7 +1986,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en } ipp := netip.AddrPortFrom(derpMagicIPAddr, uint16(regionID)) - if c.handleDiscoMessage(b[:n], ipp, dm.src) { + if c.handleDiscoMessage(b[:n], ipp, dm.src, discoRXPathDERP) { return 0, nil } @@ -2139,6 +2120,14 @@ func discoPcapFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []by return b.Bytes() } +type discoRXPath string + +const ( + discoRXPathUDP discoRXPath = "UDP socket" + discoRXPathDERP discoRXPath = "DERP" + discoRXPathRawSocket discoRXPath = "raw socket" +) + // handleDiscoMessage handles a discovery message and reports whether // msg was a Tailscale inter-node discovery message. // @@ -2153,7 +2142,7 @@ func discoPcapFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []by // src.Port() being the region ID) and the derpNodeSrc will be the node key // it was received from at the DERP layer. derpNodeSrc is zero when received // over UDP. -func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic) (isDiscoMsg bool) { +func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic, via discoRXPath) (isDiscoMsg bool) { const headerLen = len(disco.Magic) + key.DiscoPublicRawLen if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { return false @@ -2174,7 +2163,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } if debugDisco() { - c.logf("magicsock: disco: got disco-looking frame from %v", sender.ShortString()) + c.logf("magicsock: disco: got disco-looking frame from %v via %s", sender.ShortString(), via) } if c.privateKey.IsZero() { // Ignore disco messages when we're stopped. @@ -2210,12 +2199,14 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke // disco key. When we restart we get a new disco key // and old packets might've still been in flight (or // scheduled). This is particularly the case for LANs - // or non-NATed endpoints. - // Don't log in normal case. Pass on to wireguard, in case - // it's actually a wireguard packet (super unlikely, - // but). + // or non-NATed endpoints. UDP offloading on Linux + // can also cause this when a disco message is + // received via raw socket at the head of a coalesced + // group of messages. Don't log in normal case. + // Callers may choose to pass on to wireguard, in case + // it's actually a wireguard packet (super unlikely, but). if debugDisco() { - c.logf("magicsock: disco: failed to open naclbox from %v (wrong rcpt?)", sender) + c.logf("magicsock: disco: failed to open naclbox from %v (wrong rcpt?) via %s", sender, via) } metricRecvDiscoBadKey.Add(1) return @@ -3205,13 +3196,13 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur defer ruc.mu.Unlock() if runtime.GOOS == "js" { - ruc.setConnLocked(newBlockForeverConn(), "") + ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize()) return nil } if debugAlwaysDERP() { c.logf("disabled %v per TS_DEBUG_ALWAYS_USE_DERP", network) - ruc.setConnLocked(newBlockForeverConn(), "") + ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize()) return nil } @@ -3253,7 +3244,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur if debugBindSocket() { c.logf("magicsock: bindSocket: successfully listened %v port %d", network, port) } - ruc.setConnLocked(pconn, network) + ruc.setConnLocked(pconn, network, c.bind.BatchSize()) if network == "udp4" { health.SetUDP4Unbound(false) } @@ -3264,7 +3255,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur // Set pconn to a dummy conn whose reads block until closed. // This keeps the receive funcs alive for a future in which // we get a link change and we can try binding again. - ruc.setConnLocked(newBlockForeverConn(), "") + ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize()) if network == "udp4" { health.SetUDP4Unbound(true) } @@ -3361,49 +3352,332 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) { return ep, nil } -type batchReaderWriter interface { - batchReader - batchWriter +// xnetBatchReaderWriter defines the batching i/o methods of +// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn). +// TODO(jwhited): This should eventually be replaced with the standard library +// implementation of https://github.com/golang/go/issues/45886 +type xnetBatchReaderWriter interface { + xnetBatchReader + xnetBatchWriter +} + +type xnetBatchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) } -type batchWriter interface { +type xnetBatchWriter interface { WriteBatch([]ipv6.Message, int) (int, error) } -type batchReader interface { - ReadBatch([]ipv6.Message, int) (int, error) +// batchingUDPConn is a UDP socket that provides batched i/o. +type batchingUDPConn struct { + pc nettype.PacketConn + xpc xnetBatchReaderWriter + rxOffload bool // supports UDP GRO or similar + txOffload atomic.Bool // supports UDP GSO or similar + setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing + getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing + sendBatchPool sync.Pool } -// udpConnWithBatchOps wraps a *net.UDPConn in order to extend it to support -// batch operations. -// -// TODO(jwhited): This wrapping is temporary. https://github.com/golang/go/issues/45886 -type udpConnWithBatchOps struct { - *net.UDPConn - xpc batchReaderWriter +func (c *batchingUDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.rxOffload { + // UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may + // receive a "monster datagram" from any read call. The ReadFrom() API + // does not support passing the GSO size and is unsafe to use in such a + // case. Other platforms may vary in behavior, but we go with the most + // conservative approach to prevent this from becoming a footgun in the + // future. + return 0, nil, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable") + } + return c.pc.ReadFrom(p) +} + +func (c *batchingUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.pc.WriteTo(b, addr) +} + +func (c *batchingUDPConn) SetDeadline(t time.Time) error { + return c.pc.SetDeadline(t) +} + +func (c *batchingUDPConn) SetReadDeadline(t time.Time) error { + return c.pc.SetReadDeadline(t) +} + +func (c *batchingUDPConn) SetWriteDeadline(t time.Time) error { + return c.pc.SetWriteDeadline(t) +} + +const ( + // This was initially established for Linux, but may split out to + // GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the + // kernel's TX path, and UDP_GRO_CNT_MAX for RX. + udpSegmentMaxDatagrams = 64 +) + +const ( + // Exceeding these values results in EMSGSIZE. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 +) + +// coalesceMessages iterates msgs, coalescing them where possible while +// maintaining datagram order. All msgs have their Addr field set to addr. +func (c *batchingUDPConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int { + var ( + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of buffs + ) + maxPayloadLen := maxIPv4PayloadLen + if addr.IP.To4() == nil { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buff := range buffs { + if i > 0 { + msgLen := len(buff) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...) + copy(msgs[base].Buffers[0][baseLenBefore:], buff) + if i == len(buffs)-1 { + c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue + } + } + if dgramCnt > 1 { + c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buff) + msgs[base].OOB = msgs[base].OOB[:0] + msgs[base].Buffers[0] = buff + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type sendBatch struct { + msgs []ipv6.Message + ua *net.UDPAddr +} + +func (c *batchingUDPConn) getSendBatch() *sendBatch { + batch := c.sendBatchPool.Get().(*sendBatch) + return batch +} + +func (c *batchingUDPConn) putSendBatch(batch *sendBatch) { + for i := range batch.msgs { + batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} + } + c.sendBatchPool.Put(batch) +} + +func (c *batchingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { + batch := c.getSendBatch() + defer c.putSendBatch(batch) + if addr.Addr().Is6() { + as16 := addr.Addr().As16() + copy(batch.ua.IP, as16[:]) + batch.ua.IP = batch.ua.IP[:16] + } else { + as4 := addr.Addr().As4() + copy(batch.ua.IP, as4[:]) + batch.ua.IP = batch.ua.IP[:4] + } + batch.ua.Port = int(addr.Port()) + var ( + n int + retried bool + ) +retry: + if c.txOffload.Load() { + n = c.coalesceMessages(batch.ua, buffs, batch.msgs) + } else { + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + batch.msgs[i].Addr = batch.ua + batch.msgs[i].OOB = batch.msgs[i].OOB[:0] + } + n = len(buffs) + } + + err := c.writeBatch(batch.msgs[:n]) + if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) { + c.txOffload.Store(false) + retried = true + goto retry + } + if retried { + return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err} + } + return err +} + +func (c *batchingUDPConn) writeBatch(msgs []ipv6.Message) error { + var head int + for { + n, err := c.xpc.WriteBatch(msgs[head:], 0) + if err != nil || n == len(msgs[head:]) { + // Returning the number of packets written would require + // unraveling individual msg len and gso size during a coalesced + // write. The top of the call stack disregards partial success, + // so keep this simple for now. + return err + } + head += n + } +} + +// splitCoalescedMessages splits coalesced messages from the tail of dst +// beginning at index 'firstMsgAt' into the head of the same slice. It reports +// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An +// error is returned if a socket control message cannot be parsed or a split +// operation would overflow msgs. +func (c *batchingUDPConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") + } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 + } + } + return n, nil } -func newUDPConnWithBatchOps(conn *net.UDPConn, network string) udpConnWithBatchOps { - ucbo := udpConnWithBatchOps{ - UDPConn: conn, +func (c *batchingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) { + if !c.rxOffload || len(msgs) < 2 { + return c.xpc.ReadBatch(msgs, flags) + } + // Read into the tail of msgs, split into the head. + readAt := len(msgs) - 2 + numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0) + if err != nil || numRead == 0 { + return 0, err + } + return c.splitCoalescedMessages(msgs, readAt) +} + +func (c *batchingUDPConn) LocalAddr() net.Addr { + return c.pc.LocalAddr().(*net.UDPAddr) +} + +func (c *batchingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + return c.pc.WriteToUDPAddrPort(b, addr) +} + +func (c *batchingUDPConn) Close() error { + return c.pc.Close() +} + +// tryUpgradeToBatchingUDPConn probes the capabilities of the OS and pconn, and +// upgrades pconn to a *batchingUDPConn if appropriate. +func tryUpgradeToBatchingUDPConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { + if network != "udp4" && network != "udp6" { + return pconn + } + if runtime.GOOS != "linux" { + return pconn + } + if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") { + // recvmmsg/sendmmsg were added in 2.6.33, but we support down to + // 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807. + // As a cheap heuristic: if the Linux kernel starts with "2", just + // consider it too old for mmsg. Nobody who cares about performance runs + // such ancient kernels. UDP offload was added much later, so no + // upgrades are available. + return pconn + } + uc, ok := pconn.(*net.UDPConn) + if !ok { + return pconn + } + b := &batchingUDPConn{ + pc: pconn, + getGSOSizeFromControl: getGSOSizeFromControl, + setGSOSizeInControl: setGSOSizeInControl, + sendBatchPool: sync.Pool{ + New: func() any { + ua := &net.UDPAddr{ + IP: make([]byte, 16), + } + msgs := make([]ipv6.Message, batchSize) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Addr = ua + msgs[i].OOB = make([]byte, controlMessageSize) + } + return &sendBatch{ + ua: ua, + msgs: msgs, + } + }, + }, } switch network { case "udp4": - ucbo.xpc = ipv4.NewPacketConn(conn) + b.xpc = ipv4.NewPacketConn(uc) case "udp6": - ucbo.xpc = ipv6.NewPacketConn(conn) + b.xpc = ipv6.NewPacketConn(uc) default: panic("bogus network") } - return ucbo -} - -func (u udpConnWithBatchOps) WriteBatch(ms []ipv6.Message, flags int) (int, error) { - return u.xpc.WriteBatch(ms, flags) -} - -func (u udpConnWithBatchOps) ReadBatch(ms []ipv6.Message, flags int) (int, error) { - return u.xpc.ReadBatch(ms, flags) + var txOffload bool + txOffload, b.rxOffload = tryEnableUDPOffload(uc) + b.txOffload.Store(txOffload) + return b } // RebindingUDPConn is a UDP socket that can be re-bound. @@ -3423,34 +3697,14 @@ type RebindingUDPConn struct { port uint16 } -// upgradePacketConn may upgrade a nettype.PacketConn to a udpConnWithBatchOps. -func upgradePacketConn(p nettype.PacketConn, network string) nettype.PacketConn { - uc, ok := p.(*net.UDPConn) - if ok && runtime.GOOS == "linux" && (network == "udp4" || network == "udp6") { - // recvmmsg/sendmmsg were added in 2.6.33 but we support down to 2.6.32 - // for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807. - // As a cheap heuristic: if the Linux kernel starts with "2", just consider - // it too old for the fast paths. Nobody who cares about performance runs such - // ancient kernels. - if strings.HasPrefix(hostinfo.GetOSVersion(), "2") { - return p - } - // Non-Linux does not support batch operations. x/net will fall back to - // recv/sendmsg, but not all platforms have recv/sendmsg support. Keep - // this simple for now. - return newUDPConnWithBatchOps(uc, network) - } - return p -} - // setConnLocked sets the provided nettype.PacketConn. It should be called only // after acquiring RebindingUDPConn.mu. It upgrades the provided -// nettype.PacketConn to a udpConnWithBatchOps when appropriate. This upgrade +// nettype.PacketConn to a *batchingUDPConn when appropriate. This upgrade // is intentionally pushed closest to where read/write ops occur in order to // avoid disrupting surrounding code that assumes nettype.PacketConn is a // *net.UDPConn. -func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string) { - upc := upgradePacketConn(p, network) +func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) { + upc := tryUpgradeToBatchingUDPConn(p, network, batchSize) c.pconn = upc c.pconnAtomic.Store(&upc) c.port = uint16(c.localAddrLocked().Port) @@ -3480,83 +3734,38 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b) } -// ReadFromNetaddr reads a packet from c into b. -// It returns the number of bytes copied and the return address. -// It is identical to c.ReadFrom, except that it returns a netip.AddrPort instead of a net.Addr. -// ReadFromNetaddr is designed to work with specific underlying connection types. -// If c's underlying connection returns a non-*net.UPDAddr return address, ReadFromNetaddr will return an error. -// ReadFromNetaddr exists because it removes an allocation per read, -// when c's underlying connection is a net.UDPConn. -func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netip.AddrPort, err error) { +// WriteBatchTo writes buffs to addr. +func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { for { pconn := *c.pconnAtomic.Load() - - // Optimization: Treat *net.UDPConn specially. - // This lets us avoid allocations by calling ReadFromUDPAddrPort. - // The non-*net.UDPConn case works, but it allocates. - if udpConn, ok := pconn.(*net.UDPConn); ok { - n, ipp, err = udpConn.ReadFromUDPAddrPort(b) - } else { - var addr net.Addr - n, addr, err = pconn.ReadFrom(b) - pAddr, ok := addr.(*net.UDPAddr) - if addr != nil && !ok { - return 0, netip.AddrPort{}, fmt.Errorf("RebindingUDPConn.ReadFromNetaddr: underlying connection returned address of type %T, want *netaddr.UDPAddr", addr) - } - if pAddr != nil { - ipp = netaddr.Unmap(pAddr.AddrPort()) - if !ipp.IsValid() { - return 0, netip.AddrPort{}, errors.New("netaddr.FromStdAddr failed") - } - } - } - - if err != nil && pconn != c.currentConn() { - // The connection changed underfoot. Try again. - continue - } - return n, ipp, err - } -} - -func (c *RebindingUDPConn) WriteBatch(msgs []ipv6.Message, flags int) (int, error) { - var ( - n int - err error - start int - ) - for { - pconn := *c.pconnAtomic.Load() - bw, ok := pconn.(batchWriter) + b, ok := pconn.(*batchingUDPConn) if !ok { - for _, msg := range msgs { - _, err = c.writeToWithInitPconn(pconn, msg.Buffers[0], msg.Addr) + for _, buf := range buffs { + _, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr) if err != nil { - return n, err + return err } - n++ } - return n, nil + return nil } - - n, err = bw.WriteBatch(msgs[start:], flags) + err := b.WriteBatchTo(buffs, addr) if err != nil { if pconn != c.currentConn() { continue } - return n, err - } else if n == len(msgs[start:]) { - return len(msgs), nil - } else { - start += n + return err } + return err } } +// ReadBatch reads messages from c into msgs. It returns the number of messages +// the caller should evaluate for nonzero len, as a zero len message may fall +// on either side of a nonzero. func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) { for { pconn := *c.pconnAtomic.Load() - br, ok := pconn.(batchReader) + b, ok := pconn.(*batchingUDPConn) if !ok { var err error msgs[0].N, msgs[0].Addr, err = c.readFromWithInitPconn(pconn, msgs[0].Buffers[0]) @@ -3565,7 +3774,7 @@ func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error } return 0, err } - n, err := br.ReadBatch(msgs, flags) + n, err := b.ReadBatch(msgs, flags) if err != nil && pconn != c.currentConn() { continue } @@ -3607,9 +3816,9 @@ func (c *RebindingUDPConn) closeLocked() error { return c.pconn.Close() } -func (c *RebindingUDPConn) writeToWithInitPconn(pconn nettype.PacketConn, b []byte, addr net.Addr) (int, error) { +func (c *RebindingUDPConn) writeToUDPAddrPortWithInitPconn(pconn nettype.PacketConn, b []byte, addr netip.AddrPort) (int, error) { for { - n, err := pconn.WriteTo(b, addr) + n, err := pconn.WriteToUDPAddrPort(b, addr) if err != nil && pconn != c.currentConn() { pconn = *c.pconnAtomic.Load() continue @@ -3619,13 +3828,9 @@ func (c *RebindingUDPConn) writeToWithInitPconn(pconn nettype.PacketConn, b []by } func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { - return c.writeToWithInitPconn(*c.pconnAtomic.Load(), b, addr) -} - -func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { for { pconn := *c.pconnAtomic.Load() - n, err := pconn.WriteToUDPAddrPort(b, addr) + n, err := pconn.WriteTo(b, addr) if err != nil && pconn != c.currentConn() { continue } @@ -3633,6 +3838,10 @@ func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (in } } +func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + return c.writeToUDPAddrPortWithInitPconn(*c.pconnAtomic.Load(), b, addr) +} + func newBlockForeverConn() *blockForeverConn { c := new(blockForeverConn) c.cond = sync.NewCond(&c.mu) @@ -3665,20 +3874,6 @@ func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (in return len(p), nil } -func (c *blockForeverConn) ReadBatch(p []ipv6.Message, flags int) (int, error) { - c.mu.Lock() - for !c.closed { - c.cond.Wait() - } - c.mu.Unlock() - return 0, net.ErrClosed -} - -func (c *blockForeverConn) WriteBatch(p []ipv6.Message, flags int) (int, error) { - // Silently drop writes. - return len(p), nil -} - func (c *blockForeverConn) LocalAddr() net.Addr { // Return a *net.UDPAddr because lots of code assumes that it will. return new(net.UDPAddr) diff --git a/wgengine/magicsock/magicsock_default.go b/wgengine/magicsock/magicsock_default.go index 8cdfd09e1..4dda3c8a6 100644 --- a/wgengine/magicsock/magicsock_default.go +++ b/wgengine/magicsock/magicsock_default.go @@ -20,3 +20,21 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) { func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { portableTrySetSocketBuffer(pconn, logf) } + +func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) { + return false, false +} + +func getGSOSizeFromControl(control []byte) (int, error) { + return 0, nil +} + +func setGSOSizeInControl(control *[]byte, gso uint16) {} + +func errShouldDisableOffload(err error) bool { + return false +} + +const ( + controlMessageSize = 0 +) diff --git a/wgengine/magicsock/magicsock_linux.go b/wgengine/magicsock/magicsock_linux.go index 024fac8a5..cdfbeb759 100644 --- a/wgengine/magicsock/magicsock_linux.go +++ b/wgengine/magicsock/magicsock_linux.go @@ -258,7 +258,7 @@ func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) { metricRecvDiscoPacketIPv6.Add(1) } - c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{}) + c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{}, discoRXPathRawSocket) } } @@ -317,3 +317,90 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { } } } + +const ( + // TODO(jwhited): upstream to unix? + socketOptionLevelUDP = 17 + socketOptionUDPSegment = 103 + socketOptionUDPGRO = 104 +) + +// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn, +// and returns two booleans indicating TX and RX UDP offload support. +func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) { + if c, ok := pconn.(*net.UDPConn); ok { + rc, err := c.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPSegment) + if errSyscall != nil { + // no point in checking RX, TX support was added first. + return + } + hasTX = true + errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPGRO, 1) + hasRX = errSyscall == nil + }) + if err != nil { + return false, false + } + } + return hasTX, hasRX +} + +// getGSOSizeFromControl returns the GSO size found in control. If no GSO size +// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0. +// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but +// its contents cannot be parsed as a socket control message. +func getGSOSizeFromControl(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == socketOptionLevelUDP && hdr.Type == socketOptionUDPGRO && len(data) >= 2 { + var gso uint16 + // TODO(jwhited): replace with encoding/binary.NativeEndian when it's available + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), 2), data[:2]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSizeInControl sets a socket control message in control containing +// gsoSize. If len(control) < controlMessageSize control's len will be set to 0. +func setGSOSizeInControl(control *[]byte, gsoSize uint16) { + *control = (*control)[:0] + if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { + return + } + if cap(*control) < controlMessageSize { + return + } + *control = (*control)[:cap(*control)] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) + hdr.Level = socketOptionLevelUDP + hdr.Type = socketOptionUDPSegment + hdr.SetLen(unix.CmsgLen(2)) + // TODO(jwhited): replace with encoding/binary.NativeEndian when it's available + copy((*control)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), 2)) + *control = (*control)[:unix.CmsgSpace(2)] +} + +var controlMessageSize = -1 // bomb if used for allocation before init + +func init() { + // controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control + // message. These contain a single uint16 of data. + controlMessageSize = unix.CmsgSpace(2) +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 18851d28d..45fac2b98 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -31,6 +31,7 @@ import ( "github.com/tailscale/wireguard-go/tun/tuntest" "go4.org/mem" "golang.org/x/exp/maps" + "golang.org/x/net/ipv6" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/derp" "tailscale.com/derp/derphttp" @@ -1156,7 +1157,7 @@ func TestDiscoMessage(t *testing.T) { box := peer1Priv.Shared(c.discoPrivate.Public()).Seal([]byte(payload)) pkt = append(pkt, box...) - got := c.handleDiscoMessage(pkt, netip.AddrPort{}, key.NodePublic{}) + got := c.handleDiscoMessage(pkt, netip.AddrPort{}, key.NodePublic{}, discoRXPathUDP) if !got { t.Error("failed to open it") } @@ -1832,8 +1833,8 @@ func TestRebindingUDPConn(t *testing.T) { t.Fatal(err) } defer realConn.Close() - c.setConnLocked(realConn.(nettype.PacketConn), "udp4") - c.setConnLocked(newBlockForeverConn(), "") + c.setConnLocked(realConn.(nettype.PacketConn), "udp4", 1) + c.setConnLocked(newBlockForeverConn(), "", 1) } // https://github.com/tailscale/tailscale/issues/6680: don't ignore @@ -1861,3 +1862,235 @@ func TestBufferedDerpWritesBeforeDrop(t *testing.T) { } t.Logf("bufferedDerpWritesBeforeDrop = %d", vv) } + +func setGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func getGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_batchingUDPConn_splitCoalescedMessages(t *testing.T) { + c := &batchingUDPConn{ + setGSOSizeInControl: setGSOSize, + getGSOSizeFromControl: getGSOSize, + } + + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1024)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := c.splitCoalescedMessages(tt.msgs, 2) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} + +func Test_batchingUDPConn_coalesceMessages(t *testing.T) { + c := &batchingUDPConn{ + setGSOSizeInControl: setGSOSize, + getGSOSizeFromControl: getGSOSize, + } + + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := c.coalesceMessages(addr, tt.buffs, msgs) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := getGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +}