net/tstun,wgengine{/netstack/gro}: refactor and re-enable gVisor GRO for Linux (#13172)

In 2f27319baf we disabled GRO due to a
data race around concurrent calls to tstun.Wrapper.Write(). This commit
refactors GRO to be thread-safe, and re-enables it on Linux.

This refactor now carries a GRO type across tstun and netstack APIs
with a lifetime that is scoped to a single tstun.Wrapper.Write() call.

In 25f0a3fc8f we used build tags to
prevent importation of gVisor's GRO package on iOS as at the time we
believed it was contributing to additional memory usage on that
platform. It wasn't, so this commit simplifies and removes those
build tags.

Updates tailscale/corp#22353
Updates tailscale/corp#22125
Updates #6816

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/13202/head
Jordan Whited 3 months ago committed by GitHub
parent 93dc2ded6e
commit df6014f1d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -314,7 +314,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+
gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+
💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ 💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+
gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack/gro
gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+
gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack
gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+
@ -828,6 +828,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
💣 tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ 💣 tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+
tailscale.com/wgengine/netlog from tailscale.com/wgengine tailscale.com/wgengine/netlog from tailscale.com/wgengine
tailscale.com/wgengine/netstack from tailscale.com/tsnet tailscale.com/wgengine/netstack from tailscale.com/tsnet
tailscale.com/wgengine/netstack/gro from tailscale.com/net/tstun+
tailscale.com/wgengine/router from tailscale.com/ipn/ipnlocal+ tailscale.com/wgengine/router from tailscale.com/ipn/ipnlocal+
tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+ tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+
tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal

@ -225,7 +225,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+
gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+
💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ 💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+
gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack/gro
gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+
gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack
gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+
@ -420,6 +420,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
💣 tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ 💣 tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+
tailscale.com/wgengine/netlog from tailscale.com/wgengine tailscale.com/wgengine/netlog from tailscale.com/wgengine
tailscale.com/wgengine/netstack from tailscale.com/cmd/tailscaled tailscale.com/wgengine/netstack from tailscale.com/cmd/tailscaled
tailscale.com/wgengine/netstack/gro from tailscale.com/net/tstun+
tailscale.com/wgengine/router from tailscale.com/cmd/tailscaled+ tailscale.com/wgengine/router from tailscale.com/cmd/tailscaled+
tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+ tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+
tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal

@ -36,6 +36,7 @@ import (
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/wgengine/capture" "tailscale.com/wgengine/capture"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/netstack/gro"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
) )
@ -74,6 +75,15 @@ var parsedPacketPool = sync.Pool{New: func() any { return new(packet.Parsed) }}
// It must not hold onto the packet struct, as its backing storage will be reused. // It must not hold onto the packet struct, as its backing storage will be reused.
type FilterFunc func(*packet.Parsed, *Wrapper) filter.Response type FilterFunc func(*packet.Parsed, *Wrapper) filter.Response
// GROFilterFunc is a FilterFunc extended with a *gro.GRO, enabling increased
// throughput where GRO is supported by a packet.Parsed interceptor, e.g.
// netstack/gVisor, and we are handling a vector of packets. Callers must pass a
// nil g for the first packet in a given vector, and continue passing the
// returned *gro.GRO for all remaining packets in said vector. If g is non-nil
// after the last packet for a given vector is passed through the GROFilterFunc,
// the caller must also call g.Flush().
type GROFilterFunc func(p *packet.Parsed, w *Wrapper, g *gro.GRO) (filter.Response, *gro.GRO)
// Wrapper augments a tun.Device with packet filtering and injection. // Wrapper augments a tun.Device with packet filtering and injection.
// //
// A Wrapper starts in a "corked" mode where Read calls are blocked // A Wrapper starts in a "corked" mode where Read calls are blocked
@ -161,11 +171,7 @@ type Wrapper struct {
// and therefore sees the packets that may be later dropped by it. // and therefore sees the packets that may be later dropped by it.
PreFilterPacketInboundFromWireGuard FilterFunc PreFilterPacketInboundFromWireGuard FilterFunc
// PostFilterPacketInboundFromWireGuard is the inbound filter function that runs after the main filter. // PostFilterPacketInboundFromWireGuard is the inbound filter function that runs after the main filter.
PostFilterPacketInboundFromWireGuard FilterFunc PostFilterPacketInboundFromWireGuard GROFilterFunc
// EndPacketVectorInboundFromWireGuardFlush is a function that runs after all packets in a given vector
// have been handled by all filters. Filters may queue packets for the purposes of GRO, requiring an
// explicit flush.
EndPacketVectorInboundFromWireGuardFlush func()
// PreFilterPacketOutboundToWireGuardNetstackIntercept is a filter function that runs before the main filter // PreFilterPacketOutboundToWireGuardNetstackIntercept is a filter function that runs before the main filter
// for packets from the local system. This filter is populated by netstack to hook // for packets from the local system. This filter is populated by netstack to hook
// packets that should be handled by netstack. If set, this filter runs before // packets that should be handled by netstack. If set, this filter runs before
@ -1061,7 +1067,7 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
return n, err return n, err
} }
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback, pc *peerConfigTable) filter.Response { func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) {
if captHook != nil { if captHook != nil {
captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
} }
@ -1070,7 +1076,7 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca
if pingReq, ok := p.AsTSMPPing(); ok { if pingReq, ok := p.AsTSMPPing(); ok {
t.noteActivity() t.noteActivity()
t.injectOutboundPong(p, pingReq) t.injectOutboundPong(p, pingReq)
return filter.DropSilently return filter.DropSilently, gro
} else if data, ok := p.AsTSMPPong(); ok { } else if data, ok := p.AsTSMPPong(); ok {
if f := t.OnTSMPPongReceived; f != nil { if f := t.OnTSMPPongReceived; f != nil {
f(data) f(data)
@ -1082,7 +1088,7 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca
if f := t.OnICMPEchoResponseReceived; f != nil && f(p) { if f := t.OnICMPEchoResponseReceived; f != nil && f(p) {
// Note: this looks dropped in metrics, even though it was // Note: this looks dropped in metrics, even though it was
// handled internally. // handled internally.
return filter.DropSilently return filter.DropSilently, gro
} }
} }
@ -1094,12 +1100,12 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca
t.isSelfDisco(p) { t.isSelfDisco(p) {
t.limitedLogf("[unexpected] received self disco in packet over tstun; dropping") t.limitedLogf("[unexpected] received self disco in packet over tstun; dropping")
metricPacketInDropSelfDisco.Add(1) metricPacketInDropSelfDisco.Add(1)
return filter.DropSilently return filter.DropSilently, gro
} }
if t.PreFilterPacketInboundFromWireGuard != nil { if t.PreFilterPacketInboundFromWireGuard != nil {
if res := t.PreFilterPacketInboundFromWireGuard(p, t); res.IsDrop() { if res := t.PreFilterPacketInboundFromWireGuard(p, t); res.IsDrop() {
return res return res, gro
} }
} }
@ -1110,7 +1116,7 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca
filt = t.filter.Load() filt = t.filter.Load()
} }
if filt == nil { if filt == nil {
return filter.Drop return filter.Drop, gro
} }
outcome := filt.RunIn(p, t.filterFlags) outcome := filt.RunIn(p, t.filterFlags)
@ -1150,20 +1156,24 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca
// TODO(bradfitz): also send a TCP RST, after the TSMP message. // TODO(bradfitz): also send a TCP RST, after the TSMP message.
} }
return filter.Drop return filter.Drop, gro
} }
if t.PostFilterPacketInboundFromWireGuard != nil { if t.PostFilterPacketInboundFromWireGuard != nil {
if res := t.PostFilterPacketInboundFromWireGuard(p, t); res.IsDrop() { var res filter.Response
return res res, gro = t.PostFilterPacketInboundFromWireGuard(p, t, gro)
if res.IsDrop() {
return res, gro
} }
} }
return filter.Accept return filter.Accept, gro
} }
// Write accepts incoming packets. The packets begins at buffs[:][offset:], // Write accepts incoming packets. The packets begin at buffs[:][offset:],
// like wireguard-go/tun.Device.Write. // like wireguard-go/tun.Device.Write. Write is called per-peer via
// wireguard-go/device.Peer.RoutineSequentialReceiver, so it MUST be
// thread-safe.
func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
metricPacketIn.Add(int64(len(buffs))) metricPacketIn.Add(int64(len(buffs)))
i := 0 i := 0
@ -1171,11 +1181,17 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
defer parsedPacketPool.Put(p) defer parsedPacketPool.Put(p)
captHook := t.captureHook.Load() captHook := t.captureHook.Load()
pc := t.peerConfig.Load() pc := t.peerConfig.Load()
var buffsGRO *gro.GRO
for _, buff := range buffs { for _, buff := range buffs {
p.Decode(buff[offset:]) p.Decode(buff[offset:])
pc.dnat(p) pc.dnat(p)
if !t.disableFilter { if !t.disableFilter {
if t.filterPacketInboundFromWireGuard(p, captHook, pc) != filter.Accept { var res filter.Response
// TODO(jwhited): name and document this filter code path
// appropriately. It is not only responsible for filtering, it
// also routes packets towards gVisor/netstack.
res, buffsGRO = t.filterPacketInboundFromWireGuard(p, captHook, pc, buffsGRO)
if res != filter.Accept {
metricPacketInDrop.Add(1) metricPacketInDrop.Add(1)
} else { } else {
buffs[i] = buff buffs[i] = buff
@ -1183,8 +1199,8 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
} }
} }
} }
if t.EndPacketVectorInboundFromWireGuardFlush != nil { if buffsGRO != nil {
t.EndPacketVectorInboundFromWireGuardFlush() buffsGRO.Flush()
} }
if t.disableFilter { if t.disableFilter {
i = len(buffs) i = len(buffs)

@ -552,7 +552,7 @@ func TestPeerAPIBypass(t *testing.T) {
tt.w.SetFilter(tt.filter) tt.w.SetFilter(tt.filter)
tt.w.disableTSMPRejected = true tt.w.disableTSMPRejected = true
tt.w.logf = t.Logf tt.w.logf = t.Logf
if got := tt.w.filterPacketInboundFromWireGuard(p, nil, nil); got != tt.want { if got, _ := tt.w.filterPacketInboundFromWireGuard(p, nil, nil, nil); got != tt.want {
t.Errorf("got = %v; want %v", got, tt.want) t.Errorf("got = %v; want %v", got, tt.want)
} }
}) })
@ -582,7 +582,7 @@ func TestFilterDiscoLoop(t *testing.T) {
p := new(packet.Parsed) p := new(packet.Parsed)
p.Decode(pkt) p.Decode(pkt)
got := tw.filterPacketInboundFromWireGuard(p, nil, nil) got, _ := tw.filterPacketInboundFromWireGuard(p, nil, nil, nil)
if got != filter.DropSilently { if got != filter.DropSilently {
t.Errorf("got %v; want DropSilently", got) t.Errorf("got %v; want DropSilently", got)
} }

@ -0,0 +1,169 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package gro implements GRO for the receive (write) path into gVisor.
package gro
import (
"bytes"
"sync"
"github.com/tailscale/wireguard-go/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nsgro "gvisor.dev/gvisor/pkg/tcpip/stack/gro"
"tailscale.com/net/packet"
"tailscale.com/types/ipproto"
)
// RXChecksumOffload validates IPv4, TCP, and UDP header checksums in p,
// returning an equivalent *stack.PacketBuffer if they are valid, otherwise nil.
// The set of headers validated covers where gVisor would perform validation if
// !stack.PacketBuffer.RXChecksumValidated, i.e. it satisfies
// stack.CapabilityRXChecksumOffload. Other protocols with checksum fields,
// e.g. ICMP{v6}, are still validated by gVisor regardless of rx checksum
// offloading capabilities.
func RXChecksumOffload(p *packet.Parsed) *stack.PacketBuffer {
var (
pn tcpip.NetworkProtocolNumber
csumStart int
)
buf := p.Buffer()
switch p.IPVersion {
case 4:
if len(buf) < header.IPv4MinimumSize {
return nil
}
csumStart = int((buf[0] & 0x0F) * 4)
if csumStart < header.IPv4MinimumSize || csumStart > header.IPv4MaximumHeaderSize || len(buf) < csumStart {
return nil
}
if ^tun.Checksum(buf[:csumStart], 0) != 0 {
return nil
}
pn = header.IPv4ProtocolNumber
case 6:
if len(buf) < header.IPv6FixedHeaderSize {
return nil
}
csumStart = header.IPv6FixedHeaderSize
pn = header.IPv6ProtocolNumber
if p.IPProto != ipproto.ICMPv6 && p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP {
// buf could have extension headers before a UDP or TCP header, but
// packet.Parsed.IPProto will be set to the ext header type, so we
// have to look deeper. We are still responsible for validating the
// L4 checksum in this case. So, make use of gVisor's existing
// extension header parsing via parse.IPv6() in order to unpack the
// L4 csumStart index. This is not particularly efficient as we have
// to allocate a short-lived stack.PacketBuffer that cannot be
// re-used. parse.IPv6() "consumes" the IPv6 headers, so we can't
// inject this stack.PacketBuffer into the stack at a later point.
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(bytes.Clone(buf)),
})
defer packetBuf.DecRef()
// The rightmost bool returns false only if packetBuf is too short,
// which we've already accounted for above.
transportProto, _, _, _, _ := parse.IPv6(packetBuf)
if transportProto == header.TCPProtocolNumber || transportProto == header.UDPProtocolNumber {
csumLen := packetBuf.Data().Size()
if len(buf) < csumLen {
return nil
}
csumStart = len(buf) - csumLen
p.IPProto = ipproto.Proto(transportProto)
}
}
}
if p.IPProto == ipproto.TCP || p.IPProto == ipproto.UDP {
lenForPseudo := len(buf) - csumStart
csum := tun.PseudoHeaderChecksum(
uint8(p.IPProto),
p.Src.Addr().AsSlice(),
p.Dst.Addr().AsSlice(),
uint16(lenForPseudo))
csum = tun.Checksum(buf[csumStart:], csum)
if ^csum != 0 {
return nil
}
}
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(bytes.Clone(buf)),
})
packetBuf.NetworkProtocolNumber = pn
// Setting this is not technically required. gVisor overrides where
// stack.CapabilityRXChecksumOffload is advertised from Capabilities().
// https://github.com/google/gvisor/blob/64c016c92987cc04dfd4c7b091ddd21bdad875f8/pkg/tcpip/stack/nic.go#L763
// This is also why we offload for all packets since we cannot signal this
// per-packet.
packetBuf.RXChecksumValidated = true
return packetBuf
}
var (
groPool sync.Pool
)
func init() {
groPool.New = func() any {
g := &GRO{}
g.gro.Init(true)
return g
}
}
// GRO coalesces incoming packets to increase throughput. It is NOT thread-safe.
type GRO struct {
gro nsgro.GRO
maybeEnqueued bool
}
// NewGRO returns a new instance of *GRO from a sync.Pool. It can be returned to
// the pool with GRO.Flush().
func NewGRO() *GRO {
return groPool.Get().(*GRO)
}
// SetDispatcher sets the underlying stack.NetworkDispatcher where packets are
// delivered.
func (g *GRO) SetDispatcher(d stack.NetworkDispatcher) {
g.gro.Dispatcher = d
}
// Enqueue enqueues the provided packet for GRO. It may immediately deliver
// it to the underlying stack.NetworkDispatcher depending on its contents. To
// explicitly flush previously enqueued packets see Flush().
func (g *GRO) Enqueue(p *packet.Parsed) {
if g.gro.Dispatcher == nil {
return
}
pkt := RXChecksumOffload(p)
if pkt == nil {
return
}
// TODO(jwhited): g.gro.Enqueue() duplicates a lot of p.Decode().
// We may want to push stack.PacketBuffer further up as a
// replacement for packet.Parsed, or inversely push packet.Parsed
// down into refactored GRO logic.
g.gro.Enqueue(pkt)
g.maybeEnqueued = true
pkt.DecRef()
}
// Flush flushes previously enqueued packets to the underlying
// stack.NetworkDispatcher, and returns GRO to a pool for later re-use. Callers
// MUST NOT use GRO once it has been Flush()'d.
func (g *GRO) Flush() {
if g.gro.Dispatcher != nil && g.maybeEnqueued {
g.gro.Flush()
}
g.gro.Dispatcher = nil
g.maybeEnqueued = false
groPool.Put(g)
}

@ -1,7 +1,7 @@
// Copyright (c) Tailscale Inc & AUTHORS // Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
package netstack package gro
import ( import (
"bytes" "bytes"
@ -13,7 +13,7 @@ import (
"tailscale.com/net/packet" "tailscale.com/net/packet"
) )
func Test_rxChecksumOffload(t *testing.T) { func Test_RXChecksumOffload(t *testing.T) {
payloadLen := 100 payloadLen := 100
tcpFields := &header.TCPFields{ tcpFields := &header.TCPFields{
@ -97,7 +97,7 @@ func Test_rxChecksumOffload(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p := &packet.Parsed{} p := &packet.Parsed{}
p.Decode(tt.input) p.Decode(tt.input)
got := rxChecksumOffload(p) got := RXChecksumOffload(p)
if tt.wantPB != (got != nil) { if tt.wantPB != (got != nil) {
t.Fatalf("wantPB = %v != (got != nil): %v", tt.wantPB, got != nil) t.Fatalf("wantPB = %v != (got != nil): %v", tt.wantPB, got != nil)
} }

@ -1,16 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !ios
package netstack
import (
nsgro "gvisor.dev/gvisor/pkg/tcpip/stack/gro"
)
// gro wraps a gVisor GRO implementation. It exists solely to prevent iOS from
// importing said package (see _ios.go).
type gro struct {
nsgro.GRO
}

@ -1,30 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build ios
package netstack
import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// gro on iOS delivers packets to its Dispatcher, immediately. This type exists
// to prevent importation of the gVisor GRO implementation as said package
// increases binary size. This is a penalty we do not wish to pay since we
// currently do not leverage GRO on iOS.
type gro struct {
Dispatcher stack.NetworkDispatcher
}
func (g *gro) Init(v bool) {
if v {
panic("GRO is not supported on this platform")
}
}
func (g *gro) Flush() {}
func (g *gro) Enqueue(pkt *stack.PacketBuffer) {
g.Dispatcher.DeliverNetworkPacket(pkt.NetworkProtocolNumber, pkt)
}

@ -4,18 +4,15 @@
package netstack package netstack
import ( import (
"bytes"
"context" "context"
"sync" "sync"
"github.com/tailscale/wireguard-go/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/wgengine/netstack/gro"
) )
type queue struct { type queue struct {
@ -83,54 +80,72 @@ func (q *queue) Num() int {
var _ stack.LinkEndpoint = (*linkEndpoint)(nil) var _ stack.LinkEndpoint = (*linkEndpoint)(nil)
var _ stack.GSOEndpoint = (*linkEndpoint)(nil) var _ stack.GSOEndpoint = (*linkEndpoint)(nil)
type supportedGRO int
const (
groNotSupported supportedGRO = iota
tcpGROSupported
)
// linkEndpoint implements stack.LinkEndpoint and stack.GSOEndpoint. Outbound // linkEndpoint implements stack.LinkEndpoint and stack.GSOEndpoint. Outbound
// packets written by gVisor towards Tailscale are stored in a channel. // packets written by gVisor towards Tailscale are stored in a channel.
// Inbound is fed to gVisor via injectInbound or enqueueGRO. This is loosely // Inbound is fed to gVisor via injectInbound or gro. This is loosely
// modeled after gvisor.dev/pkg/tcpip/link/channel.Endpoint. // modeled after gvisor.dev/pkg/tcpip/link/channel.Endpoint.
type linkEndpoint struct { type linkEndpoint struct {
SupportedGSOKind stack.SupportedGSO SupportedGSOKind stack.SupportedGSO
initGRO initGRO supportedGRO supportedGRO
mu sync.RWMutex // mu guards the following fields mu sync.RWMutex // mu guards the following fields
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
linkAddr tcpip.LinkAddress linkAddr tcpip.LinkAddress
mtu uint32 mtu uint32
gro gro // mu only guards access to gro.Dispatcher
q *queue // outbound q *queue // outbound
} }
// TODO(jwhited): move to linkEndpointOpts struct or similar. func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress, supportedGRO supportedGRO) *linkEndpoint {
type initGRO bool
const (
disableGRO initGRO = false
enableGRO initGRO = true
)
func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress, gro initGRO) *linkEndpoint {
le := &linkEndpoint{ le := &linkEndpoint{
supportedGRO: supportedGRO,
q: &queue{ q: &queue{
c: make(chan *stack.PacketBuffer, size), c: make(chan *stack.PacketBuffer, size),
}, },
mtu: mtu, mtu: mtu,
linkAddr: linkAddr, linkAddr: linkAddr,
} }
le.initGRO = gro
le.gro.Init(bool(gro))
return le return le
} }
// gro attempts to enqueue p on g if l supports a GRO kind matching the
// transport protocol carried in p. gro may allocate g if it is nil. gro can
// either return the existing g, a newly allocated one, or nil. Callers are
// responsible for calling Flush() on the returned value if it is non-nil once
// they have finished iterating through all GRO candidates for a given vector.
// If gro allocates a *gro.GRO it will have l's stack.NetworkDispatcher set via
// SetDispatcher().
func (l *linkEndpoint) gro(p *packet.Parsed, g *gro.GRO) *gro.GRO {
if l.supportedGRO == groNotSupported || p.IPProto != ipproto.TCP {
// IPv6 may have extension headers preceding a TCP header, but we trade
// for a fast path and assume p cannot be coalesced in such a case.
l.injectInbound(p)
return g
}
if g == nil {
l.mu.RLock()
d := l.dispatcher
l.mu.RUnlock()
g = gro.NewGRO()
g.SetDispatcher(d)
}
g.Enqueue(p)
return g
}
// Close closes l. Further packet injections will return an error, and all // Close closes l. Further packet injections will return an error, and all
// pending packets are discarded. Close may be called concurrently with // pending packets are discarded. Close may be called concurrently with
// WritePackets. // WritePackets.
func (l *linkEndpoint) Close() { func (l *linkEndpoint) Close() {
l.mu.Lock() l.mu.Lock()
if l.gro.Dispatcher != nil {
l.gro.Flush()
}
l.dispatcher = nil l.dispatcher = nil
l.gro.Dispatcher = nil
l.mu.Unlock() l.mu.Unlock()
l.q.Close() l.q.Close()
l.Drain() l.Drain()
@ -162,93 +177,6 @@ func (l *linkEndpoint) NumQueued() int {
return l.q.Num() return l.q.Num()
} }
// rxChecksumOffload validates IPv4, TCP, and UDP header checksums in p,
// returning an equivalent *stack.PacketBuffer if they are valid, otherwise nil.
// The set of headers validated covers where gVisor would perform validation if
// !stack.PacketBuffer.RXChecksumValidated, i.e. it satisfies
// stack.CapabilityRXChecksumOffload. Other protocols with checksum fields,
// e.g. ICMP{v6}, are still validated by gVisor regardless of rx checksum
// offloading capabilities.
func rxChecksumOffload(p *packet.Parsed) *stack.PacketBuffer {
var (
pn tcpip.NetworkProtocolNumber
csumStart int
)
buf := p.Buffer()
switch p.IPVersion {
case 4:
if len(buf) < header.IPv4MinimumSize {
return nil
}
csumStart = int((buf[0] & 0x0F) * 4)
if csumStart < header.IPv4MinimumSize || csumStart > header.IPv4MaximumHeaderSize || len(buf) < csumStart {
return nil
}
if ^tun.Checksum(buf[:csumStart], 0) != 0 {
return nil
}
pn = header.IPv4ProtocolNumber
case 6:
if len(buf) < header.IPv6FixedHeaderSize {
return nil
}
csumStart = header.IPv6FixedHeaderSize
pn = header.IPv6ProtocolNumber
if p.IPProto != ipproto.ICMPv6 && p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP {
// buf could have extension headers before a UDP or TCP header, but
// packet.Parsed.IPProto will be set to the ext header type, so we
// have to look deeper. We are still responsible for validating the
// L4 checksum in this case. So, make use of gVisor's existing
// extension header parsing via parse.IPv6() in order to unpack the
// L4 csumStart index. This is not particularly efficient as we have
// to allocate a short-lived stack.PacketBuffer that cannot be
// re-used. parse.IPv6() "consumes" the IPv6 headers, so we can't
// inject this stack.PacketBuffer into the stack at a later point.
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(bytes.Clone(buf)),
})
defer packetBuf.DecRef()
// The rightmost bool returns false only if packetBuf is too short,
// which we've already accounted for above.
transportProto, _, _, _, _ := parse.IPv6(packetBuf)
if transportProto == header.TCPProtocolNumber || transportProto == header.UDPProtocolNumber {
csumLen := packetBuf.Data().Size()
if len(buf) < csumLen {
return nil
}
csumStart = len(buf) - csumLen
p.IPProto = ipproto.Proto(transportProto)
}
}
}
if p.IPProto == ipproto.TCP || p.IPProto == ipproto.UDP {
lenForPseudo := len(buf) - csumStart
csum := tun.PseudoHeaderChecksum(
uint8(p.IPProto),
p.Src.Addr().AsSlice(),
p.Dst.Addr().AsSlice(),
uint16(lenForPseudo))
csum = tun.Checksum(buf[csumStart:], csum)
if ^csum != 0 {
return nil
}
}
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(bytes.Clone(buf)),
})
packetBuf.NetworkProtocolNumber = pn
// Setting this is not technically required. gVisor overrides where
// stack.CapabilityRXChecksumOffload is advertised from Capabilities().
// https://github.com/google/gvisor/blob/64c016c92987cc04dfd4c7b091ddd21bdad875f8/pkg/tcpip/stack/nic.go#L763
// This is also why we offload for all packets since we cannot signal this
// per-packet.
packetBuf.RXChecksumValidated = true
return packetBuf
}
func (l *linkEndpoint) injectInbound(p *packet.Parsed) { func (l *linkEndpoint) injectInbound(p *packet.Parsed) {
l.mu.RLock() l.mu.RLock()
d := l.dispatcher d := l.dispatcher
@ -256,7 +184,7 @@ func (l *linkEndpoint) injectInbound(p *packet.Parsed) {
if d == nil { if d == nil {
return return
} }
pkt := rxChecksumOffload(p) pkt := gro.RXChecksumOffload(p)
if pkt == nil { if pkt == nil {
return return
} }
@ -264,52 +192,12 @@ func (l *linkEndpoint) injectInbound(p *packet.Parsed) {
pkt.DecRef() pkt.DecRef()
} }
// enqueueGRO enqueues the provided packet for GRO. It may immediately deliver
// it to the underlying stack.NetworkDispatcher depending on its contents and if
// GRO was initialized via newLinkEndpoint. To explicitly flush previously
// enqueued packets see flushGRO. enqueueGRO is not thread-safe and must not
// be called concurrently with flushGRO.
func (l *linkEndpoint) enqueueGRO(p *packet.Parsed) {
l.mu.RLock()
defer l.mu.RUnlock()
if l.gro.Dispatcher == nil {
return
}
pkt := rxChecksumOffload(p)
if pkt == nil {
return
}
// TODO(jwhited): gro.Enqueue() duplicates a lot of p.Decode().
// We may want to push stack.PacketBuffer further up as a
// replacement for packet.Parsed, or inversely push packet.Parsed
// down into refactored GRO logic.
l.gro.Enqueue(pkt)
pkt.DecRef()
}
// flushGRO flushes previously enqueueGRO'd packets to the underlying
// stack.NetworkDispatcher. flushGRO is not thread-safe, and must not be
// called concurrently with enqueueGRO.
func (l *linkEndpoint) flushGRO() {
if !l.initGRO {
// If GRO was not initialized fast path return to avoid scanning GRO
// buckets (see l.gro.Flush()) that will always be empty.
return
}
l.mu.RLock()
defer l.mu.RUnlock()
if l.gro.Dispatcher != nil {
l.gro.Flush()
}
}
// Attach saves the stack network-layer dispatcher for use later when packets // Attach saves the stack network-layer dispatcher for use later when packets
// are injected. // are injected.
func (l *linkEndpoint) Attach(dispatcher stack.NetworkDispatcher) { func (l *linkEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
l.mu.Lock() l.mu.Lock()
defer l.mu.Unlock() defer l.mu.Unlock()
l.dispatcher = dispatcher l.dispatcher = dispatcher
l.gro.Dispatcher = dispatcher
} }
// IsAttached implements stack.LinkEndpoint.IsAttached. // IsAttached implements stack.LinkEndpoint.IsAttached.

@ -54,6 +54,7 @@ import (
"tailscale.com/wgengine" "tailscale.com/wgengine"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/magicsock"
"tailscale.com/wgengine/netstack/gro"
) )
const debugPackets = false const debugPackets = false
@ -324,15 +325,15 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
if err != nil { if err != nil {
return nil, err return nil, err
} }
var linkEP *linkEndpoint var supportedGSOKind stack.SupportedGSO
var supportedGROKind supportedGRO
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
// TODO(jwhited): add Windows GSO support https://github.com/tailscale/corp/issues/21874 // TODO(jwhited): add Windows support https://github.com/tailscale/corp/issues/21874
// TODO(jwhited): exercise enableGRO in relation to https://github.com/tailscale/corp/issues/22353 supportedGSOKind = stack.HostGSOSupported
linkEP = newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "", disableGRO) supportedGROKind = tcpGROSupported
linkEP.SupportedGSOKind = stack.HostGSOSupported
} else {
linkEP = newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "", disableGRO)
} }
linkEP := newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "", supportedGROKind)
linkEP.SupportedGSOKind = supportedGSOKind
if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
} }
@ -380,7 +381,6 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
ns.ctx, ns.ctxCancel = context.WithCancel(context.Background()) ns.ctx, ns.ctxCancel = context.WithCancel(context.Background())
ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc()) ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc())
ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound
ns.tundev.EndPacketVectorInboundFromWireGuardFlush = linkEP.flushGRO
ns.tundev.PreFilterPacketOutboundToWireGuardNetstackIntercept = ns.handleLocalPackets ns.tundev.PreFilterPacketOutboundToWireGuardNetstackIntercept = ns.handleLocalPackets
stacksForMetrics.Store(ns, struct{}{}) stacksForMetrics.Store(ns, struct{}{})
return ns, nil return ns, nil
@ -1039,14 +1039,14 @@ func (ns *Impl) userPing(dstIP netip.Addr, pingResPkt []byte, direction userPing
// continue normally (typically being delivered to the host networking stack), // continue normally (typically being delivered to the host networking stack),
// whereas returning filter.DropSilently is done when netstack intercepts the // whereas returning filter.DropSilently is done when netstack intercepts the
// packet and no further processing towards to host should be done. // packet and no further processing towards to host should be done.
func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Response { func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper, gro *gro.GRO) (filter.Response, *gro.GRO) {
if ns.ctx.Err() != nil { if ns.ctx.Err() != nil {
return filter.DropSilently return filter.DropSilently, gro
} }
if !ns.shouldProcessInbound(p, t) { if !ns.shouldProcessInbound(p, t) {
// Let the host network stack (if any) deal with it. // Let the host network stack (if any) deal with it.
return filter.Accept return filter.Accept, gro
} }
destIP := p.Dst.Addr() destIP := p.Dst.Addr()
@ -1066,13 +1066,13 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons
pong = packet.Generate(&h, p.Payload()) pong = packet.Generate(&h, p.Payload())
} }
go ns.userPing(pingIP, pong, userPingDirectionOutbound) go ns.userPing(pingIP, pong, userPingDirectionOutbound)
return filter.DropSilently return filter.DropSilently, gro
} }
if debugPackets { if debugPackets {
ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer()) ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer())
} }
ns.linkEP.enqueueGRO(p) gro = ns.linkEP.gro(p, gro)
// We've now delivered this to netstack, so we're done. // We've now delivered this to netstack, so we're done.
// Instead of returning a filter.Accept here (which would also // Instead of returning a filter.Accept here (which would also
@ -1080,7 +1080,7 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons
// filter.Drop (which would log about rejected traffic), // filter.Drop (which would log about rejected traffic),
// instead return filter.DropSilently which just quietly stops // instead return filter.DropSilently which just quietly stops
// processing it in the tstun TUN wrapper. // processing it in the tstun TUN wrapper.
return filter.DropSilently return filter.DropSilently, gro
} }
// shouldHandlePing returns whether or not netstack should handle an incoming // shouldHandlePing returns whether or not netstack should handle an incoming

@ -79,7 +79,7 @@ func TestInjectInboundLeak(t *testing.T) {
const N = 10_000 const N = 10_000
ms0 := getMemStats() ms0 := getMemStats()
for range N { for range N {
outcome := ns.injectInbound(pkt, tunWrap) outcome, _ := ns.injectInbound(pkt, tunWrap, nil)
if outcome != filter.DropSilently { if outcome != filter.DropSilently {
t.Fatalf("got outcome %v; want DropSilently", outcome) t.Fatalf("got outcome %v; want DropSilently", outcome)
} }
@ -569,7 +569,7 @@ func TestTCPForwardLimits(t *testing.T) {
// When injecting this packet, we want the outcome to be "drop // When injecting this packet, we want the outcome to be "drop
// silently", which indicates that netstack is processing the // silently", which indicates that netstack is processing the
// packet and not delivering it to the host system. // packet and not delivering it to the host system.
if resp := impl.injectInbound(&parsed, impl.tundev); resp != filter.DropSilently { if resp, _ := impl.injectInbound(&parsed, impl.tundev, nil); resp != filter.DropSilently {
t.Errorf("got filter outcome %v, want filter.DropSilently", resp) t.Errorf("got filter outcome %v, want filter.DropSilently", resp)
} }
@ -587,7 +587,7 @@ func TestTCPForwardLimits(t *testing.T) {
// Inject another packet, which will be deduplicated and thus not // Inject another packet, which will be deduplicated and thus not
// increment our counter. // increment our counter.
parsed.Decode(pkt) parsed.Decode(pkt)
if resp := impl.injectInbound(&parsed, impl.tundev); resp != filter.DropSilently { if resp, _ := impl.injectInbound(&parsed, impl.tundev, nil); resp != filter.DropSilently {
t.Errorf("got filter outcome %v, want filter.DropSilently", resp) t.Errorf("got filter outcome %v, want filter.DropSilently", resp)
} }
@ -655,7 +655,7 @@ func TestTCPForwardLimits_PerClient(t *testing.T) {
// When injecting this packet, we want the outcome to be "drop // When injecting this packet, we want the outcome to be "drop
// silently", which indicates that netstack is processing the // silently", which indicates that netstack is processing the
// packet and not delivering it to the host system. // packet and not delivering it to the host system.
if resp := impl.injectInbound(&parsed, impl.tundev); resp != filter.DropSilently { if resp, _ := impl.injectInbound(&parsed, impl.tundev, nil); resp != filter.DropSilently {
t.Fatalf("got filter outcome %v, want filter.DropSilently", resp) t.Fatalf("got filter outcome %v, want filter.DropSilently", resp)
} }
} }

@ -54,6 +54,7 @@ import (
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/magicsock"
"tailscale.com/wgengine/netlog" "tailscale.com/wgengine/netlog"
"tailscale.com/wgengine/netstack/gro"
"tailscale.com/wgengine/router" "tailscale.com/wgengine/router"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
"tailscale.com/wgengine/wgint" "tailscale.com/wgengine/wgint"
@ -519,7 +520,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
} }
// echoRespondToAll is an inbound post-filter responding to all echo requests. // echoRespondToAll is an inbound post-filter responding to all echo requests.
func echoRespondToAll(p *packet.Parsed, t *tstun.Wrapper) filter.Response { func echoRespondToAll(p *packet.Parsed, t *tstun.Wrapper, gro *gro.GRO) (filter.Response, *gro.GRO) {
if p.IsEchoRequest() { if p.IsEchoRequest() {
header := p.ICMP4Header() header := p.ICMP4Header()
header.ToResponse() header.ToResponse()
@ -531,9 +532,9 @@ func echoRespondToAll(p *packet.Parsed, t *tstun.Wrapper) filter.Response {
// it away. If this ever gets run in non-fake mode, you'll // it away. If this ever gets run in non-fake mode, you'll
// get double responses to pings, which is an indicator you // get double responses to pings, which is an indicator you
// shouldn't be doing that I guess.) // shouldn't be doing that I guess.)
return filter.Accept return filter.Accept, gro
} }
return filter.Accept return filter.Accept, gro
} }
// handleLocalPackets inspects packets coming from the local network // handleLocalPackets inspects packets coming from the local network

Loading…
Cancel
Save