diff --git a/net/neterror/neterror.go b/net/neterror/neterror.go index c2d3269d5..e2387440d 100644 --- a/net/neterror/neterror.go +++ b/net/neterror/neterror.go @@ -6,6 +6,7 @@ package neterror import ( "errors" + "fmt" "runtime" "syscall" ) @@ -57,3 +58,25 @@ func PacketWasTruncated(err error) bool { } return packetWasTruncated(err) } + +var shouldDisableUDPGSO func(error) bool // non-nil on Linux + +func ShouldDisableUDPGSO(err error) bool { + if shouldDisableUDPGSO == nil { + return false + } + return shouldDisableUDPGSO(err) +} + +type ErrUDPGSODisabled struct { + OnLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} diff --git a/net/neterror/neterror_linux.go b/net/neterror/neterror_linux.go new file mode 100644 index 000000000..857367fe8 --- /dev/null +++ b/net/neterror/neterror_linux.go @@ -0,0 +1,26 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func init() { + shouldDisableUDPGSO = func(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not + // have tx checksumming enabled, which is a hard requirement of + // UDP_SEGMENT. See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false + } +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index d33cbd244..25ebcc97e 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -44,7 +44,6 @@ import ( "tailscale.com/net/connstats" "tailscale.com/net/dnscache" "tailscale.com/net/interfaces" - "tailscale.com/net/netaddr" "tailscale.com/net/netcheck" "tailscale.com/net/neterror" "tailscale.com/net/netns" @@ -281,7 +280,6 @@ type Conn struct { pconn6 RebindingUDPConn receiveBatchPool sync.Pool - sendBatchPool sync.Pool // closeDisco4 and closeDisco6 are io.Closers to shut down the raw // disco packet receivers. If nil, no raw disco receiver is @@ -597,26 +595,13 @@ func newConn() *Conn { msgs := make([]ipv6.Message, c.bind.BatchSize()) for i := range msgs { msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, controlMessageSize) } batch := &receiveBatch{ msgs: msgs, } return batch }} - c.sendBatchPool = sync.Pool{New: func() any { - ua := &net.UDPAddr{ - IP: make([]byte, 16), - } - msgs := make([]ipv6.Message, c.bind.BatchSize()) - for i := range msgs { - msgs[i].Buffers = make([][]byte, 1) - msgs[i].Addr = ua - } - return &sendBatch{ - ua: ua, - msgs: msgs, - } - }} c.muCond = sync.NewCond(&c.mu) c.networkUp.Store(true) // assume up until told otherwise return c @@ -1301,19 +1286,11 @@ var errNoUDP = errors.New("no UDP available on platform") var ( // This acts as a compile-time check for our usage of ipv6.Message in - // udpConnWithBatchOps for both IPv6 and IPv4 operations. + // batchingUDPConn for both IPv6 and IPv4 operations. _ ipv6.Message = ipv4.Message{} ) -type sendBatch struct { - ua *net.UDPAddr - msgs []ipv6.Message // ipv4.Message and ipv6.Message are the same underlying type -} - func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) { - batch := c.sendBatchPool.Get().(*sendBatch) - defer c.sendBatchPool.Put(batch) - isIPv6 := false switch { case addr.Addr().Is4(): @@ -1322,19 +1299,17 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err default: panic("bogus sendUDPBatch addr type") } - - as16 := addr.Addr().As16() - copy(batch.ua.IP, as16[:]) - batch.ua.Port = int(addr.Port()) - for i, buff := range buffs { - batch.msgs[i].Buffers[0] = buff - batch.msgs[i].Addr = batch.ua - } - if isIPv6 { - _, err = c.pconn6.WriteBatch(batch.msgs[:len(buffs)], 0) + err = c.pconn6.WriteBatchTo(buffs, addr) } else { - _, err = c.pconn4.WriteBatch(batch.msgs[:len(buffs)], 0) + err = c.pconn4.WriteBatchTo(buffs, addr) + } + if err != nil { + var errGSO neterror.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + c.logf("magicsock: %s", errGSO.Error()) + err = errGSO.RetryErr + } } return err == nil, err } @@ -1844,14 +1819,18 @@ type receiveBatch struct { msgs []ipv6.Message } -func (c *Conn) getReceiveBatch() *receiveBatch { +func (c *Conn) getReceiveBatchForBuffs(buffs [][]byte) *receiveBatch { batch := c.receiveBatchPool.Get().(*receiveBatch) + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + batch.msgs[i].OOB = batch.msgs[i].OOB[:cap(batch.msgs[i].OOB)] + } return batch } func (c *Conn) putReceiveBatch(batch *receiveBatch) { for i := range batch.msgs { - batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers} + batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} } c.receiveBatchPool.Put(batch) } @@ -1860,13 +1839,10 @@ func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in health.ReceiveIPv6.Enter() defer health.ReceiveIPv6.Exit() - batch := c.getReceiveBatch() + batch := c.getReceiveBatchForBuffs(buffs) defer c.putReceiveBatch(batch) for { - for i := range buffs { - batch.msgs[i].Buffers[0] = buffs[i] - } - numMsgs, err := c.pconn6.ReadBatch(batch.msgs, 0) + numMsgs, err := c.pconn6.ReadBatch(batch.msgs[:len(buffs)], 0) if err != nil { if neterror.PacketWasTruncated(err) { // TODO(raggi): discuss whether to log? @@ -1877,6 +1853,10 @@ func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in reportToCaller := false for i, msg := range batch.msgs[:numMsgs] { + if msg.N == 0 { + sizes[i] = 0 + continue + } ipp := msg.Addr.(*net.UDPAddr).AddrPort() if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok { metricRecvDataIPv6.Add(1) @@ -1898,13 +1878,10 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in health.ReceiveIPv4.Enter() defer health.ReceiveIPv4.Exit() - batch := c.getReceiveBatch() + batch := c.getReceiveBatchForBuffs(buffs) defer c.putReceiveBatch(batch) for { - for i := range buffs { - batch.msgs[i].Buffers[0] = buffs[i] - } - numMsgs, err := c.pconn4.ReadBatch(batch.msgs, 0) + numMsgs, err := c.pconn4.ReadBatch(batch.msgs[:len(buffs)], 0) if err != nil { if neterror.PacketWasTruncated(err) { // TODO(raggi): discuss whether to log? @@ -1915,6 +1892,10 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in reportToCaller := false for i, msg := range batch.msgs[:numMsgs] { + if msg.N == 0 { + sizes[i] = 0 + continue + } ipp := msg.Addr.(*net.UDPAddr).AddrPort() if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok { metricRecvDataIPv4.Add(1) @@ -1940,7 +1921,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) c.stunReceiveFunc.Load()(b, ipp) return nil, false } - if c.handleDiscoMessage(b, ipp, key.NodePublic{}) { + if c.handleDiscoMessage(b, ipp, key.NodePublic{}, discoRXPathUDP) { return nil, false } if !c.havePrivateKey.Load() { @@ -2005,7 +1986,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en } ipp := netip.AddrPortFrom(derpMagicIPAddr, uint16(regionID)) - if c.handleDiscoMessage(b[:n], ipp, dm.src) { + if c.handleDiscoMessage(b[:n], ipp, dm.src, discoRXPathDERP) { return 0, nil } @@ -2139,6 +2120,14 @@ func discoPcapFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []by return b.Bytes() } +type discoRXPath string + +const ( + discoRXPathUDP discoRXPath = "UDP socket" + discoRXPathDERP discoRXPath = "DERP" + discoRXPathRawSocket discoRXPath = "raw socket" +) + // handleDiscoMessage handles a discovery message and reports whether // msg was a Tailscale inter-node discovery message. // @@ -2153,7 +2142,7 @@ func discoPcapFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []by // src.Port() being the region ID) and the derpNodeSrc will be the node key // it was received from at the DERP layer. derpNodeSrc is zero when received // over UDP. -func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic) (isDiscoMsg bool) { +func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic, via discoRXPath) (isDiscoMsg bool) { const headerLen = len(disco.Magic) + key.DiscoPublicRawLen if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { return false @@ -2174,7 +2163,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } if debugDisco() { - c.logf("magicsock: disco: got disco-looking frame from %v", sender.ShortString()) + c.logf("magicsock: disco: got disco-looking frame from %v via %s", sender.ShortString(), via) } if c.privateKey.IsZero() { // Ignore disco messages when we're stopped. @@ -2210,12 +2199,14 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke // disco key. When we restart we get a new disco key // and old packets might've still been in flight (or // scheduled). This is particularly the case for LANs - // or non-NATed endpoints. - // Don't log in normal case. Pass on to wireguard, in case - // it's actually a wireguard packet (super unlikely, - // but). + // or non-NATed endpoints. UDP offloading on Linux + // can also cause this when a disco message is + // received via raw socket at the head of a coalesced + // group of messages. Don't log in normal case. + // Callers may choose to pass on to wireguard, in case + // it's actually a wireguard packet (super unlikely, but). if debugDisco() { - c.logf("magicsock: disco: failed to open naclbox from %v (wrong rcpt?)", sender) + c.logf("magicsock: disco: failed to open naclbox from %v (wrong rcpt?) via %s", sender, via) } metricRecvDiscoBadKey.Add(1) return @@ -3205,13 +3196,13 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur defer ruc.mu.Unlock() if runtime.GOOS == "js" { - ruc.setConnLocked(newBlockForeverConn(), "") + ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize()) return nil } if debugAlwaysDERP() { c.logf("disabled %v per TS_DEBUG_ALWAYS_USE_DERP", network) - ruc.setConnLocked(newBlockForeverConn(), "") + ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize()) return nil } @@ -3253,7 +3244,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur if debugBindSocket() { c.logf("magicsock: bindSocket: successfully listened %v port %d", network, port) } - ruc.setConnLocked(pconn, network) + ruc.setConnLocked(pconn, network, c.bind.BatchSize()) if network == "udp4" { health.SetUDP4Unbound(false) } @@ -3264,7 +3255,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur // Set pconn to a dummy conn whose reads block until closed. // This keeps the receive funcs alive for a future in which // we get a link change and we can try binding again. - ruc.setConnLocked(newBlockForeverConn(), "") + ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize()) if network == "udp4" { health.SetUDP4Unbound(true) } @@ -3361,49 +3352,332 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) { return ep, nil } -type batchReaderWriter interface { - batchReader - batchWriter +// xnetBatchReaderWriter defines the batching i/o methods of +// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn). +// TODO(jwhited): This should eventually be replaced with the standard library +// implementation of https://github.com/golang/go/issues/45886 +type xnetBatchReaderWriter interface { + xnetBatchReader + xnetBatchWriter +} + +type xnetBatchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) } -type batchWriter interface { +type xnetBatchWriter interface { WriteBatch([]ipv6.Message, int) (int, error) } -type batchReader interface { - ReadBatch([]ipv6.Message, int) (int, error) +// batchingUDPConn is a UDP socket that provides batched i/o. +type batchingUDPConn struct { + pc nettype.PacketConn + xpc xnetBatchReaderWriter + rxOffload bool // supports UDP GRO or similar + txOffload atomic.Bool // supports UDP GSO or similar + setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing + getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing + sendBatchPool sync.Pool } -// udpConnWithBatchOps wraps a *net.UDPConn in order to extend it to support -// batch operations. -// -// TODO(jwhited): This wrapping is temporary. https://github.com/golang/go/issues/45886 -type udpConnWithBatchOps struct { - *net.UDPConn - xpc batchReaderWriter +func (c *batchingUDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.rxOffload { + // UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may + // receive a "monster datagram" from any read call. The ReadFrom() API + // does not support passing the GSO size and is unsafe to use in such a + // case. Other platforms may vary in behavior, but we go with the most + // conservative approach to prevent this from becoming a footgun in the + // future. + return 0, nil, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable") + } + return c.pc.ReadFrom(p) +} + +func (c *batchingUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.pc.WriteTo(b, addr) +} + +func (c *batchingUDPConn) SetDeadline(t time.Time) error { + return c.pc.SetDeadline(t) +} + +func (c *batchingUDPConn) SetReadDeadline(t time.Time) error { + return c.pc.SetReadDeadline(t) +} + +func (c *batchingUDPConn) SetWriteDeadline(t time.Time) error { + return c.pc.SetWriteDeadline(t) +} + +const ( + // This was initially established for Linux, but may split out to + // GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the + // kernel's TX path, and UDP_GRO_CNT_MAX for RX. + udpSegmentMaxDatagrams = 64 +) + +const ( + // Exceeding these values results in EMSGSIZE. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 +) + +// coalesceMessages iterates msgs, coalescing them where possible while +// maintaining datagram order. All msgs have their Addr field set to addr. +func (c *batchingUDPConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int { + var ( + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of buffs + ) + maxPayloadLen := maxIPv4PayloadLen + if addr.IP.To4() == nil { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buff := range buffs { + if i > 0 { + msgLen := len(buff) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...) + copy(msgs[base].Buffers[0][baseLenBefore:], buff) + if i == len(buffs)-1 { + c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue + } + } + if dgramCnt > 1 { + c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buff) + msgs[base].OOB = msgs[base].OOB[:0] + msgs[base].Buffers[0] = buff + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type sendBatch struct { + msgs []ipv6.Message + ua *net.UDPAddr +} + +func (c *batchingUDPConn) getSendBatch() *sendBatch { + batch := c.sendBatchPool.Get().(*sendBatch) + return batch +} + +func (c *batchingUDPConn) putSendBatch(batch *sendBatch) { + for i := range batch.msgs { + batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} + } + c.sendBatchPool.Put(batch) +} + +func (c *batchingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { + batch := c.getSendBatch() + defer c.putSendBatch(batch) + if addr.Addr().Is6() { + as16 := addr.Addr().As16() + copy(batch.ua.IP, as16[:]) + batch.ua.IP = batch.ua.IP[:16] + } else { + as4 := addr.Addr().As4() + copy(batch.ua.IP, as4[:]) + batch.ua.IP = batch.ua.IP[:4] + } + batch.ua.Port = int(addr.Port()) + var ( + n int + retried bool + ) +retry: + if c.txOffload.Load() { + n = c.coalesceMessages(batch.ua, buffs, batch.msgs) + } else { + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + batch.msgs[i].Addr = batch.ua + batch.msgs[i].OOB = batch.msgs[i].OOB[:0] + } + n = len(buffs) + } + + err := c.writeBatch(batch.msgs[:n]) + if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) { + c.txOffload.Store(false) + retried = true + goto retry + } + if retried { + return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err} + } + return err +} + +func (c *batchingUDPConn) writeBatch(msgs []ipv6.Message) error { + var head int + for { + n, err := c.xpc.WriteBatch(msgs[head:], 0) + if err != nil || n == len(msgs[head:]) { + // Returning the number of packets written would require + // unraveling individual msg len and gso size during a coalesced + // write. The top of the call stack disregards partial success, + // so keep this simple for now. + return err + } + head += n + } +} + +// splitCoalescedMessages splits coalesced messages from the tail of dst +// beginning at index 'firstMsgAt' into the head of the same slice. It reports +// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An +// error is returned if a socket control message cannot be parsed or a split +// operation would overflow msgs. +func (c *batchingUDPConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") + } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 + } + } + return n, nil } -func newUDPConnWithBatchOps(conn *net.UDPConn, network string) udpConnWithBatchOps { - ucbo := udpConnWithBatchOps{ - UDPConn: conn, +func (c *batchingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) { + if !c.rxOffload || len(msgs) < 2 { + return c.xpc.ReadBatch(msgs, flags) + } + // Read into the tail of msgs, split into the head. + readAt := len(msgs) - 2 + numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0) + if err != nil || numRead == 0 { + return 0, err + } + return c.splitCoalescedMessages(msgs, readAt) +} + +func (c *batchingUDPConn) LocalAddr() net.Addr { + return c.pc.LocalAddr().(*net.UDPAddr) +} + +func (c *batchingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + return c.pc.WriteToUDPAddrPort(b, addr) +} + +func (c *batchingUDPConn) Close() error { + return c.pc.Close() +} + +// tryUpgradeToBatchingUDPConn probes the capabilities of the OS and pconn, and +// upgrades pconn to a *batchingUDPConn if appropriate. +func tryUpgradeToBatchingUDPConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { + if network != "udp4" && network != "udp6" { + return pconn + } + if runtime.GOOS != "linux" { + return pconn + } + if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") { + // recvmmsg/sendmmsg were added in 2.6.33, but we support down to + // 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807. + // As a cheap heuristic: if the Linux kernel starts with "2", just + // consider it too old for mmsg. Nobody who cares about performance runs + // such ancient kernels. UDP offload was added much later, so no + // upgrades are available. + return pconn + } + uc, ok := pconn.(*net.UDPConn) + if !ok { + return pconn + } + b := &batchingUDPConn{ + pc: pconn, + getGSOSizeFromControl: getGSOSizeFromControl, + setGSOSizeInControl: setGSOSizeInControl, + sendBatchPool: sync.Pool{ + New: func() any { + ua := &net.UDPAddr{ + IP: make([]byte, 16), + } + msgs := make([]ipv6.Message, batchSize) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Addr = ua + msgs[i].OOB = make([]byte, controlMessageSize) + } + return &sendBatch{ + ua: ua, + msgs: msgs, + } + }, + }, } switch network { case "udp4": - ucbo.xpc = ipv4.NewPacketConn(conn) + b.xpc = ipv4.NewPacketConn(uc) case "udp6": - ucbo.xpc = ipv6.NewPacketConn(conn) + b.xpc = ipv6.NewPacketConn(uc) default: panic("bogus network") } - return ucbo -} - -func (u udpConnWithBatchOps) WriteBatch(ms []ipv6.Message, flags int) (int, error) { - return u.xpc.WriteBatch(ms, flags) -} - -func (u udpConnWithBatchOps) ReadBatch(ms []ipv6.Message, flags int) (int, error) { - return u.xpc.ReadBatch(ms, flags) + var txOffload bool + txOffload, b.rxOffload = tryEnableUDPOffload(uc) + b.txOffload.Store(txOffload) + return b } // RebindingUDPConn is a UDP socket that can be re-bound. @@ -3423,34 +3697,14 @@ type RebindingUDPConn struct { port uint16 } -// upgradePacketConn may upgrade a nettype.PacketConn to a udpConnWithBatchOps. -func upgradePacketConn(p nettype.PacketConn, network string) nettype.PacketConn { - uc, ok := p.(*net.UDPConn) - if ok && runtime.GOOS == "linux" && (network == "udp4" || network == "udp6") { - // recvmmsg/sendmmsg were added in 2.6.33 but we support down to 2.6.32 - // for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807. - // As a cheap heuristic: if the Linux kernel starts with "2", just consider - // it too old for the fast paths. Nobody who cares about performance runs such - // ancient kernels. - if strings.HasPrefix(hostinfo.GetOSVersion(), "2") { - return p - } - // Non-Linux does not support batch operations. x/net will fall back to - // recv/sendmsg, but not all platforms have recv/sendmsg support. Keep - // this simple for now. - return newUDPConnWithBatchOps(uc, network) - } - return p -} - // setConnLocked sets the provided nettype.PacketConn. It should be called only // after acquiring RebindingUDPConn.mu. It upgrades the provided -// nettype.PacketConn to a udpConnWithBatchOps when appropriate. This upgrade +// nettype.PacketConn to a *batchingUDPConn when appropriate. This upgrade // is intentionally pushed closest to where read/write ops occur in order to // avoid disrupting surrounding code that assumes nettype.PacketConn is a // *net.UDPConn. -func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string) { - upc := upgradePacketConn(p, network) +func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) { + upc := tryUpgradeToBatchingUDPConn(p, network, batchSize) c.pconn = upc c.pconnAtomic.Store(&upc) c.port = uint16(c.localAddrLocked().Port) @@ -3480,83 +3734,38 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b) } -// ReadFromNetaddr reads a packet from c into b. -// It returns the number of bytes copied and the return address. -// It is identical to c.ReadFrom, except that it returns a netip.AddrPort instead of a net.Addr. -// ReadFromNetaddr is designed to work with specific underlying connection types. -// If c's underlying connection returns a non-*net.UPDAddr return address, ReadFromNetaddr will return an error. -// ReadFromNetaddr exists because it removes an allocation per read, -// when c's underlying connection is a net.UDPConn. -func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netip.AddrPort, err error) { +// WriteBatchTo writes buffs to addr. +func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { for { pconn := *c.pconnAtomic.Load() - - // Optimization: Treat *net.UDPConn specially. - // This lets us avoid allocations by calling ReadFromUDPAddrPort. - // The non-*net.UDPConn case works, but it allocates. - if udpConn, ok := pconn.(*net.UDPConn); ok { - n, ipp, err = udpConn.ReadFromUDPAddrPort(b) - } else { - var addr net.Addr - n, addr, err = pconn.ReadFrom(b) - pAddr, ok := addr.(*net.UDPAddr) - if addr != nil && !ok { - return 0, netip.AddrPort{}, fmt.Errorf("RebindingUDPConn.ReadFromNetaddr: underlying connection returned address of type %T, want *netaddr.UDPAddr", addr) - } - if pAddr != nil { - ipp = netaddr.Unmap(pAddr.AddrPort()) - if !ipp.IsValid() { - return 0, netip.AddrPort{}, errors.New("netaddr.FromStdAddr failed") - } - } - } - - if err != nil && pconn != c.currentConn() { - // The connection changed underfoot. Try again. - continue - } - return n, ipp, err - } -} - -func (c *RebindingUDPConn) WriteBatch(msgs []ipv6.Message, flags int) (int, error) { - var ( - n int - err error - start int - ) - for { - pconn := *c.pconnAtomic.Load() - bw, ok := pconn.(batchWriter) + b, ok := pconn.(*batchingUDPConn) if !ok { - for _, msg := range msgs { - _, err = c.writeToWithInitPconn(pconn, msg.Buffers[0], msg.Addr) + for _, buf := range buffs { + _, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr) if err != nil { - return n, err + return err } - n++ } - return n, nil + return nil } - - n, err = bw.WriteBatch(msgs[start:], flags) + err := b.WriteBatchTo(buffs, addr) if err != nil { if pconn != c.currentConn() { continue } - return n, err - } else if n == len(msgs[start:]) { - return len(msgs), nil - } else { - start += n + return err } + return err } } +// ReadBatch reads messages from c into msgs. It returns the number of messages +// the caller should evaluate for nonzero len, as a zero len message may fall +// on either side of a nonzero. func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) { for { pconn := *c.pconnAtomic.Load() - br, ok := pconn.(batchReader) + b, ok := pconn.(*batchingUDPConn) if !ok { var err error msgs[0].N, msgs[0].Addr, err = c.readFromWithInitPconn(pconn, msgs[0].Buffers[0]) @@ -3565,7 +3774,7 @@ func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error } return 0, err } - n, err := br.ReadBatch(msgs, flags) + n, err := b.ReadBatch(msgs, flags) if err != nil && pconn != c.currentConn() { continue } @@ -3607,9 +3816,9 @@ func (c *RebindingUDPConn) closeLocked() error { return c.pconn.Close() } -func (c *RebindingUDPConn) writeToWithInitPconn(pconn nettype.PacketConn, b []byte, addr net.Addr) (int, error) { +func (c *RebindingUDPConn) writeToUDPAddrPortWithInitPconn(pconn nettype.PacketConn, b []byte, addr netip.AddrPort) (int, error) { for { - n, err := pconn.WriteTo(b, addr) + n, err := pconn.WriteToUDPAddrPort(b, addr) if err != nil && pconn != c.currentConn() { pconn = *c.pconnAtomic.Load() continue @@ -3619,13 +3828,9 @@ func (c *RebindingUDPConn) writeToWithInitPconn(pconn nettype.PacketConn, b []by } func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { - return c.writeToWithInitPconn(*c.pconnAtomic.Load(), b, addr) -} - -func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { for { pconn := *c.pconnAtomic.Load() - n, err := pconn.WriteToUDPAddrPort(b, addr) + n, err := pconn.WriteTo(b, addr) if err != nil && pconn != c.currentConn() { continue } @@ -3633,6 +3838,10 @@ func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (in } } +func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + return c.writeToUDPAddrPortWithInitPconn(*c.pconnAtomic.Load(), b, addr) +} + func newBlockForeverConn() *blockForeverConn { c := new(blockForeverConn) c.cond = sync.NewCond(&c.mu) @@ -3665,20 +3874,6 @@ func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (in return len(p), nil } -func (c *blockForeverConn) ReadBatch(p []ipv6.Message, flags int) (int, error) { - c.mu.Lock() - for !c.closed { - c.cond.Wait() - } - c.mu.Unlock() - return 0, net.ErrClosed -} - -func (c *blockForeverConn) WriteBatch(p []ipv6.Message, flags int) (int, error) { - // Silently drop writes. - return len(p), nil -} - func (c *blockForeverConn) LocalAddr() net.Addr { // Return a *net.UDPAddr because lots of code assumes that it will. return new(net.UDPAddr) diff --git a/wgengine/magicsock/magicsock_default.go b/wgengine/magicsock/magicsock_default.go index 8cdfd09e1..4dda3c8a6 100644 --- a/wgengine/magicsock/magicsock_default.go +++ b/wgengine/magicsock/magicsock_default.go @@ -20,3 +20,21 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) { func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { portableTrySetSocketBuffer(pconn, logf) } + +func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) { + return false, false +} + +func getGSOSizeFromControl(control []byte) (int, error) { + return 0, nil +} + +func setGSOSizeInControl(control *[]byte, gso uint16) {} + +func errShouldDisableOffload(err error) bool { + return false +} + +const ( + controlMessageSize = 0 +) diff --git a/wgengine/magicsock/magicsock_linux.go b/wgengine/magicsock/magicsock_linux.go index 024fac8a5..cdfbeb759 100644 --- a/wgengine/magicsock/magicsock_linux.go +++ b/wgengine/magicsock/magicsock_linux.go @@ -258,7 +258,7 @@ func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) { metricRecvDiscoPacketIPv6.Add(1) } - c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{}) + c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{}, discoRXPathRawSocket) } } @@ -317,3 +317,90 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { } } } + +const ( + // TODO(jwhited): upstream to unix? + socketOptionLevelUDP = 17 + socketOptionUDPSegment = 103 + socketOptionUDPGRO = 104 +) + +// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn, +// and returns two booleans indicating TX and RX UDP offload support. +func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) { + if c, ok := pconn.(*net.UDPConn); ok { + rc, err := c.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPSegment) + if errSyscall != nil { + // no point in checking RX, TX support was added first. + return + } + hasTX = true + errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPGRO, 1) + hasRX = errSyscall == nil + }) + if err != nil { + return false, false + } + } + return hasTX, hasRX +} + +// getGSOSizeFromControl returns the GSO size found in control. If no GSO size +// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0. +// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but +// its contents cannot be parsed as a socket control message. +func getGSOSizeFromControl(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == socketOptionLevelUDP && hdr.Type == socketOptionUDPGRO && len(data) >= 2 { + var gso uint16 + // TODO(jwhited): replace with encoding/binary.NativeEndian when it's available + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), 2), data[:2]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSizeInControl sets a socket control message in control containing +// gsoSize. If len(control) < controlMessageSize control's len will be set to 0. +func setGSOSizeInControl(control *[]byte, gsoSize uint16) { + *control = (*control)[:0] + if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { + return + } + if cap(*control) < controlMessageSize { + return + } + *control = (*control)[:cap(*control)] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) + hdr.Level = socketOptionLevelUDP + hdr.Type = socketOptionUDPSegment + hdr.SetLen(unix.CmsgLen(2)) + // TODO(jwhited): replace with encoding/binary.NativeEndian when it's available + copy((*control)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), 2)) + *control = (*control)[:unix.CmsgSpace(2)] +} + +var controlMessageSize = -1 // bomb if used for allocation before init + +func init() { + // controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control + // message. These contain a single uint16 of data. + controlMessageSize = unix.CmsgSpace(2) +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 18851d28d..45fac2b98 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -31,6 +31,7 @@ import ( "github.com/tailscale/wireguard-go/tun/tuntest" "go4.org/mem" "golang.org/x/exp/maps" + "golang.org/x/net/ipv6" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/derp" "tailscale.com/derp/derphttp" @@ -1156,7 +1157,7 @@ func TestDiscoMessage(t *testing.T) { box := peer1Priv.Shared(c.discoPrivate.Public()).Seal([]byte(payload)) pkt = append(pkt, box...) - got := c.handleDiscoMessage(pkt, netip.AddrPort{}, key.NodePublic{}) + got := c.handleDiscoMessage(pkt, netip.AddrPort{}, key.NodePublic{}, discoRXPathUDP) if !got { t.Error("failed to open it") } @@ -1832,8 +1833,8 @@ func TestRebindingUDPConn(t *testing.T) { t.Fatal(err) } defer realConn.Close() - c.setConnLocked(realConn.(nettype.PacketConn), "udp4") - c.setConnLocked(newBlockForeverConn(), "") + c.setConnLocked(realConn.(nettype.PacketConn), "udp4", 1) + c.setConnLocked(newBlockForeverConn(), "", 1) } // https://github.com/tailscale/tailscale/issues/6680: don't ignore @@ -1861,3 +1862,235 @@ func TestBufferedDerpWritesBeforeDrop(t *testing.T) { } t.Logf("bufferedDerpWritesBeforeDrop = %d", vv) } + +func setGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func getGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_batchingUDPConn_splitCoalescedMessages(t *testing.T) { + c := &batchingUDPConn{ + setGSOSizeInControl: setGSOSize, + getGSOSizeFromControl: getGSOSize, + } + + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1024)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := c.splitCoalescedMessages(tt.msgs, 2) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} + +func Test_batchingUDPConn_coalesceMessages(t *testing.T) { + c := &batchingUDPConn{ + setGSOSizeInControl: setGSOSize, + getGSOSizeFromControl: getGSOSize, + } + + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := c.coalesceMessages(addr, tt.buffs, msgs) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := getGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +}