tsnet: add support for a user-supplied tun.Device

tsnet users can now provide a tun.Device, including any custom
implementation that conforms to the interface.

netstack has a new option CheckLocalTransportEndpoints that when used
alongside a TUN enables netstack listens and dials to correctly capture
traffic associated with those sockets. tsnet with a TUN sets this
option, while all other builds leave this at false to preserve existing
performance.

Updates #18423

Signed-off-by: James Tucker <james@tailscale.com>
main
James Tucker 1 week ago committed by James Tucker
parent c062230cce
commit 63d563e734

@ -26,6 +26,7 @@ import (
"sync"
"time"
"github.com/tailscale/wireguard-go/tun"
"tailscale.com/client/local"
"tailscale.com/control/controlclient"
"tailscale.com/envknob"
@ -167,6 +168,11 @@ type Server struct {
// that the control server will allow the node to adopt that tag.
AdvertiseTags []string
// Tun, if non-nil, specifies a custom tun.Device to use for packet I/O.
//
// This field must be set before calling Start.
Tun tun.Device
initOnce sync.Once
initErr error
lb *ipnlocal.LocalBackend
@ -659,6 +665,7 @@ func (s *Server) start() (reterr error) {
s.dialer = &tsdial.Dialer{Logf: tsLogf} // mutated below (before used)
s.dialer.SetBus(sys.Bus.Get())
eng, err := wgengine.NewUserspaceEngine(tsLogf, wgengine.Config{
Tun: s.Tun,
EventBus: sys.Bus.Get(),
ListenPort: s.Port,
NetMon: s.netMon,
@ -682,8 +689,16 @@ func (s *Server) start() (reterr error) {
}
sys.Tun.Get().Start()
sys.Set(ns)
ns.ProcessLocalIPs = true
ns.ProcessSubnets = true
if s.Tun == nil {
// Only process packets in netstack when using the default fake TUN.
// When a TUN is provided, let packets flow through it instead.
ns.ProcessLocalIPs = true
ns.ProcessSubnets = true
} else {
// When using a TUN, check gVisor for registered endpoints to handle
// packets for tsnet listeners and outbound connection replies.
ns.CheckLocalTransportEndpoints = true
}
ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow
ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow
s.netstack = ns
@ -1072,10 +1087,34 @@ func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) {
network = "udp6"
}
}
if err := s.Start(); err != nil {
netLn, err := s.listen(network, addr, listenOnTailnet)
if err != nil {
return nil, err
}
return s.netstack.ListenPacket(network, ap.String())
ln := netLn.(*listener)
pc, err := s.netstack.ListenPacket(network, ap.String())
if err != nil {
ln.Close()
return nil, err
}
return &udpPacketConn{
PacketConn: pc,
ln: ln,
}, nil
}
// udpPacketConn wraps a net.PacketConn to unregister from s.listeners on Close.
type udpPacketConn struct {
net.PacketConn
ln *listener
}
func (c *udpPacketConn) Close() error {
c.ln.Close()
return c.PacketConn.Close()
}
// ListenTLS announces only on the Tailscale network.
@ -1611,10 +1650,37 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro
closedc: make(chan struct{}),
conn: make(chan net.Conn),
}
// When using a TUN with TCP, create a gVisor TCP listener.
if s.Tun != nil && (network == "" || network == "tcp" || network == "tcp4" || network == "tcp6") {
var nsNetwork string
nsAddr := host
switch {
case network == "tcp4" || network == "tcp6":
nsNetwork = network
case host.Addr().Is4():
nsNetwork = "tcp4"
case host.Addr().Is6():
nsNetwork = "tcp6"
default:
// Wildcard address: use tcp6 for dual-stack (accepts both v4 and v6).
nsNetwork = "tcp6"
nsAddr = netip.AddrPortFrom(netip.IPv6Unspecified(), host.Port())
}
gonetLn, err := s.netstack.ListenTCP(nsNetwork, nsAddr.String())
if err != nil {
return nil, fmt.Errorf("tsnet: %w", err)
}
ln.gonetLn = gonetLn
}
s.mu.Lock()
for _, key := range keys {
if _, ok := s.listeners[key]; ok {
s.mu.Unlock()
if ln.gonetLn != nil {
ln.gonetLn.Close()
}
return nil, fmt.Errorf("tsnet: listener already open for %s, %s", network, addr)
}
}
@ -1684,9 +1750,17 @@ type listener struct {
conn chan net.Conn // unbuffered, never closed
closedc chan struct{} // closed on [listener.Close]
closed bool // guarded by s.mu
// gonetLn, if set, is the gonet.Listener that handles new connections.
// gonetLn is set by [listen] when a TUN is in use and terminates the listener.
// gonetLn is nil when TUN is nil.
gonetLn net.Listener
}
func (ln *listener) Accept() (net.Conn, error) {
if ln.gonetLn != nil {
return ln.gonetLn.Accept()
}
select {
case c := <-ln.conn:
return c, nil
@ -1696,6 +1770,9 @@ func (ln *listener) Accept() (net.Conn, error) {
}
func (ln *listener) Addr() net.Addr {
if ln.gonetLn != nil {
return ln.gonetLn.Addr()
}
return addr{
network: ln.keys[0].network,
addr: ln.addr,
@ -1721,6 +1798,9 @@ func (ln *listener) closeLocked() error {
}
close(ln.closedc)
ln.closed = true
if ln.gonetLn != nil {
ln.gonetLn.Close()
}
return nil
}

@ -39,6 +39,7 @@ import (
"github.com/google/go-cmp/cmp"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
"github.com/tailscale/wireguard-go/tun"
"golang.org/x/net/proxy"
"tailscale.com/client/local"
@ -48,11 +49,13 @@ import (
"tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/store/mem"
"tailscale.com/net/netns"
"tailscale.com/net/packet"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/tstest/deptest"
"tailscale.com/tstest/integration"
"tailscale.com/tstest/integration/testcontrol"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/views"
@ -1860,6 +1863,676 @@ func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *local.Client) {
t.Error("magicsock did not find a direct path from lc1 to lc2")
}
// chanTUN is a tun.Device for testing that uses channels for packet I/O.
// Inbound receives packets written to the TUN (from the perspective of the network stack).
// Outbound is for injecting packets to be read from the TUN.
type chanTUN struct {
Inbound chan []byte // packets written to TUN
Outbound chan []byte // packets to read from TUN
closed chan struct{}
events chan tun.Event
}
func newChanTUN() *chanTUN {
t := &chanTUN{
Inbound: make(chan []byte, 10),
Outbound: make(chan []byte, 10),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
t.events <- tun.EventUp
return t
}
func (t *chanTUN) File() *os.File { panic("not implemented") }
func (t *chanTUN) Close() error {
select {
case <-t.closed:
default:
close(t.closed)
close(t.Inbound)
}
return nil
}
func (t *chanTUN) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case <-t.closed:
return 0, io.EOF
case pkt := <-t.Outbound:
sizes[0] = copy(bufs[0][offset:], pkt)
return 1, nil
}
}
func (t *chanTUN) Write(bufs [][]byte, offset int) (int, error) {
for _, buf := range bufs {
pkt := buf[offset:]
if len(pkt) == 0 {
continue
}
select {
case <-t.closed:
return 0, errors.New("closed")
case t.Inbound <- slices.Clone(pkt):
}
}
return len(bufs), nil
}
func (t *chanTUN) MTU() (int, error) { return 1280, nil }
func (t *chanTUN) Name() (string, error) { return "chantun", nil }
func (t *chanTUN) Events() <-chan tun.Event { return t.events }
func (t *chanTUN) BatchSize() int { return 1 }
// listenTest provides common setup for listener and TUN tests.
type listenTest struct {
s1, s2 *Server
s1ip4, s1ip6 netip.Addr
s2ip4, s2ip6 netip.Addr
tun *chanTUN // nil for netstack mode
}
// setupListenTest creates two tsnet servers for testing.
// If useTUN is true, s2 uses a chanTUN; otherwise it uses netstack only.
func setupListenTest(t *testing.T, useTUN bool) *listenTest {
t.Helper()
tstest.Shard(t)
tstest.ResourceCheck(t)
ctx := t.Context()
controlURL, _ := startControl(t)
s1, _, _ := startServer(t, ctx, controlURL, "s1")
tmp := filepath.Join(t.TempDir(), "s2")
must.Do(os.MkdirAll(tmp, 0755))
s2 := &Server{
Dir: tmp,
ControlURL: controlURL,
Hostname: "s2",
Store: new(mem.Store),
Ephemeral: true,
}
var tun *chanTUN
if useTUN {
tun = newChanTUN()
s2.Tun = tun
}
if *verboseNodes {
s2.Logf = t.Logf
}
t.Cleanup(func() { s2.Close() })
s2status, err := s2.Up(ctx)
if err != nil {
t.Fatal(err)
}
s1ip4, s1ip6 := s1.TailscaleIPs()
s2ip4 := s2status.TailscaleIPs[0]
var s2ip6 netip.Addr
if len(s2status.TailscaleIPs) > 1 {
s2ip6 = s2status.TailscaleIPs[1]
}
lc1 := must.Get(s1.LocalClient())
must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP))
return &listenTest{
s1: s1,
s2: s2,
s1ip4: s1ip4,
s1ip6: s1ip6,
s2ip4: s2ip4,
s2ip6: s2ip6,
tun: tun,
}
}
// echoUDP returns an IP packet with src/dst and ports swapped, with checksums recomputed.
func echoUDP(pkt []byte) []byte {
var p packet.Parsed
p.Decode(pkt)
if p.IPProto != ipproto.UDP {
return nil
}
switch p.IPVersion {
case 4:
h := p.UDP4Header()
h.ToResponse()
return packet.Generate(h, p.Payload())
case 6:
h := packet.UDP6Header{
IP6Header: p.IP6Header(),
SrcPort: p.Src.Port(),
DstPort: p.Dst.Port(),
}
h.ToResponse()
return packet.Generate(h, p.Payload())
}
return nil
}
func TestTUN(t *testing.T) {
tt := setupListenTest(t, true)
go func() {
for pkt := range tt.tun.Inbound {
var p packet.Parsed
p.Decode(pkt)
if p.Dst.Port() == 9999 {
tt.tun.Outbound <- echoUDP(pkt)
}
}
}()
test := func(t *testing.T, s2ip netip.Addr) {
conn, err := tt.s1.Dial(t.Context(), "udp", netip.AddrPortFrom(s2ip, 9999).String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
want := "hello from s1"
if _, err := conn.Write([]byte(want)); err != nil {
t.Fatal(err)
}
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
got := make([]byte, 1024)
n, err := conn.Read(got)
if err != nil {
t.Fatalf("reading echo response: %v", err)
}
if string(got[:n]) != want {
t.Errorf("got %q, want %q", got[:n], want)
}
}
t.Run("IPv4", func(t *testing.T) { test(t, tt.s2ip4) })
t.Run("IPv6", func(t *testing.T) { test(t, tt.s2ip6) })
}
// TestTUNDNS tests that a TUN can send DNS queries to quad-100 and receive
// responses. This verifies that handleLocalPackets intercepts outbound traffic
// to the service IP.
func TestTUNDNS(t *testing.T) {
tt := setupListenTest(t, true)
test := func(t *testing.T, srcIP netip.Addr, serviceIP netip.Addr) {
tt.tun.Outbound <- buildDNSQuery("s2", srcIP)
ipVersion := uint8(4)
if srcIP.Is6() {
ipVersion = 6
}
for {
select {
case pkt := <-tt.tun.Inbound:
var p packet.Parsed
p.Decode(pkt)
if p.IPVersion != ipVersion || p.IPProto != ipproto.UDP {
continue
}
if p.Src.Addr() == serviceIP && p.Src.Port() == 53 {
if len(p.Payload()) < 12 {
t.Fatalf("DNS response too short: %d bytes", len(p.Payload()))
}
return // success
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for DNS response")
}
}
}
t.Run("IPv4", func(t *testing.T) {
test(t, tt.s2ip4, netip.MustParseAddr("100.100.100.100"))
})
t.Run("IPv6", func(t *testing.T) {
test(t, tt.s2ip6, netip.MustParseAddr("fd7a:115c:a1e0::53"))
})
}
// TestListenPacket tests UDP listeners (ListenPacket) in both netstack and TUN modes.
func TestListenPacket(t *testing.T) {
testListenPacket := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
pc, err := lt.s2.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String())
if err != nil {
t.Fatal(err)
}
defer pc.Close()
echoErr := make(chan error, 1)
go func() {
buf := make([]byte, 1500)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
echoErr <- err
return
}
_, err = pc.WriteTo(buf[:n], addr)
if err != nil {
echoErr <- err
return
}
}()
conn, err := lt.s1.Dial(t.Context(), "udp", pc.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
want := "hello udp"
if _, err := conn.Write([]byte(want)); err != nil {
t.Fatal(err)
}
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
got := make([]byte, 1024)
n, err := conn.Read(got)
if err != nil {
select {
case e := <-echoErr:
t.Fatalf("echo error: %v; read error: %v", e, err)
default:
t.Fatalf("Read failed: %v", err)
}
}
if string(got[:n]) != want {
t.Errorf("got %q, want %q", got[:n], want)
}
}
t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false)
t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
})
t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true)
t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
})
}
// TestListenTCP tests TCP listeners with concrete addresses in both netstack
// and TUN modes.
func TestListenTCP(t *testing.T) {
testListenTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
ln, err := lt.s2.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String())
if err != nil {
t.Fatal(err)
}
defer ln.Close()
echoErr := make(chan error, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
echoErr <- err
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
echoErr <- err
return
}
_, err = conn.Write(buf[:n])
if err != nil {
echoErr <- err
return
}
}()
conn, err := lt.s1.Dial(t.Context(), "tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
defer conn.Close()
want := "hello tcp"
if _, err := conn.Write([]byte(want)); err != nil {
t.Fatalf("Write failed: %v", err)
}
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
got := make([]byte, 1024)
n, err := conn.Read(got)
if err != nil {
select {
case e := <-echoErr:
t.Fatalf("echo error: %v; read error: %v", e, err)
default:
t.Fatalf("Read failed: %v", err)
}
}
if string(got[:n]) != want {
t.Errorf("got %q, want %q", got[:n], want)
}
}
t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false)
t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
})
t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true)
t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
})
}
// TestListenTCPDualStack tests TCP listeners with wildcard addresses (dual-stack)
// in both netstack and TUN modes.
func TestListenTCPDualStack(t *testing.T) {
testListenTCPDualStack := func(t *testing.T, lt *listenTest, dialIP netip.Addr) {
ln, err := lt.s2.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
_, portStr, err := net.SplitHostPort(ln.Addr().String())
if err != nil {
t.Fatalf("parsing listener address %q: %v", ln.Addr().String(), err)
}
echoErr := make(chan error, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
echoErr <- err
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
echoErr <- err
return
}
_, err = conn.Write(buf[:n])
if err != nil {
echoErr <- err
return
}
}()
dialAddr := net.JoinHostPort(dialIP.String(), portStr)
conn, err := lt.s1.Dial(t.Context(), "tcp", dialAddr)
if err != nil {
t.Fatalf("Dial(%q) failed: %v", dialAddr, err)
}
defer conn.Close()
want := "hello tcp dualstack"
if _, err := conn.Write([]byte(want)); err != nil {
t.Fatalf("Write failed: %v", err)
}
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
got := make([]byte, 1024)
n, err := conn.Read(got)
if err != nil {
select {
case e := <-echoErr:
t.Fatalf("echo error: %v; read error: %v", e, err)
default:
t.Fatalf("Read failed: %v", err)
}
}
if string(got[:n]) != want {
t.Errorf("got %q, want %q", got[:n], want)
}
}
t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false)
t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
})
t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true)
t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
})
}
// TestDialTCP tests TCP dialing from s2 to s1 in both netstack and TUN modes.
// In TUN mode, this verifies that outbound TCP connections and their replies
// are handled by netstack without packets escaping to the TUN.
func TestDialTCP(t *testing.T) {
testDialTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
ln, err := lt.s1.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String())
if err != nil {
t.Fatal(err)
}
defer ln.Close()
echoErr := make(chan error, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
echoErr <- err
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
echoErr <- err
return
}
_, err = conn.Write(buf[:n])
if err != nil {
echoErr <- err
return
}
}()
conn, err := lt.s2.Dial(t.Context(), "tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
defer conn.Close()
want := "hello tcp dial"
if _, err := conn.Write([]byte(want)); err != nil {
t.Fatalf("Write failed: %v", err)
}
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
got := make([]byte, 1024)
n, err := conn.Read(got)
if err != nil {
select {
case e := <-echoErr:
t.Fatalf("echo error: %v; read error: %v", e, err)
default:
t.Fatalf("Read failed: %v", err)
}
}
if string(got[:n]) != want {
t.Errorf("got %q, want %q", got[:n], want)
}
}
t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false)
t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) })
t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) })
})
t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true)
var escapedTCPPackets atomic.Int32
var wg sync.WaitGroup
wg.Go(func() {
for pkt := range lt.tun.Inbound {
var p packet.Parsed
p.Decode(pkt)
if p.IPProto == ipproto.TCP {
escapedTCPPackets.Add(1)
t.Logf("TCP packet escaped to TUN: %v -> %v", p.Src, p.Dst)
}
}
})
t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) })
t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) })
lt.tun.Close()
wg.Wait()
if escaped := escapedTCPPackets.Load(); escaped > 0 {
t.Errorf("%d TCP packets escaped to TUN", escaped)
}
})
}
// TestDialUDP tests UDP dialing from s2 to s1 in both netstack and TUN modes.
// In TUN mode, this verifies that outbound UDP connections register endpoints
// with gVisor, allowing reply packets to be routed through netstack instead of
// escaping to the TUN.
func TestDialUDP(t *testing.T) {
testDialUDP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
pc, err := lt.s1.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String())
if err != nil {
t.Fatal(err)
}
defer pc.Close()
echoErr := make(chan error, 1)
go func() {
buf := make([]byte, 1500)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
echoErr <- err
return
}
_, err = pc.WriteTo(buf[:n], addr)
if err != nil {
echoErr <- err
return
}
}()
conn, err := lt.s2.Dial(t.Context(), "udp", pc.LocalAddr().String())
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
defer conn.Close()
want := "hello udp dial"
if _, err := conn.Write([]byte(want)); err != nil {
t.Fatalf("Write failed: %v", err)
}
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
got := make([]byte, 1024)
n, err := conn.Read(got)
if err != nil {
select {
case e := <-echoErr:
t.Fatalf("echo error: %v; read error: %v", e, err)
default:
t.Fatalf("Read failed: %v", err)
}
}
if string(got[:n]) != want {
t.Errorf("got %q, want %q", got[:n], want)
}
}
t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false)
t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) })
t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) })
})
t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true)
var escapedUDPPackets atomic.Int32
var wg sync.WaitGroup
wg.Go(func() {
for pkt := range lt.tun.Inbound {
var p packet.Parsed
p.Decode(pkt)
if p.IPProto == ipproto.UDP {
escapedUDPPackets.Add(1)
t.Logf("UDP packet escaped to TUN: %v -> %v", p.Src, p.Dst)
}
}
})
t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) })
t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) })
lt.tun.Close()
wg.Wait()
if escaped := escapedUDPPackets.Load(); escaped > 0 {
t.Errorf("%d UDP packets escaped to TUN", escaped)
}
})
}
// buildDNSQuery builds a UDP/IP packet containing a DNS query for name to the
// Tailscale service IP (100.100.100.100 for IPv4, fd7a:115c:a1e0::53 for IPv6).
func buildDNSQuery(name string, srcIP netip.Addr) []byte {
qtype := byte(0x01) // Type A for IPv4
if srcIP.Is6() {
qtype = 0x1c // Type AAAA for IPv6
}
dns := []byte{
0x12, 0x34, // ID
0x01, 0x00, // Flags: standard query, recursion desired
0x00, 0x01, // QDCOUNT: 1
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ANCOUNT, NSCOUNT, ARCOUNT
}
for _, label := range strings.Split(name, ".") {
dns = append(dns, byte(len(label)))
dns = append(dns, label...)
}
dns = append(dns, 0x00, 0x00, qtype, 0x00, 0x01) // null, Type A/AAAA, Class IN
if srcIP.Is4() {
h := packet.UDP4Header{
IP4Header: packet.IP4Header{
Src: srcIP,
Dst: netip.MustParseAddr("100.100.100.100"),
},
SrcPort: 12345,
DstPort: 53,
}
return packet.Generate(h, dns)
}
h := packet.UDP6Header{
IP6Header: packet.IP6Header{
Src: srcIP,
Dst: netip.MustParseAddr("fd7a:115c:a1e0::53"),
},
SrcPort: 12345,
DstPort: 53,
}
return packet.Generate(h, dns)
}
func TestDeps(t *testing.T) {
tstest.Shard(t)
deptest.DepChecker{

@ -165,6 +165,17 @@ type Impl struct {
// over the UDP flow.
GetUDPHandlerForFlow func(src, dst netip.AddrPort) (handler func(nettype.ConnPacketConn), intercept bool)
// CheckLocalTransportEndpoints, if true, causes netstack to check if gVisor
// has a registered endpoint for incoming packets to local IPs. This is used
// by tsnet to intercept packets for registered listeners and outbound
// connections when ProcessLocalIPs is false (i.e., when using a TUN).
// It can only be set before calling Start.
// TODO(raggi): refactor the way we handle both CheckLocalTransportEndpoints
// and the earlier netstack registrations for serve, funnel, peerAPI and so
// on. Currently this optimizes away cost for tailscaled in TUN mode, while
// enabling extension support when using tsnet in TUN mode. See #18423.
CheckLocalTransportEndpoints bool
// ProcessLocalIPs is whether netstack should handle incoming
// traffic directed at the Node.Addresses (local IPs).
// It can only be set before calling Start.
@ -1109,6 +1120,45 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool {
if ns.ProcessSubnets && !isLocal {
return true
}
if isLocal && ns.CheckLocalTransportEndpoints {
// Handle packets to registered listeners and replies to outbound
// connections by checking if gVisor has a registered endpoint.
// This covers TCP listeners, UDP listeners, and outbound TCP replies.
if p.IPProto == ipproto.TCP || p.IPProto == ipproto.UDP {
var netProto tcpip.NetworkProtocolNumber
var id stack.TransportEndpointID
if p.Dst.Addr().Is4() {
netProto = ipv4.ProtocolNumber
id = stack.TransportEndpointID{
LocalAddress: tcpip.AddrFrom4(p.Dst.Addr().As4()),
LocalPort: p.Dst.Port(),
RemoteAddress: tcpip.AddrFrom4(p.Src.Addr().As4()),
RemotePort: p.Src.Port(),
}
} else {
netProto = ipv6.ProtocolNumber
id = stack.TransportEndpointID{
LocalAddress: tcpip.AddrFrom16(p.Dst.Addr().As16()),
LocalPort: p.Dst.Port(),
RemoteAddress: tcpip.AddrFrom16(p.Src.Addr().As16()),
RemotePort: p.Src.Port(),
}
}
var transProto tcpip.TransportProtocolNumber
if p.IPProto == ipproto.TCP {
transProto = tcp.ProtocolNumber
} else {
transProto = udp.ProtocolNumber
}
ep := ns.ipstack.FindTransportEndpoint(netProto, transProto, id, nicID)
if debugNetstack() {
ns.logf("[v2] FindTransportEndpoint: id=%+v found=%v", id, ep != nil)
}
if ep != nil {
return true
}
}
}
return false
}
@ -1575,7 +1625,7 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.
func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) {
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %v", address, err)
return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %w", address, err)
}
var networkProto tcpip.NetworkProtocolNumber
@ -1612,6 +1662,40 @@ func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) {
return gonet.NewUDPConn(&wq, ep), nil
}
// ListenTCP listens for TCP connections on the given address.
func (ns *Impl) ListenTCP(network, address string) (*gonet.TCPListener, error) {
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %w", address, err)
}
var networkProto tcpip.NetworkProtocolNumber
switch network {
case "tcp4":
networkProto = ipv4.ProtocolNumber
if ap.Addr().IsValid() && !ap.Addr().Is4() {
return nil, fmt.Errorf("netstack: tcp4 requires an IPv4 address")
}
case "tcp6":
networkProto = ipv6.ProtocolNumber
if ap.Addr().IsValid() && !ap.Addr().Is6() {
return nil, fmt.Errorf("netstack: tcp6 requires an IPv6 address")
}
default:
return nil, fmt.Errorf("netstack: unsupported network %q", network)
}
localAddress := tcpip.FullAddress{
NIC: nicID,
Port: ap.Port(),
}
if ap.Addr().IsValid() && !ap.Addr().IsUnspecified() {
localAddress.Addr = tcpip.AddrFromSlice(ap.Addr().AsSlice())
}
return gonet.ListenTCP(ns.ipstack, localAddress, networkProto)
}
func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
sess := r.ID()
if debugNetstack() {

Loading…
Cancel
Save