net/neterror, wgengine/magicsock: use UDP GSO and GRO on Linux (#7791)

This commit implements UDP offloading for Linux. GSO size is passed to
and from the kernel via socket control messages. Support is probed at
runtime.

UDP GSO is dependent on checksum offload support on the egress netdev.
UDP GSO will be disabled in the event sendmmsg() returns EIO, which is
a strong signal that the egress netdev does not support checksum
offload.

Updates tailscale/corp#8734

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/7205/merge
Jordan Whited 2 years ago committed by GitHub
parent 45138fcfba
commit f475e5550c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,6 +6,7 @@ package neterror
import (
"errors"
"fmt"
"runtime"
"syscall"
)
@ -57,3 +58,25 @@ func PacketWasTruncated(err error) bool {
}
return packetWasTruncated(err)
}
var shouldDisableUDPGSO func(error) bool // non-nil on Linux
func ShouldDisableUDPGSO(err error) bool {
if shouldDisableUDPGSO == nil {
return false
}
return shouldDisableUDPGSO(err)
}
type ErrUDPGSODisabled struct {
OnLaddr string
RetryErr error
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}

@ -0,0 +1,26 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package neterror
import (
"errors"
"os"
"golang.org/x/sys/unix"
)
func init() {
shouldDisableUDPGSO = func(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not
// have tx checksumming enabled, which is a hard requirement of
// UDP_SEGMENT. See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
return serr.Err == unix.EIO
}
return false
}
}

@ -44,7 +44,6 @@ import (
"tailscale.com/net/connstats"
"tailscale.com/net/dnscache"
"tailscale.com/net/interfaces"
"tailscale.com/net/netaddr"
"tailscale.com/net/netcheck"
"tailscale.com/net/neterror"
"tailscale.com/net/netns"
@ -281,7 +280,6 @@ type Conn struct {
pconn6 RebindingUDPConn
receiveBatchPool sync.Pool
sendBatchPool sync.Pool
// closeDisco4 and closeDisco6 are io.Closers to shut down the raw
// disco packet receivers. If nil, no raw disco receiver is
@ -597,26 +595,13 @@ func newConn() *Conn {
msgs := make([]ipv6.Message, c.bind.BatchSize())
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, controlMessageSize)
}
batch := &receiveBatch{
msgs: msgs,
}
return batch
}}
c.sendBatchPool = sync.Pool{New: func() any {
ua := &net.UDPAddr{
IP: make([]byte, 16),
}
msgs := make([]ipv6.Message, c.bind.BatchSize())
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
}
return &sendBatch{
ua: ua,
msgs: msgs,
}
}}
c.muCond = sync.NewCond(&c.mu)
c.networkUp.Store(true) // assume up until told otherwise
return c
@ -1301,19 +1286,11 @@ var errNoUDP = errors.New("no UDP available on platform")
var (
// This acts as a compile-time check for our usage of ipv6.Message in
// udpConnWithBatchOps for both IPv6 and IPv4 operations.
// batchingUDPConn for both IPv6 and IPv4 operations.
_ ipv6.Message = ipv4.Message{}
)
type sendBatch struct {
ua *net.UDPAddr
msgs []ipv6.Message // ipv4.Message and ipv6.Message are the same underlying type
}
func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) {
batch := c.sendBatchPool.Get().(*sendBatch)
defer c.sendBatchPool.Put(batch)
isIPv6 := false
switch {
case addr.Addr().Is4():
@ -1322,19 +1299,17 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err
default:
panic("bogus sendUDPBatch addr type")
}
as16 := addr.Addr().As16()
copy(batch.ua.IP, as16[:])
batch.ua.Port = int(addr.Port())
for i, buff := range buffs {
batch.msgs[i].Buffers[0] = buff
batch.msgs[i].Addr = batch.ua
}
if isIPv6 {
_, err = c.pconn6.WriteBatch(batch.msgs[:len(buffs)], 0)
err = c.pconn6.WriteBatchTo(buffs, addr)
} else {
_, err = c.pconn4.WriteBatch(batch.msgs[:len(buffs)], 0)
err = c.pconn4.WriteBatchTo(buffs, addr)
}
if err != nil {
var errGSO neterror.ErrUDPGSODisabled
if errors.As(err, &errGSO) {
c.logf("magicsock: %s", errGSO.Error())
err = errGSO.RetryErr
}
}
return err == nil, err
}
@ -1844,14 +1819,18 @@ type receiveBatch struct {
msgs []ipv6.Message
}
func (c *Conn) getReceiveBatch() *receiveBatch {
func (c *Conn) getReceiveBatchForBuffs(buffs [][]byte) *receiveBatch {
batch := c.receiveBatchPool.Get().(*receiveBatch)
for i := range buffs {
batch.msgs[i].Buffers[0] = buffs[i]
batch.msgs[i].OOB = batch.msgs[i].OOB[:cap(batch.msgs[i].OOB)]
}
return batch
}
func (c *Conn) putReceiveBatch(batch *receiveBatch) {
for i := range batch.msgs {
batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers}
batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB}
}
c.receiveBatchPool.Put(batch)
}
@ -1860,13 +1839,10 @@ func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in
health.ReceiveIPv6.Enter()
defer health.ReceiveIPv6.Exit()
batch := c.getReceiveBatch()
batch := c.getReceiveBatchForBuffs(buffs)
defer c.putReceiveBatch(batch)
for {
for i := range buffs {
batch.msgs[i].Buffers[0] = buffs[i]
}
numMsgs, err := c.pconn6.ReadBatch(batch.msgs, 0)
numMsgs, err := c.pconn6.ReadBatch(batch.msgs[:len(buffs)], 0)
if err != nil {
if neterror.PacketWasTruncated(err) {
// TODO(raggi): discuss whether to log?
@ -1877,6 +1853,10 @@ func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in
reportToCaller := false
for i, msg := range batch.msgs[:numMsgs] {
if msg.N == 0 {
sizes[i] = 0
continue
}
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok {
metricRecvDataIPv6.Add(1)
@ -1898,13 +1878,10 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in
health.ReceiveIPv4.Enter()
defer health.ReceiveIPv4.Exit()
batch := c.getReceiveBatch()
batch := c.getReceiveBatchForBuffs(buffs)
defer c.putReceiveBatch(batch)
for {
for i := range buffs {
batch.msgs[i].Buffers[0] = buffs[i]
}
numMsgs, err := c.pconn4.ReadBatch(batch.msgs, 0)
numMsgs, err := c.pconn4.ReadBatch(batch.msgs[:len(buffs)], 0)
if err != nil {
if neterror.PacketWasTruncated(err) {
// TODO(raggi): discuss whether to log?
@ -1915,6 +1892,10 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in
reportToCaller := false
for i, msg := range batch.msgs[:numMsgs] {
if msg.N == 0 {
sizes[i] = 0
continue
}
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok {
metricRecvDataIPv4.Add(1)
@ -1940,7 +1921,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache)
c.stunReceiveFunc.Load()(b, ipp)
return nil, false
}
if c.handleDiscoMessage(b, ipp, key.NodePublic{}) {
if c.handleDiscoMessage(b, ipp, key.NodePublic{}, discoRXPathUDP) {
return nil, false
}
if !c.havePrivateKey.Load() {
@ -2005,7 +1986,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
}
ipp := netip.AddrPortFrom(derpMagicIPAddr, uint16(regionID))
if c.handleDiscoMessage(b[:n], ipp, dm.src) {
if c.handleDiscoMessage(b[:n], ipp, dm.src, discoRXPathDERP) {
return 0, nil
}
@ -2139,6 +2120,14 @@ func discoPcapFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []by
return b.Bytes()
}
type discoRXPath string
const (
discoRXPathUDP discoRXPath = "UDP socket"
discoRXPathDERP discoRXPath = "DERP"
discoRXPathRawSocket discoRXPath = "raw socket"
)
// handleDiscoMessage handles a discovery message and reports whether
// msg was a Tailscale inter-node discovery message.
//
@ -2153,7 +2142,7 @@ func discoPcapFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []by
// src.Port() being the region ID) and the derpNodeSrc will be the node key
// it was received from at the DERP layer. derpNodeSrc is zero when received
// over UDP.
func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic) (isDiscoMsg bool) {
func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic, via discoRXPath) (isDiscoMsg bool) {
const headerLen = len(disco.Magic) + key.DiscoPublicRawLen
if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic {
return false
@ -2174,7 +2163,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke
return
}
if debugDisco() {
c.logf("magicsock: disco: got disco-looking frame from %v", sender.ShortString())
c.logf("magicsock: disco: got disco-looking frame from %v via %s", sender.ShortString(), via)
}
if c.privateKey.IsZero() {
// Ignore disco messages when we're stopped.
@ -2210,12 +2199,14 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke
// disco key. When we restart we get a new disco key
// and old packets might've still been in flight (or
// scheduled). This is particularly the case for LANs
// or non-NATed endpoints.
// Don't log in normal case. Pass on to wireguard, in case
// it's actually a wireguard packet (super unlikely,
// but).
// or non-NATed endpoints. UDP offloading on Linux
// can also cause this when a disco message is
// received via raw socket at the head of a coalesced
// group of messages. Don't log in normal case.
// Callers may choose to pass on to wireguard, in case
// it's actually a wireguard packet (super unlikely, but).
if debugDisco() {
c.logf("magicsock: disco: failed to open naclbox from %v (wrong rcpt?)", sender)
c.logf("magicsock: disco: failed to open naclbox from %v (wrong rcpt?) via %s", sender, via)
}
metricRecvDiscoBadKey.Add(1)
return
@ -3205,13 +3196,13 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur
defer ruc.mu.Unlock()
if runtime.GOOS == "js" {
ruc.setConnLocked(newBlockForeverConn(), "")
ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize())
return nil
}
if debugAlwaysDERP() {
c.logf("disabled %v per TS_DEBUG_ALWAYS_USE_DERP", network)
ruc.setConnLocked(newBlockForeverConn(), "")
ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize())
return nil
}
@ -3253,7 +3244,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur
if debugBindSocket() {
c.logf("magicsock: bindSocket: successfully listened %v port %d", network, port)
}
ruc.setConnLocked(pconn, network)
ruc.setConnLocked(pconn, network, c.bind.BatchSize())
if network == "udp4" {
health.SetUDP4Unbound(false)
}
@ -3264,7 +3255,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur
// Set pconn to a dummy conn whose reads block until closed.
// This keeps the receive funcs alive for a future in which
// we get a link change and we can try binding again.
ruc.setConnLocked(newBlockForeverConn(), "")
ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize())
if network == "udp4" {
health.SetUDP4Unbound(true)
}
@ -3361,49 +3352,332 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) {
return ep, nil
}
type batchReaderWriter interface {
batchReader
batchWriter
// xnetBatchReaderWriter defines the batching i/o methods of
// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn).
// TODO(jwhited): This should eventually be replaced with the standard library
// implementation of https://github.com/golang/go/issues/45886
type xnetBatchReaderWriter interface {
xnetBatchReader
xnetBatchWriter
}
type xnetBatchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
type xnetBatchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
// batchingUDPConn is a UDP socket that provides batched i/o.
type batchingUDPConn struct {
pc nettype.PacketConn
xpc xnetBatchReaderWriter
rxOffload bool // supports UDP GRO or similar
txOffload atomic.Bool // supports UDP GSO or similar
setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing
getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing
sendBatchPool sync.Pool
}
// udpConnWithBatchOps wraps a *net.UDPConn in order to extend it to support
// batch operations.
//
// TODO(jwhited): This wrapping is temporary. https://github.com/golang/go/issues/45886
type udpConnWithBatchOps struct {
*net.UDPConn
xpc batchReaderWriter
func (c *batchingUDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.rxOffload {
// UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may
// receive a "monster datagram" from any read call. The ReadFrom() API
// does not support passing the GSO size and is unsafe to use in such a
// case. Other platforms may vary in behavior, but we go with the most
// conservative approach to prevent this from becoming a footgun in the
// future.
return 0, nil, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable")
}
return c.pc.ReadFrom(p)
}
func newUDPConnWithBatchOps(conn *net.UDPConn, network string) udpConnWithBatchOps {
ucbo := udpConnWithBatchOps{
UDPConn: conn,
func (c *batchingUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.pc.WriteTo(b, addr)
}
func (c *batchingUDPConn) SetDeadline(t time.Time) error {
return c.pc.SetDeadline(t)
}
func (c *batchingUDPConn) SetReadDeadline(t time.Time) error {
return c.pc.SetReadDeadline(t)
}
func (c *batchingUDPConn) SetWriteDeadline(t time.Time) error {
return c.pc.SetWriteDeadline(t)
}
const (
// This was initially established for Linux, but may split out to
// GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the
// kernel's TX path, and UDP_GRO_CNT_MAX for RX.
udpSegmentMaxDatagrams = 64
)
const (
// Exceeding these values results in EMSGSIZE.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
)
// coalesceMessages iterates msgs, coalescing them where possible while
// maintaining datagram order. All msgs have their Addr field set to addr.
func (c *batchingUDPConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of buffs
)
maxPayloadLen := maxIPv4PayloadLen
if addr.IP.To4() == nil {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buff := range buffs {
if i > 0 {
msgLen := len(buff)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...)
copy(msgs[base].Buffers[0][baseLenBefore:], buff)
if i == len(buffs)-1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buff)
msgs[base].OOB = msgs[base].OOB[:0]
msgs[base].Buffers[0] = buff
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type sendBatch struct {
msgs []ipv6.Message
ua *net.UDPAddr
}
func (c *batchingUDPConn) getSendBatch() *sendBatch {
batch := c.sendBatchPool.Get().(*sendBatch)
return batch
}
func (c *batchingUDPConn) putSendBatch(batch *sendBatch) {
for i := range batch.msgs {
batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB}
}
c.sendBatchPool.Put(batch)
}
func (c *batchingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
batch := c.getSendBatch()
defer c.putSendBatch(batch)
if addr.Addr().Is6() {
as16 := addr.Addr().As16()
copy(batch.ua.IP, as16[:])
batch.ua.IP = batch.ua.IP[:16]
} else {
as4 := addr.Addr().As4()
copy(batch.ua.IP, as4[:])
batch.ua.IP = batch.ua.IP[:4]
}
batch.ua.Port = int(addr.Port())
var (
n int
retried bool
)
retry:
if c.txOffload.Load() {
n = c.coalesceMessages(batch.ua, buffs, batch.msgs)
} else {
for i := range buffs {
batch.msgs[i].Buffers[0] = buffs[i]
batch.msgs[i].Addr = batch.ua
batch.msgs[i].OOB = batch.msgs[i].OOB[:0]
}
n = len(buffs)
}
err := c.writeBatch(batch.msgs[:n])
if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) {
c.txOffload.Store(false)
retried = true
goto retry
}
if retried {
return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err}
}
return err
}
func (c *batchingUDPConn) writeBatch(msgs []ipv6.Message) error {
var head int
for {
n, err := c.xpc.WriteBatch(msgs[head:], 0)
if err != nil || n == len(msgs[head:]) {
// Returning the number of packets written would require
// unraveling individual msg len and gso size during a coalesced
// write. The top of the call stack disregards partial success,
// so keep this simple for now.
return err
}
head += n
}
}
// splitCoalescedMessages splits coalesced messages from the tail of dst
// beginning at index 'firstMsgAt' into the head of the same slice. It reports
// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An
// error is returned if a socket control message cannot be parsed or a split
// operation would overflow msgs.
func (c *batchingUDPConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}
func (c *batchingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) {
if !c.rxOffload || len(msgs) < 2 {
return c.xpc.ReadBatch(msgs, flags)
}
// Read into the tail of msgs, split into the head.
readAt := len(msgs) - 2
numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0)
if err != nil || numRead == 0 {
return 0, err
}
return c.splitCoalescedMessages(msgs, readAt)
}
func (c *batchingUDPConn) LocalAddr() net.Addr {
return c.pc.LocalAddr().(*net.UDPAddr)
}
func (c *batchingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
return c.pc.WriteToUDPAddrPort(b, addr)
}
func (c *batchingUDPConn) Close() error {
return c.pc.Close()
}
// tryUpgradeToBatchingUDPConn probes the capabilities of the OS and pconn, and
// upgrades pconn to a *batchingUDPConn if appropriate.
func tryUpgradeToBatchingUDPConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
if network != "udp4" && network != "udp6" {
return pconn
}
if runtime.GOOS != "linux" {
return pconn
}
if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") {
// recvmmsg/sendmmsg were added in 2.6.33, but we support down to
// 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807.
// As a cheap heuristic: if the Linux kernel starts with "2", just
// consider it too old for mmsg. Nobody who cares about performance runs
// such ancient kernels. UDP offload was added much later, so no
// upgrades are available.
return pconn
}
uc, ok := pconn.(*net.UDPConn)
if !ok {
return pconn
}
b := &batchingUDPConn{
pc: pconn,
getGSOSizeFromControl: getGSOSizeFromControl,
setGSOSizeInControl: setGSOSizeInControl,
sendBatchPool: sync.Pool{
New: func() any {
ua := &net.UDPAddr{
IP: make([]byte, 16),
}
msgs := make([]ipv6.Message, batchSize)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
msgs[i].OOB = make([]byte, controlMessageSize)
}
return &sendBatch{
ua: ua,
msgs: msgs,
}
},
},
}
switch network {
case "udp4":
ucbo.xpc = ipv4.NewPacketConn(conn)
b.xpc = ipv4.NewPacketConn(uc)
case "udp6":
ucbo.xpc = ipv6.NewPacketConn(conn)
b.xpc = ipv6.NewPacketConn(uc)
default:
panic("bogus network")
}
return ucbo
}
func (u udpConnWithBatchOps) WriteBatch(ms []ipv6.Message, flags int) (int, error) {
return u.xpc.WriteBatch(ms, flags)
}
func (u udpConnWithBatchOps) ReadBatch(ms []ipv6.Message, flags int) (int, error) {
return u.xpc.ReadBatch(ms, flags)
var txOffload bool
txOffload, b.rxOffload = tryEnableUDPOffload(uc)
b.txOffload.Store(txOffload)
return b
}
// RebindingUDPConn is a UDP socket that can be re-bound.
@ -3423,34 +3697,14 @@ type RebindingUDPConn struct {
port uint16
}
// upgradePacketConn may upgrade a nettype.PacketConn to a udpConnWithBatchOps.
func upgradePacketConn(p nettype.PacketConn, network string) nettype.PacketConn {
uc, ok := p.(*net.UDPConn)
if ok && runtime.GOOS == "linux" && (network == "udp4" || network == "udp6") {
// recvmmsg/sendmmsg were added in 2.6.33 but we support down to 2.6.32
// for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807.
// As a cheap heuristic: if the Linux kernel starts with "2", just consider
// it too old for the fast paths. Nobody who cares about performance runs such
// ancient kernels.
if strings.HasPrefix(hostinfo.GetOSVersion(), "2") {
return p
}
// Non-Linux does not support batch operations. x/net will fall back to
// recv/sendmsg, but not all platforms have recv/sendmsg support. Keep
// this simple for now.
return newUDPConnWithBatchOps(uc, network)
}
return p
}
// setConnLocked sets the provided nettype.PacketConn. It should be called only
// after acquiring RebindingUDPConn.mu. It upgrades the provided
// nettype.PacketConn to a udpConnWithBatchOps when appropriate. This upgrade
// nettype.PacketConn to a *batchingUDPConn when appropriate. This upgrade
// is intentionally pushed closest to where read/write ops occur in order to
// avoid disrupting surrounding code that assumes nettype.PacketConn is a
// *net.UDPConn.
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string) {
upc := upgradePacketConn(p, network)
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) {
upc := tryUpgradeToBatchingUDPConn(p, network, batchSize)
c.pconn = upc
c.pconnAtomic.Store(&upc)
c.port = uint16(c.localAddrLocked().Port)
@ -3480,83 +3734,38 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b)
}
// ReadFromNetaddr reads a packet from c into b.
// It returns the number of bytes copied and the return address.
// It is identical to c.ReadFrom, except that it returns a netip.AddrPort instead of a net.Addr.
// ReadFromNetaddr is designed to work with specific underlying connection types.
// If c's underlying connection returns a non-*net.UPDAddr return address, ReadFromNetaddr will return an error.
// ReadFromNetaddr exists because it removes an allocation per read,
// when c's underlying connection is a net.UDPConn.
func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netip.AddrPort, err error) {
// WriteBatchTo writes buffs to addr.
func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
for {
pconn := *c.pconnAtomic.Load()
// Optimization: Treat *net.UDPConn specially.
// This lets us avoid allocations by calling ReadFromUDPAddrPort.
// The non-*net.UDPConn case works, but it allocates.
if udpConn, ok := pconn.(*net.UDPConn); ok {
n, ipp, err = udpConn.ReadFromUDPAddrPort(b)
} else {
var addr net.Addr
n, addr, err = pconn.ReadFrom(b)
pAddr, ok := addr.(*net.UDPAddr)
if addr != nil && !ok {
return 0, netip.AddrPort{}, fmt.Errorf("RebindingUDPConn.ReadFromNetaddr: underlying connection returned address of type %T, want *netaddr.UDPAddr", addr)
}
if pAddr != nil {
ipp = netaddr.Unmap(pAddr.AddrPort())
if !ipp.IsValid() {
return 0, netip.AddrPort{}, errors.New("netaddr.FromStdAddr failed")
}
}
}
if err != nil && pconn != c.currentConn() {
// The connection changed underfoot. Try again.
continue
}
return n, ipp, err
}
}
func (c *RebindingUDPConn) WriteBatch(msgs []ipv6.Message, flags int) (int, error) {
var (
n int
err error
start int
)
for {
pconn := *c.pconnAtomic.Load()
bw, ok := pconn.(batchWriter)
b, ok := pconn.(*batchingUDPConn)
if !ok {
for _, msg := range msgs {
_, err = c.writeToWithInitPconn(pconn, msg.Buffers[0], msg.Addr)
for _, buf := range buffs {
_, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr)
if err != nil {
return n, err
return err
}
n++
}
return n, nil
return nil
}
n, err = bw.WriteBatch(msgs[start:], flags)
err := b.WriteBatchTo(buffs, addr)
if err != nil {
if pconn != c.currentConn() {
continue
}
return n, err
} else if n == len(msgs[start:]) {
return len(msgs), nil
} else {
start += n
return err
}
return err
}
}
// ReadBatch reads messages from c into msgs. It returns the number of messages
// the caller should evaluate for nonzero len, as a zero len message may fall
// on either side of a nonzero.
func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) {
for {
pconn := *c.pconnAtomic.Load()
br, ok := pconn.(batchReader)
b, ok := pconn.(*batchingUDPConn)
if !ok {
var err error
msgs[0].N, msgs[0].Addr, err = c.readFromWithInitPconn(pconn, msgs[0].Buffers[0])
@ -3565,7 +3774,7 @@ func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error
}
return 0, err
}
n, err := br.ReadBatch(msgs, flags)
n, err := b.ReadBatch(msgs, flags)
if err != nil && pconn != c.currentConn() {
continue
}
@ -3607,9 +3816,9 @@ func (c *RebindingUDPConn) closeLocked() error {
return c.pconn.Close()
}
func (c *RebindingUDPConn) writeToWithInitPconn(pconn nettype.PacketConn, b []byte, addr net.Addr) (int, error) {
func (c *RebindingUDPConn) writeToUDPAddrPortWithInitPconn(pconn nettype.PacketConn, b []byte, addr netip.AddrPort) (int, error) {
for {
n, err := pconn.WriteTo(b, addr)
n, err := pconn.WriteToUDPAddrPort(b, addr)
if err != nil && pconn != c.currentConn() {
pconn = *c.pconnAtomic.Load()
continue
@ -3619,13 +3828,9 @@ func (c *RebindingUDPConn) writeToWithInitPconn(pconn nettype.PacketConn, b []by
}
func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.writeToWithInitPconn(*c.pconnAtomic.Load(), b, addr)
}
func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
for {
pconn := *c.pconnAtomic.Load()
n, err := pconn.WriteToUDPAddrPort(b, addr)
n, err := pconn.WriteTo(b, addr)
if err != nil && pconn != c.currentConn() {
continue
}
@ -3633,6 +3838,10 @@ func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (in
}
}
func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
return c.writeToUDPAddrPortWithInitPconn(*c.pconnAtomic.Load(), b, addr)
}
func newBlockForeverConn() *blockForeverConn {
c := new(blockForeverConn)
c.cond = sync.NewCond(&c.mu)
@ -3665,20 +3874,6 @@ func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (in
return len(p), nil
}
func (c *blockForeverConn) ReadBatch(p []ipv6.Message, flags int) (int, error) {
c.mu.Lock()
for !c.closed {
c.cond.Wait()
}
c.mu.Unlock()
return 0, net.ErrClosed
}
func (c *blockForeverConn) WriteBatch(p []ipv6.Message, flags int) (int, error) {
// Silently drop writes.
return len(p), nil
}
func (c *blockForeverConn) LocalAddr() net.Addr {
// Return a *net.UDPAddr because lots of code assumes that it will.
return new(net.UDPAddr)

@ -20,3 +20,21 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) {
portableTrySetSocketBuffer(pconn, logf)
}
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
return false, false
}
func getGSOSizeFromControl(control []byte) (int, error) {
return 0, nil
}
func setGSOSizeInControl(control *[]byte, gso uint16) {}
func errShouldDisableOffload(err error) bool {
return false
}
const (
controlMessageSize = 0
)

@ -258,7 +258,7 @@ func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) {
metricRecvDiscoPacketIPv6.Add(1)
}
c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{})
c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{}, discoRXPathRawSocket)
}
}
@ -317,3 +317,90 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) {
}
}
}
const (
// TODO(jwhited): upstream to unix?
socketOptionLevelUDP = 17
socketOptionUDPSegment = 103
socketOptionUDPGRO = 104
)
// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn,
// and returns two booleans indicating TX and RX UDP offload support.
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
if c, ok := pconn.(*net.UDPConn); ok {
rc, err := c.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPSegment)
if errSyscall != nil {
// no point in checking RX, TX support was added first.
return
}
hasTX = true
errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPGRO, 1)
hasRX = errSyscall == nil
})
if err != nil {
return false, false
}
}
return hasTX, hasRX
}
// getGSOSizeFromControl returns the GSO size found in control. If no GSO size
// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0.
// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but
// its contents cannot be parsed as a socket control message.
func getGSOSizeFromControl(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == socketOptionLevelUDP && hdr.Type == socketOptionUDPGRO && len(data) >= 2 {
var gso uint16
// TODO(jwhited): replace with encoding/binary.NativeEndian when it's available
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), 2), data[:2])
return int(gso), nil
}
}
return 0, nil
}
// setGSOSizeInControl sets a socket control message in control containing
// gsoSize. If len(control) < controlMessageSize control's len will be set to 0.
func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
*control = (*control)[:0]
if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
return
}
if cap(*control) < controlMessageSize {
return
}
*control = (*control)[:cap(*control)]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
hdr.Level = socketOptionLevelUDP
hdr.Type = socketOptionUDPSegment
hdr.SetLen(unix.CmsgLen(2))
// TODO(jwhited): replace with encoding/binary.NativeEndian when it's available
copy((*control)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), 2))
*control = (*control)[:unix.CmsgSpace(2)]
}
var controlMessageSize = -1 // bomb if used for allocation before init
func init() {
// controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control
// message. These contain a single uint16 of data.
controlMessageSize = unix.CmsgSpace(2)
}

@ -31,6 +31,7 @@ import (
"github.com/tailscale/wireguard-go/tun/tuntest"
"go4.org/mem"
"golang.org/x/exp/maps"
"golang.org/x/net/ipv6"
"tailscale.com/cmd/testwrapper/flakytest"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
@ -1156,7 +1157,7 @@ func TestDiscoMessage(t *testing.T) {
box := peer1Priv.Shared(c.discoPrivate.Public()).Seal([]byte(payload))
pkt = append(pkt, box...)
got := c.handleDiscoMessage(pkt, netip.AddrPort{}, key.NodePublic{})
got := c.handleDiscoMessage(pkt, netip.AddrPort{}, key.NodePublic{}, discoRXPathUDP)
if !got {
t.Error("failed to open it")
}
@ -1832,8 +1833,8 @@ func TestRebindingUDPConn(t *testing.T) {
t.Fatal(err)
}
defer realConn.Close()
c.setConnLocked(realConn.(nettype.PacketConn), "udp4")
c.setConnLocked(newBlockForeverConn(), "")
c.setConnLocked(realConn.(nettype.PacketConn), "udp4", 1)
c.setConnLocked(newBlockForeverConn(), "", 1)
}
// https://github.com/tailscale/tailscale/issues/6680: don't ignore
@ -1861,3 +1862,235 @@ func TestBufferedDerpWritesBeforeDrop(t *testing.T) {
}
t.Logf("bufferedDerpWritesBeforeDrop = %d", vv)
}
func setGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func getGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_batchingUDPConn_splitCoalescedMessages(t *testing.T) {
c := &batchingUDPConn{
setGSOSizeInControl: setGSOSize,
getGSOSizeFromControl: getGSOSize,
}
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1024)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := c.splitCoalescedMessages(tt.msgs, 2)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}
func Test_batchingUDPConn_coalesceMessages(t *testing.T) {
c := &batchingUDPConn{
setGSOSizeInControl: setGSOSize,
getGSOSizeFromControl: getGSOSize,
}
cases := []struct {
name string
buffs [][]byte
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
make([]byte, 1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
make([]byte, 1, 2),
make([]byte, 1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
make([]byte, 2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
make([]byte, 2, 4),
make([]byte, 2, 2),
make([]byte, 2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := c.coalesceMessages(addr, tt.buffs, msgs)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := 0; i < got; i++ {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := getGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}

Loading…
Cancel
Save