diff --git a/wgengine/magicsock/batching_conn.go b/wgengine/magicsock/batching_conn.go index 242f31c37..5320d1caf 100644 --- a/wgengine/magicsock/batching_conn.go +++ b/wgengine/magicsock/batching_conn.go @@ -4,200 +4,22 @@ package magicsock import ( - "errors" - "net" "net/netip" - "sync" - "sync/atomic" - "syscall" - "time" + "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "tailscale.com/net/neterror" "tailscale.com/types/nettype" ) -// xnetBatchReaderWriter defines the batching i/o methods of -// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn). -// TODO(jwhited): This should eventually be replaced with the standard library -// implementation of https://github.com/golang/go/issues/45886 -type xnetBatchReaderWriter interface { - xnetBatchReader - xnetBatchWriter -} - -type xnetBatchReader interface { - ReadBatch([]ipv6.Message, int) (int, error) -} - -type xnetBatchWriter interface { - WriteBatch([]ipv6.Message, int) (int, error) -} - -// batchingUDPConn is a UDP socket that provides batched i/o. -type batchingUDPConn struct { - pc nettype.PacketConn - xpc xnetBatchReaderWriter - rxOffload bool // supports UDP GRO or similar - txOffload atomic.Bool // supports UDP GSO or similar - setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing - getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing - sendBatchPool sync.Pool -} - -func (c *batchingUDPConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - if c.rxOffload { - // UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may - // receive a "monster datagram" from any read call. The ReadFrom() API - // does not support passing the GSO size and is unsafe to use in such a - // case. Other platforms may vary in behavior, but we go with the most - // conservative approach to prevent this from becoming a footgun in the - // future. - return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable") - } - return c.pc.ReadFromUDPAddrPort(p) -} - -func (c *batchingUDPConn) SetDeadline(t time.Time) error { - return c.pc.SetDeadline(t) -} - -func (c *batchingUDPConn) SetReadDeadline(t time.Time) error { - return c.pc.SetReadDeadline(t) -} - -func (c *batchingUDPConn) SetWriteDeadline(t time.Time) error { - return c.pc.SetWriteDeadline(t) -} - -const ( - // This was initially established for Linux, but may split out to - // GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the - // kernel's TX path, and UDP_GRO_CNT_MAX for RX. - udpSegmentMaxDatagrams = 64 -) - -const ( - // Exceeding these values results in EMSGSIZE. - maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 - maxIPv6PayloadLen = 1<<16 - 1 - 8 +var ( + // This acts as a compile-time check for our usage of ipv6.Message in + // batchingConn for both IPv6 and IPv4 operations. + _ ipv6.Message = ipv4.Message{} ) -// coalesceMessages iterates msgs, coalescing them where possible while -// maintaining datagram order. All msgs have their Addr field set to addr. -func (c *batchingUDPConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int { - var ( - base = -1 // index of msg we are currently coalescing into - gsoSize int // segmentation size of msgs[base] - dgramCnt int // number of dgrams coalesced into msgs[base] - endBatch bool // tracking flag to start a new batch on next iteration of buffs - ) - maxPayloadLen := maxIPv4PayloadLen - if addr.IP.To4() == nil { - maxPayloadLen = maxIPv6PayloadLen - } - for i, buff := range buffs { - if i > 0 { - msgLen := len(buff) - baseLenBefore := len(msgs[base].Buffers[0]) - freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore - if msgLen+baseLenBefore <= maxPayloadLen && - msgLen <= gsoSize && - msgLen <= freeBaseCap && - dgramCnt < udpSegmentMaxDatagrams && - !endBatch { - msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...) - copy(msgs[base].Buffers[0][baseLenBefore:], buff) - if i == len(buffs)-1 { - c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) - } - dgramCnt++ - if msgLen < gsoSize { - // A smaller than gsoSize packet on the tail is legal, but - // it must end the batch. - endBatch = true - } - continue - } - } - if dgramCnt > 1 { - c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) - } - // Reset prior to incrementing base since we are preparing to start a - // new potential batch. - endBatch = false - base++ - gsoSize = len(buff) - msgs[base].OOB = msgs[base].OOB[:0] - msgs[base].Buffers[0] = buff - msgs[base].Addr = addr - dgramCnt = 1 - } - return base + 1 -} - -type sendBatch struct { - msgs []ipv6.Message - ua *net.UDPAddr -} - -func (c *batchingUDPConn) getSendBatch() *sendBatch { - batch := c.sendBatchPool.Get().(*sendBatch) - return batch -} - -func (c *batchingUDPConn) putSendBatch(batch *sendBatch) { - for i := range batch.msgs { - batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} - } - c.sendBatchPool.Put(batch) -} - -func (c *batchingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { - batch := c.getSendBatch() - defer c.putSendBatch(batch) - if addr.Addr().Is6() { - as16 := addr.Addr().As16() - copy(batch.ua.IP, as16[:]) - batch.ua.IP = batch.ua.IP[:16] - } else { - as4 := addr.Addr().As4() - copy(batch.ua.IP, as4[:]) - batch.ua.IP = batch.ua.IP[:4] - } - batch.ua.Port = int(addr.Port()) - var ( - n int - retried bool - ) -retry: - if c.txOffload.Load() { - n = c.coalesceMessages(batch.ua, buffs, batch.msgs) - } else { - for i := range buffs { - batch.msgs[i].Buffers[0] = buffs[i] - batch.msgs[i].Addr = batch.ua - batch.msgs[i].OOB = batch.msgs[i].OOB[:0] - } - n = len(buffs) - } - - err := c.writeBatch(batch.msgs[:n]) - if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) { - c.txOffload.Store(false) - retried = true - goto retry - } - if retried { - return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err} - } - return err -} - -func (c *batchingUDPConn) SyscallConn() (syscall.RawConn, error) { - sc, ok := c.pc.(syscall.Conn) - if !ok { - return nil, errUnsupportedConnType - } - return sc.SyscallConn() +// batchingConn is a nettype.PacketConn that provides batched i/o. +type batchingConn interface { + nettype.PacketConn + ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) + WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error } diff --git a/wgengine/magicsock/batching_conn_default.go b/wgengine/magicsock/batching_conn_default.go new file mode 100644 index 000000000..519cf8082 --- /dev/null +++ b/wgengine/magicsock/batching_conn_default.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package magicsock + +import ( + "tailscale.com/types/nettype" +) + +func tryUpgradeToBatchingConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn { + return pconn +} diff --git a/wgengine/magicsock/batching_conn_linux.go b/wgengine/magicsock/batching_conn_linux.go new file mode 100644 index 000000000..2b58256b2 --- /dev/null +++ b/wgengine/magicsock/batching_conn_linux.go @@ -0,0 +1,419 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "net/netip" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + "unsafe" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" + "tailscale.com/hostinfo" + "tailscale.com/net/neterror" + "tailscale.com/types/nettype" +) + +// xnetBatchReaderWriter defines the batching i/o methods of +// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn). +// TODO(jwhited): This should eventually be replaced with the standard library +// implementation of https://github.com/golang/go/issues/45886 +type xnetBatchReaderWriter interface { + xnetBatchReader + xnetBatchWriter +} + +type xnetBatchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type xnetBatchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +// linuxBatchingConn is a UDP socket that provides batched i/o. It implements +// batchingConn. +type linuxBatchingConn struct { + pc nettype.PacketConn + xpc xnetBatchReaderWriter + rxOffload bool // supports UDP GRO or similar + txOffload atomic.Bool // supports UDP GSO or similar + setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing + getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing + sendBatchPool sync.Pool +} + +func (c *linuxBatchingConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { + if c.rxOffload { + // UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may + // receive a "monster datagram" from any read call. The ReadFrom() API + // does not support passing the GSO size and is unsafe to use in such a + // case. Other platforms may vary in behavior, but we go with the most + // conservative approach to prevent this from becoming a footgun in the + // future. + return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable") + } + return c.pc.ReadFromUDPAddrPort(p) +} + +func (c *linuxBatchingConn) SetDeadline(t time.Time) error { + return c.pc.SetDeadline(t) +} + +func (c *linuxBatchingConn) SetReadDeadline(t time.Time) error { + return c.pc.SetReadDeadline(t) +} + +func (c *linuxBatchingConn) SetWriteDeadline(t time.Time) error { + return c.pc.SetWriteDeadline(t) +} + +const ( + // This was initially established for Linux, but may split out to + // GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the + // kernel's TX path, and UDP_GRO_CNT_MAX for RX. + udpSegmentMaxDatagrams = 64 +) + +const ( + // Exceeding these values results in EMSGSIZE. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 +) + +// coalesceMessages iterates msgs, coalescing them where possible while +// maintaining datagram order. All msgs have their Addr field set to addr. +func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int { + var ( + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of buffs + ) + maxPayloadLen := maxIPv4PayloadLen + if addr.IP.To4() == nil { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buff := range buffs { + if i > 0 { + msgLen := len(buff) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...) + copy(msgs[base].Buffers[0][baseLenBefore:], buff) + if i == len(buffs)-1 { + c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue + } + } + if dgramCnt > 1 { + c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buff) + msgs[base].OOB = msgs[base].OOB[:0] + msgs[base].Buffers[0] = buff + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type sendBatch struct { + msgs []ipv6.Message + ua *net.UDPAddr +} + +func (c *linuxBatchingConn) getSendBatch() *sendBatch { + batch := c.sendBatchPool.Get().(*sendBatch) + return batch +} + +func (c *linuxBatchingConn) putSendBatch(batch *sendBatch) { + for i := range batch.msgs { + batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} + } + c.sendBatchPool.Put(batch) +} + +func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { + batch := c.getSendBatch() + defer c.putSendBatch(batch) + if addr.Addr().Is6() { + as16 := addr.Addr().As16() + copy(batch.ua.IP, as16[:]) + batch.ua.IP = batch.ua.IP[:16] + } else { + as4 := addr.Addr().As4() + copy(batch.ua.IP, as4[:]) + batch.ua.IP = batch.ua.IP[:4] + } + batch.ua.Port = int(addr.Port()) + var ( + n int + retried bool + ) +retry: + if c.txOffload.Load() { + n = c.coalesceMessages(batch.ua, buffs, batch.msgs) + } else { + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + batch.msgs[i].Addr = batch.ua + batch.msgs[i].OOB = batch.msgs[i].OOB[:0] + } + n = len(buffs) + } + + err := c.writeBatch(batch.msgs[:n]) + if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) { + c.txOffload.Store(false) + retried = true + goto retry + } + if retried { + return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err} + } + return err +} + +func (c *linuxBatchingConn) SyscallConn() (syscall.RawConn, error) { + sc, ok := c.pc.(syscall.Conn) + if !ok { + return nil, errUnsupportedConnType + } + return sc.SyscallConn() +} + +func (c *linuxBatchingConn) writeBatch(msgs []ipv6.Message) error { + var head int + for { + n, err := c.xpc.WriteBatch(msgs[head:], 0) + if err != nil || n == len(msgs[head:]) { + // Returning the number of packets written would require + // unraveling individual msg len and gso size during a coalesced + // write. The top of the call stack disregards partial success, + // so keep this simple for now. + return err + } + head += n + } +} + +// splitCoalescedMessages splits coalesced messages from the tail of dst +// beginning at index 'firstMsgAt' into the head of the same slice. It reports +// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An +// error is returned if a socket control message cannot be parsed or a split +// operation would overflow msgs. +func (c *linuxBatchingConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") + } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 + } + } + return n, nil +} + +func (c *linuxBatchingConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) { + if !c.rxOffload || len(msgs) < 2 { + return c.xpc.ReadBatch(msgs, flags) + } + // Read into the tail of msgs, split into the head. + readAt := len(msgs) - 2 + numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0) + if err != nil || numRead == 0 { + return 0, err + } + return c.splitCoalescedMessages(msgs, readAt) +} + +func (c *linuxBatchingConn) LocalAddr() net.Addr { + return c.pc.LocalAddr().(*net.UDPAddr) +} + +func (c *linuxBatchingConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + return c.pc.WriteToUDPAddrPort(b, addr) +} + +func (c *linuxBatchingConn) Close() error { + return c.pc.Close() +} + +// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn, +// and returns two booleans indicating TX and RX UDP offload support. +func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) { + if c, ok := pconn.(*net.UDPConn); ok { + rc, err := c.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + hasTX = errSyscall == nil + errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + hasRX = errSyscall == nil + }) + if err != nil { + return false, false + } + } + return hasTX, hasRX +} + +// getGSOSizeFromControl returns the GSO size found in control. If no GSO size +// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0. +// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but +// its contents cannot be parsed as a socket control message. +func getGSOSizeFromControl(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= 2 { + return int(binary.NativeEndian.Uint16(data[:2])), nil + } + } + return 0, nil +} + +// setGSOSizeInControl sets a socket control message in control containing +// gsoSize. If len(control) < controlMessageSize control's len will be set to 0. +func setGSOSizeInControl(control *[]byte, gsoSize uint16) { + *control = (*control)[:0] + if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { + return + } + if cap(*control) < controlMessageSize { + return + } + *control = (*control)[:cap(*control)] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(2)) + binary.NativeEndian.PutUint16((*control)[unix.SizeofCmsghdr:], gsoSize) + *control = (*control)[:unix.CmsgSpace(2)] +} + +// tryUpgradeToBatchingConn probes the capabilities of the OS and pconn, and +// upgrades pconn to a *linuxBatchingConn if appropriate. +func tryUpgradeToBatchingConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { + if network != "udp4" && network != "udp6" { + return pconn + } + if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") { + // recvmmsg/sendmmsg were added in 2.6.33, but we support down to + // 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807. + // As a cheap heuristic: if the Linux kernel starts with "2", just + // consider it too old for mmsg. Nobody who cares about performance runs + // such ancient kernels. UDP offload was added much later, so no + // upgrades are available. + return pconn + } + uc, ok := pconn.(*net.UDPConn) + if !ok { + return pconn + } + b := &linuxBatchingConn{ + pc: pconn, + getGSOSizeFromControl: getGSOSizeFromControl, + setGSOSizeInControl: setGSOSizeInControl, + sendBatchPool: sync.Pool{ + New: func() any { + ua := &net.UDPAddr{ + IP: make([]byte, 16), + } + msgs := make([]ipv6.Message, batchSize) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Addr = ua + msgs[i].OOB = make([]byte, controlMessageSize) + } + return &sendBatch{ + ua: ua, + msgs: msgs, + } + }, + }, + } + switch network { + case "udp4": + b.xpc = ipv4.NewPacketConn(uc) + case "udp6": + b.xpc = ipv6.NewPacketConn(uc) + default: + panic("bogus network") + } + var txOffload bool + txOffload, b.rxOffload = tryEnableUDPOffload(uc) + b.txOffload.Store(txOffload) + return b +} diff --git a/wgengine/magicsock/batching_conn_linux_test.go b/wgengine/magicsock/batching_conn_linux_test.go new file mode 100644 index 000000000..5c22bf1c7 --- /dev/null +++ b/wgengine/magicsock/batching_conn_linux_test.go @@ -0,0 +1,244 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "encoding/binary" + "net" + "testing" + + "golang.org/x/net/ipv6" +) + +func setGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func getGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) { + c := &linuxBatchingConn{ + setGSOSizeInControl: setGSOSize, + getGSOSizeFromControl: getGSOSize, + } + + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1024)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := c.splitCoalescedMessages(tt.msgs, 2) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} + +func Test_linuxBatchingConn_coalesceMessages(t *testing.T) { + c := &linuxBatchingConn{ + setGSOSizeInControl: setGSOSize, + getGSOSizeFromControl: getGSOSize, + } + + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := c.coalesceMessages(addr, tt.buffs, msgs) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := range got { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := getGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 5ac53c771..1d0fa58c3 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -25,7 +25,6 @@ import ( "github.com/tailscale/wireguard-go/conn" "go4.org/mem" - "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "tailscale.com/control/controlknobs" @@ -1101,12 +1100,6 @@ var errNoUDP = errors.New("no UDP available on platform") var errUnsupportedConnType = errors.New("unsupported connection type") -var ( - // This acts as a compile-time check for our usage of ipv6.Message in - // batchingUDPConn for both IPv6 and IPv4 operations. - _ ipv6.Message = ipv4.Message{} -) - func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) { isIPv6 := false switch { @@ -2656,153 +2649,6 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) { return ep, nil } -func (c *batchingUDPConn) writeBatch(msgs []ipv6.Message) error { - var head int - for { - n, err := c.xpc.WriteBatch(msgs[head:], 0) - if err != nil || n == len(msgs[head:]) { - // Returning the number of packets written would require - // unraveling individual msg len and gso size during a coalesced - // write. The top of the call stack disregards partial success, - // so keep this simple for now. - return err - } - head += n - } -} - -// splitCoalescedMessages splits coalesced messages from the tail of dst -// beginning at index 'firstMsgAt' into the head of the same slice. It reports -// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An -// error is returned if a socket control message cannot be parsed or a split -// operation would overflow msgs. -func (c *batchingUDPConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) { - for i := firstMsgAt; i < len(msgs); i++ { - msg := &msgs[i] - if msg.N == 0 { - return n, err - } - var ( - gsoSize int - start int - end = msg.N - numToSplit = 1 - ) - gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN]) - if err != nil { - return n, err - } - if gsoSize > 0 { - numToSplit = (msg.N + gsoSize - 1) / gsoSize - end = gsoSize - } - for j := 0; j < numToSplit; j++ { - if n > i { - return n, errors.New("splitting coalesced packet resulted in overflow") - } - copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) - msgs[n].N = copied - msgs[n].Addr = msg.Addr - start = end - end += gsoSize - if end > msg.N { - end = msg.N - } - n++ - } - if i != n-1 { - // It is legal for bytes to move within msg.Buffers[0] as a result - // of splitting, so we only zero the source msg len when it is not - // the destination of the last split operation above. - msg.N = 0 - } - } - return n, nil -} - -func (c *batchingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) { - if !c.rxOffload || len(msgs) < 2 { - return c.xpc.ReadBatch(msgs, flags) - } - // Read into the tail of msgs, split into the head. - readAt := len(msgs) - 2 - numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0) - if err != nil || numRead == 0 { - return 0, err - } - return c.splitCoalescedMessages(msgs, readAt) -} - -func (c *batchingUDPConn) LocalAddr() net.Addr { - return c.pc.LocalAddr().(*net.UDPAddr) -} - -func (c *batchingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - return c.pc.WriteToUDPAddrPort(b, addr) -} - -func (c *batchingUDPConn) Close() error { - return c.pc.Close() -} - -// tryUpgradeToBatchingUDPConn probes the capabilities of the OS and pconn, and -// upgrades pconn to a *batchingUDPConn if appropriate. -func tryUpgradeToBatchingUDPConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { - if network != "udp4" && network != "udp6" { - return pconn - } - if runtime.GOOS != "linux" { - return pconn - } - if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") { - // recvmmsg/sendmmsg were added in 2.6.33, but we support down to - // 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807. - // As a cheap heuristic: if the Linux kernel starts with "2", just - // consider it too old for mmsg. Nobody who cares about performance runs - // such ancient kernels. UDP offload was added much later, so no - // upgrades are available. - return pconn - } - uc, ok := pconn.(*net.UDPConn) - if !ok { - return pconn - } - b := &batchingUDPConn{ - pc: pconn, - getGSOSizeFromControl: getGSOSizeFromControl, - setGSOSizeInControl: setGSOSizeInControl, - sendBatchPool: sync.Pool{ - New: func() any { - ua := &net.UDPAddr{ - IP: make([]byte, 16), - } - msgs := make([]ipv6.Message, batchSize) - for i := range msgs { - msgs[i].Buffers = make([][]byte, 1) - msgs[i].Addr = ua - msgs[i].OOB = make([]byte, controlMessageSize) - } - return &sendBatch{ - ua: ua, - msgs: msgs, - } - }, - }, - } - switch network { - case "udp4": - b.xpc = ipv4.NewPacketConn(uc) - case "udp6": - b.xpc = ipv6.NewPacketConn(uc) - default: - panic("bogus network") - } - var txOffload bool - txOffload, b.rxOffload = tryEnableUDPOffload(uc) - b.txOffload.Store(txOffload) - return b -} - func newBlockForeverConn() *blockForeverConn { c := new(blockForeverConn) c.cond = sync.NewCond(&c.mu) diff --git a/wgengine/magicsock/magicsock_default.go b/wgengine/magicsock/magicsock_default.go index 87075e522..321765b8c 100644 --- a/wgengine/magicsock/magicsock_default.go +++ b/wgengine/magicsock/magicsock_default.go @@ -21,16 +21,6 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { portableTrySetSocketBuffer(pconn, logf) } -func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) { - return false, false -} - -func getGSOSizeFromControl(control []byte) (int, error) { - return 0, nil -} - -func setGSOSizeInControl(control *[]byte, gso uint16) {} - const ( controlMessageSize = 0 ) diff --git a/wgengine/magicsock/magicsock_linux.go b/wgengine/magicsock/magicsock_linux.go index c484f77c0..69074fd72 100644 --- a/wgengine/magicsock/magicsock_linux.go +++ b/wgengine/magicsock/magicsock_linux.go @@ -318,70 +318,6 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { } } -// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn, -// and returns two booleans indicating TX and RX UDP offload support. -func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) { - if c, ok := pconn.(*net.UDPConn); ok { - rc, err := c.SyscallConn() - if err != nil { - return - } - err = rc.Control(func(fd uintptr) { - _, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) - hasTX = errSyscall == nil - errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) - hasRX = errSyscall == nil - }) - if err != nil { - return false, false - } - } - return hasTX, hasRX -} - -// getGSOSizeFromControl returns the GSO size found in control. If no GSO size -// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0. -// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but -// its contents cannot be parsed as a socket control message. -func getGSOSizeFromControl(control []byte) (int, error) { - var ( - hdr unix.Cmsghdr - data []byte - rem = control - err error - ) - - for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) - if err != nil { - return 0, fmt.Errorf("error parsing socket control message: %w", err) - } - if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= 2 { - return int(binary.NativeEndian.Uint16(data[:2])), nil - } - } - return 0, nil -} - -// setGSOSizeInControl sets a socket control message in control containing -// gsoSize. If len(control) < controlMessageSize control's len will be set to 0. -func setGSOSizeInControl(control *[]byte, gsoSize uint16) { - *control = (*control)[:0] - if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { - return - } - if cap(*control) < controlMessageSize { - return - } - *control = (*control)[:cap(*control)] - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) - hdr.Level = unix.SOL_UDP - hdr.Type = unix.UDP_SEGMENT - hdr.SetLen(unix.CmsgLen(2)) - binary.NativeEndian.PutUint16((*control)[unix.SizeofCmsghdr:], gsoSize) - *control = (*control)[:unix.CmsgSpace(2)] -} - var controlMessageSize = -1 // bomb if used for allocation before init func init() { diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index a721c24e4..be1b43f56 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -35,7 +35,6 @@ import ( xmaps "golang.org/x/exp/maps" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/control/controlknobs" "tailscale.com/derp" @@ -2038,238 +2037,6 @@ func TestBufferedDerpWritesBeforeDrop(t *testing.T) { t.Logf("bufferedDerpWritesBeforeDrop = %d", vv) } -func setGSOSize(control *[]byte, gsoSize uint16) { - *control = (*control)[:cap(*control)] - binary.LittleEndian.PutUint16(*control, gsoSize) -} - -func getGSOSize(control []byte) (int, error) { - if len(control) < 2 { - return 0, nil - } - return int(binary.LittleEndian.Uint16(control)), nil -} - -func Test_batchingUDPConn_splitCoalescedMessages(t *testing.T) { - c := &batchingUDPConn{ - setGSOSizeInControl: setGSOSize, - getGSOSizeFromControl: getGSOSize, - } - - newMsg := func(n, gso int) ipv6.Message { - msg := ipv6.Message{ - Buffers: [][]byte{make([]byte, 1024)}, - N: n, - OOB: make([]byte, 2), - } - binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) - if gso > 0 { - msg.NN = 2 - } - return msg - } - - cases := []struct { - name string - msgs []ipv6.Message - firstMsgAt int - wantNumEval int - wantMsgLens []int - wantErr bool - }{ - { - name: "second last split last empty", - msgs: []ipv6.Message{ - newMsg(0, 0), - newMsg(0, 0), - newMsg(3, 1), - newMsg(0, 0), - }, - firstMsgAt: 2, - wantNumEval: 3, - wantMsgLens: []int{1, 1, 1, 0}, - wantErr: false, - }, - { - name: "second last no split last empty", - msgs: []ipv6.Message{ - newMsg(0, 0), - newMsg(0, 0), - newMsg(1, 0), - newMsg(0, 0), - }, - firstMsgAt: 2, - wantNumEval: 1, - wantMsgLens: []int{1, 0, 0, 0}, - wantErr: false, - }, - { - name: "second last no split last no split", - msgs: []ipv6.Message{ - newMsg(0, 0), - newMsg(0, 0), - newMsg(1, 0), - newMsg(1, 0), - }, - firstMsgAt: 2, - wantNumEval: 2, - wantMsgLens: []int{1, 1, 0, 0}, - wantErr: false, - }, - { - name: "second last no split last split", - msgs: []ipv6.Message{ - newMsg(0, 0), - newMsg(0, 0), - newMsg(1, 0), - newMsg(3, 1), - }, - firstMsgAt: 2, - wantNumEval: 4, - wantMsgLens: []int{1, 1, 1, 1}, - wantErr: false, - }, - { - name: "second last split last split", - msgs: []ipv6.Message{ - newMsg(0, 0), - newMsg(0, 0), - newMsg(2, 1), - newMsg(2, 1), - }, - firstMsgAt: 2, - wantNumEval: 4, - wantMsgLens: []int{1, 1, 1, 1}, - wantErr: false, - }, - { - name: "second last no split last split overflow", - msgs: []ipv6.Message{ - newMsg(0, 0), - newMsg(0, 0), - newMsg(1, 0), - newMsg(4, 1), - }, - firstMsgAt: 2, - wantNumEval: 4, - wantMsgLens: []int{1, 1, 1, 1}, - wantErr: true, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - got, err := c.splitCoalescedMessages(tt.msgs, 2) - if err != nil && !tt.wantErr { - t.Fatalf("err: %v", err) - } - if got != tt.wantNumEval { - t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) - } - for i, msg := range tt.msgs { - if msg.N != tt.wantMsgLens[i] { - t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) - } - } - }) - } -} - -func Test_batchingUDPConn_coalesceMessages(t *testing.T) { - c := &batchingUDPConn{ - setGSOSizeInControl: setGSOSize, - getGSOSizeFromControl: getGSOSize, - } - - cases := []struct { - name string - buffs [][]byte - wantLens []int - wantGSO []int - }{ - { - name: "one message no coalesce", - buffs: [][]byte{ - make([]byte, 1, 1), - }, - wantLens: []int{1}, - wantGSO: []int{0}, - }, - { - name: "two messages equal len coalesce", - buffs: [][]byte{ - make([]byte, 1, 2), - make([]byte, 1, 1), - }, - wantLens: []int{2}, - wantGSO: []int{1}, - }, - { - name: "two messages unequal len coalesce", - buffs: [][]byte{ - make([]byte, 2, 3), - make([]byte, 1, 1), - }, - wantLens: []int{3}, - wantGSO: []int{2}, - }, - { - name: "three messages second unequal len coalesce", - buffs: [][]byte{ - make([]byte, 2, 3), - make([]byte, 1, 1), - make([]byte, 2, 2), - }, - wantLens: []int{3, 2}, - wantGSO: []int{2, 0}, - }, - { - name: "three messages limited cap coalesce", - buffs: [][]byte{ - make([]byte, 2, 4), - make([]byte, 2, 2), - make([]byte, 2, 2), - }, - wantLens: []int{4, 2}, - wantGSO: []int{2, 0}, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - addr := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 1, - } - msgs := make([]ipv6.Message, len(tt.buffs)) - for i := range msgs { - msgs[i].Buffers = make([][]byte, 1) - msgs[i].OOB = make([]byte, 0, 2) - } - got := c.coalesceMessages(addr, tt.buffs, msgs) - if got != len(tt.wantLens) { - t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) - } - for i := range got { - if msgs[i].Addr != addr { - t.Errorf("msgs[%d].Addr != passed addr", i) - } - gotLen := len(msgs[i].Buffers[0]) - if gotLen != tt.wantLens[i] { - t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) - } - gotGSO, err := getGSOSize(msgs[i].OOB) - if err != nil { - t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) - } - if gotGSO != tt.wantGSO[i] { - t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) - } - } - }) - } -} - // newWireguard starts up a new wireguard-go device attached to a test tun, and // returns the device, tun and endpoint port. To add peers call device.IpcSet with UAPI instructions. func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, uint16) { diff --git a/wgengine/magicsock/rebinding_conn.go b/wgengine/magicsock/rebinding_conn.go index f1e47f3a8..c27abbadc 100644 --- a/wgengine/magicsock/rebinding_conn.go +++ b/wgengine/magicsock/rebinding_conn.go @@ -35,12 +35,12 @@ type RebindingUDPConn struct { // setConnLocked sets the provided nettype.PacketConn. It should be called only // after acquiring RebindingUDPConn.mu. It upgrades the provided -// nettype.PacketConn to a *batchingUDPConn when appropriate. This upgrade -// is intentionally pushed closest to where read/write ops occur in order to -// avoid disrupting surrounding code that assumes nettype.PacketConn is a +// nettype.PacketConn to a batchingConn when appropriate. This upgrade is +// intentionally pushed closest to where read/write ops occur in order to avoid +// disrupting surrounding code that assumes nettype.PacketConn is a // *net.UDPConn. func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) { - upc := tryUpgradeToBatchingUDPConn(p, network, batchSize) + upc := tryUpgradeToBatchingConn(p, network, batchSize) c.pconn = upc c.pconnAtomic.Store(&upc) c.port = uint16(c.localAddrLocked().Port) @@ -74,7 +74,7 @@ func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, e func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { for { pconn := *c.pconnAtomic.Load() - b, ok := pconn.(*batchingUDPConn) + b, ok := pconn.(batchingConn) if !ok { for _, buf := range buffs { _, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr) @@ -101,7 +101,7 @@ func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) err func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) { for { pconn := *c.pconnAtomic.Load() - b, ok := pconn.(*batchingUDPConn) + b, ok := pconn.(batchingConn) if !ok { n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0]) if err == nil {