From 63d563e7340b4712b9f2933f663057ce2dcfa4a4 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Thu, 15 Jan 2026 20:35:41 -0800 Subject: [PATCH] tsnet: add support for a user-supplied tun.Device tsnet users can now provide a tun.Device, including any custom implementation that conforms to the interface. netstack has a new option CheckLocalTransportEndpoints that when used alongside a TUN enables netstack listens and dials to correctly capture traffic associated with those sockets. tsnet with a TUN sets this option, while all other builds leave this at false to preserve existing performance. Updates #18423 Signed-off-by: James Tucker --- tsnet/tsnet.go | 88 ++++- tsnet/tsnet_test.go | 673 ++++++++++++++++++++++++++++++++++ wgengine/netstack/netstack.go | 86 ++++- 3 files changed, 842 insertions(+), 5 deletions(-) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index bf7e694df..d627d55b3 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -26,6 +26,7 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/tun" "tailscale.com/client/local" "tailscale.com/control/controlclient" "tailscale.com/envknob" @@ -167,6 +168,11 @@ type Server struct { // that the control server will allow the node to adopt that tag. AdvertiseTags []string + // Tun, if non-nil, specifies a custom tun.Device to use for packet I/O. + // + // This field must be set before calling Start. + Tun tun.Device + initOnce sync.Once initErr error lb *ipnlocal.LocalBackend @@ -659,6 +665,7 @@ func (s *Server) start() (reterr error) { s.dialer = &tsdial.Dialer{Logf: tsLogf} // mutated below (before used) s.dialer.SetBus(sys.Bus.Get()) eng, err := wgengine.NewUserspaceEngine(tsLogf, wgengine.Config{ + Tun: s.Tun, EventBus: sys.Bus.Get(), ListenPort: s.Port, NetMon: s.netMon, @@ -682,8 +689,16 @@ func (s *Server) start() (reterr error) { } sys.Tun.Get().Start() sys.Set(ns) - ns.ProcessLocalIPs = true - ns.ProcessSubnets = true + if s.Tun == nil { + // Only process packets in netstack when using the default fake TUN. + // When a TUN is provided, let packets flow through it instead. + ns.ProcessLocalIPs = true + ns.ProcessSubnets = true + } else { + // When using a TUN, check gVisor for registered endpoints to handle + // packets for tsnet listeners and outbound connection replies. + ns.CheckLocalTransportEndpoints = true + } ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow s.netstack = ns @@ -1072,10 +1087,34 @@ func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) { network = "udp6" } } - if err := s.Start(); err != nil { + + netLn, err := s.listen(network, addr, listenOnTailnet) + if err != nil { return nil, err } - return s.netstack.ListenPacket(network, ap.String()) + ln := netLn.(*listener) + + pc, err := s.netstack.ListenPacket(network, ap.String()) + if err != nil { + ln.Close() + return nil, err + } + + return &udpPacketConn{ + PacketConn: pc, + ln: ln, + }, nil +} + +// udpPacketConn wraps a net.PacketConn to unregister from s.listeners on Close. +type udpPacketConn struct { + net.PacketConn + ln *listener +} + +func (c *udpPacketConn) Close() error { + c.ln.Close() + return c.PacketConn.Close() } // ListenTLS announces only on the Tailscale network. @@ -1611,10 +1650,37 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro closedc: make(chan struct{}), conn: make(chan net.Conn), } + + // When using a TUN with TCP, create a gVisor TCP listener. + if s.Tun != nil && (network == "" || network == "tcp" || network == "tcp4" || network == "tcp6") { + var nsNetwork string + nsAddr := host + switch { + case network == "tcp4" || network == "tcp6": + nsNetwork = network + case host.Addr().Is4(): + nsNetwork = "tcp4" + case host.Addr().Is6(): + nsNetwork = "tcp6" + default: + // Wildcard address: use tcp6 for dual-stack (accepts both v4 and v6). + nsNetwork = "tcp6" + nsAddr = netip.AddrPortFrom(netip.IPv6Unspecified(), host.Port()) + } + gonetLn, err := s.netstack.ListenTCP(nsNetwork, nsAddr.String()) + if err != nil { + return nil, fmt.Errorf("tsnet: %w", err) + } + ln.gonetLn = gonetLn + } + s.mu.Lock() for _, key := range keys { if _, ok := s.listeners[key]; ok { s.mu.Unlock() + if ln.gonetLn != nil { + ln.gonetLn.Close() + } return nil, fmt.Errorf("tsnet: listener already open for %s, %s", network, addr) } } @@ -1684,9 +1750,17 @@ type listener struct { conn chan net.Conn // unbuffered, never closed closedc chan struct{} // closed on [listener.Close] closed bool // guarded by s.mu + + // gonetLn, if set, is the gonet.Listener that handles new connections. + // gonetLn is set by [listen] when a TUN is in use and terminates the listener. + // gonetLn is nil when TUN is nil. + gonetLn net.Listener } func (ln *listener) Accept() (net.Conn, error) { + if ln.gonetLn != nil { + return ln.gonetLn.Accept() + } select { case c := <-ln.conn: return c, nil @@ -1696,6 +1770,9 @@ func (ln *listener) Accept() (net.Conn, error) { } func (ln *listener) Addr() net.Addr { + if ln.gonetLn != nil { + return ln.gonetLn.Addr() + } return addr{ network: ln.keys[0].network, addr: ln.addr, @@ -1721,6 +1798,9 @@ func (ln *listener) closeLocked() error { } close(ln.closedc) ln.closed = true + if ln.gonetLn != nil { + ln.gonetLn.Close() + } return nil } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index f44bacab0..2c6970fa3 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -39,6 +39,7 @@ import ( "github.com/google/go-cmp/cmp" dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" + "github.com/tailscale/wireguard-go/tun" "golang.org/x/net/proxy" "tailscale.com/client/local" @@ -48,11 +49,13 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/net/netns" + "tailscale.com/net/packet" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/deptest" "tailscale.com/tstest/integration" "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/views" @@ -1860,6 +1863,676 @@ func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *local.Client) { t.Error("magicsock did not find a direct path from lc1 to lc2") } +// chanTUN is a tun.Device for testing that uses channels for packet I/O. +// Inbound receives packets written to the TUN (from the perspective of the network stack). +// Outbound is for injecting packets to be read from the TUN. +type chanTUN struct { + Inbound chan []byte // packets written to TUN + Outbound chan []byte // packets to read from TUN + closed chan struct{} + events chan tun.Event +} + +func newChanTUN() *chanTUN { + t := &chanTUN{ + Inbound: make(chan []byte, 10), + Outbound: make(chan []byte, 10), + closed: make(chan struct{}), + events: make(chan tun.Event, 1), + } + t.events <- tun.EventUp + return t +} + +func (t *chanTUN) File() *os.File { panic("not implemented") } + +func (t *chanTUN) Close() error { + select { + case <-t.closed: + default: + close(t.closed) + close(t.Inbound) + } + return nil +} + +func (t *chanTUN) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + select { + case <-t.closed: + return 0, io.EOF + case pkt := <-t.Outbound: + sizes[0] = copy(bufs[0][offset:], pkt) + return 1, nil + } +} + +func (t *chanTUN) Write(bufs [][]byte, offset int) (int, error) { + for _, buf := range bufs { + pkt := buf[offset:] + if len(pkt) == 0 { + continue + } + select { + case <-t.closed: + return 0, errors.New("closed") + case t.Inbound <- slices.Clone(pkt): + } + } + return len(bufs), nil +} + +func (t *chanTUN) MTU() (int, error) { return 1280, nil } +func (t *chanTUN) Name() (string, error) { return "chantun", nil } +func (t *chanTUN) Events() <-chan tun.Event { return t.events } +func (t *chanTUN) BatchSize() int { return 1 } + +// listenTest provides common setup for listener and TUN tests. +type listenTest struct { + s1, s2 *Server + s1ip4, s1ip6 netip.Addr + s2ip4, s2ip6 netip.Addr + tun *chanTUN // nil for netstack mode +} + +// setupListenTest creates two tsnet servers for testing. +// If useTUN is true, s2 uses a chanTUN; otherwise it uses netstack only. +func setupListenTest(t *testing.T, useTUN bool) *listenTest { + t.Helper() + tstest.Shard(t) + tstest.ResourceCheck(t) + ctx := t.Context() + controlURL, _ := startControl(t) + s1, _, _ := startServer(t, ctx, controlURL, "s1") + + tmp := filepath.Join(t.TempDir(), "s2") + must.Do(os.MkdirAll(tmp, 0755)) + s2 := &Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: "s2", + Store: new(mem.Store), + Ephemeral: true, + } + + var tun *chanTUN + if useTUN { + tun = newChanTUN() + s2.Tun = tun + } + + if *verboseNodes { + s2.Logf = t.Logf + } + t.Cleanup(func() { s2.Close() }) + + s2status, err := s2.Up(ctx) + if err != nil { + t.Fatal(err) + } + + s1ip4, s1ip6 := s1.TailscaleIPs() + s2ip4 := s2status.TailscaleIPs[0] + var s2ip6 netip.Addr + if len(s2status.TailscaleIPs) > 1 { + s2ip6 = s2status.TailscaleIPs[1] + } + + lc1 := must.Get(s1.LocalClient()) + must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP)) + + return &listenTest{ + s1: s1, + s2: s2, + s1ip4: s1ip4, + s1ip6: s1ip6, + s2ip4: s2ip4, + s2ip6: s2ip6, + tun: tun, + } +} + +// echoUDP returns an IP packet with src/dst and ports swapped, with checksums recomputed. +func echoUDP(pkt []byte) []byte { + var p packet.Parsed + p.Decode(pkt) + if p.IPProto != ipproto.UDP { + return nil + } + switch p.IPVersion { + case 4: + h := p.UDP4Header() + h.ToResponse() + return packet.Generate(h, p.Payload()) + case 6: + h := packet.UDP6Header{ + IP6Header: p.IP6Header(), + SrcPort: p.Src.Port(), + DstPort: p.Dst.Port(), + } + h.ToResponse() + return packet.Generate(h, p.Payload()) + } + return nil +} + +func TestTUN(t *testing.T) { + tt := setupListenTest(t, true) + + go func() { + for pkt := range tt.tun.Inbound { + var p packet.Parsed + p.Decode(pkt) + if p.Dst.Port() == 9999 { + tt.tun.Outbound <- echoUDP(pkt) + } + } + }() + + test := func(t *testing.T, s2ip netip.Addr) { + conn, err := tt.s1.Dial(t.Context(), "udp", netip.AddrPortFrom(s2ip, 9999).String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + want := "hello from s1" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatal(err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + t.Fatalf("reading echo response: %v", err) + } + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("IPv4", func(t *testing.T) { test(t, tt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { test(t, tt.s2ip6) }) +} + +// TestTUNDNS tests that a TUN can send DNS queries to quad-100 and receive +// responses. This verifies that handleLocalPackets intercepts outbound traffic +// to the service IP. +func TestTUNDNS(t *testing.T) { + tt := setupListenTest(t, true) + + test := func(t *testing.T, srcIP netip.Addr, serviceIP netip.Addr) { + tt.tun.Outbound <- buildDNSQuery("s2", srcIP) + + ipVersion := uint8(4) + if srcIP.Is6() { + ipVersion = 6 + } + for { + select { + case pkt := <-tt.tun.Inbound: + var p packet.Parsed + p.Decode(pkt) + if p.IPVersion != ipVersion || p.IPProto != ipproto.UDP { + continue + } + if p.Src.Addr() == serviceIP && p.Src.Port() == 53 { + if len(p.Payload()) < 12 { + t.Fatalf("DNS response too short: %d bytes", len(p.Payload())) + } + return // success + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for DNS response") + } + } + } + + t.Run("IPv4", func(t *testing.T) { + test(t, tt.s2ip4, netip.MustParseAddr("100.100.100.100")) + }) + t.Run("IPv6", func(t *testing.T) { + test(t, tt.s2ip6, netip.MustParseAddr("fd7a:115c:a1e0::53")) + }) +} + +// TestListenPacket tests UDP listeners (ListenPacket) in both netstack and TUN modes. +func TestListenPacket(t *testing.T) { + testListenPacket := func(t *testing.T, lt *listenTest, listenIP netip.Addr) { + pc, err := lt.s2.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String()) + if err != nil { + t.Fatal(err) + } + defer pc.Close() + + echoErr := make(chan error, 1) + go func() { + buf := make([]byte, 1500) + n, addr, err := pc.ReadFrom(buf) + if err != nil { + echoErr <- err + return + } + _, err = pc.WriteTo(buf[:n], addr) + if err != nil { + echoErr <- err + return + } + }() + + conn, err := lt.s1.Dial(t.Context(), "udp", pc.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + want := "hello udp" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatal(err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupListenTest(t, false) + t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupListenTest(t, true) + t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) }) + }) +} + +// TestListenTCP tests TCP listeners with concrete addresses in both netstack +// and TUN modes. +func TestListenTCP(t *testing.T) { + testListenTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) { + ln, err := lt.s2.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String()) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + echoErr := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:n]) + if err != nil { + echoErr <- err + return + } + }() + + conn, err := lt.s1.Dial(t.Context(), "tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer conn.Close() + + want := "hello tcp" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupListenTest(t, false) + t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupListenTest(t, true) + t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) }) + }) +} + +// TestListenTCPDualStack tests TCP listeners with wildcard addresses (dual-stack) +// in both netstack and TUN modes. +func TestListenTCPDualStack(t *testing.T) { + testListenTCPDualStack := func(t *testing.T, lt *listenTest, dialIP netip.Addr) { + ln, err := lt.s2.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + _, portStr, err := net.SplitHostPort(ln.Addr().String()) + if err != nil { + t.Fatalf("parsing listener address %q: %v", ln.Addr().String(), err) + } + + echoErr := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:n]) + if err != nil { + echoErr <- err + return + } + }() + + dialAddr := net.JoinHostPort(dialIP.String(), portStr) + conn, err := lt.s1.Dial(t.Context(), "tcp", dialAddr) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", dialAddr, err) + } + defer conn.Close() + + want := "hello tcp dualstack" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupListenTest(t, false) + t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) }) + t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupListenTest(t, true) + t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) }) + t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) }) + }) +} + +// TestDialTCP tests TCP dialing from s2 to s1 in both netstack and TUN modes. +// In TUN mode, this verifies that outbound TCP connections and their replies +// are handled by netstack without packets escaping to the TUN. +func TestDialTCP(t *testing.T) { + testDialTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) { + ln, err := lt.s1.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String()) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + echoErr := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:n]) + if err != nil { + echoErr <- err + return + } + }() + + conn, err := lt.s2.Dial(t.Context(), "tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer conn.Close() + + want := "hello tcp dial" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupListenTest(t, false) + t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) }) + t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupListenTest(t, true) + + var escapedTCPPackets atomic.Int32 + var wg sync.WaitGroup + wg.Go(func() { + for pkt := range lt.tun.Inbound { + var p packet.Parsed + p.Decode(pkt) + if p.IPProto == ipproto.TCP { + escapedTCPPackets.Add(1) + t.Logf("TCP packet escaped to TUN: %v -> %v", p.Src, p.Dst) + } + } + }) + + t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) }) + t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) }) + + lt.tun.Close() + wg.Wait() + if escaped := escapedTCPPackets.Load(); escaped > 0 { + t.Errorf("%d TCP packets escaped to TUN", escaped) + } + }) +} + +// TestDialUDP tests UDP dialing from s2 to s1 in both netstack and TUN modes. +// In TUN mode, this verifies that outbound UDP connections register endpoints +// with gVisor, allowing reply packets to be routed through netstack instead of +// escaping to the TUN. +func TestDialUDP(t *testing.T) { + testDialUDP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) { + pc, err := lt.s1.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String()) + if err != nil { + t.Fatal(err) + } + defer pc.Close() + + echoErr := make(chan error, 1) + go func() { + buf := make([]byte, 1500) + n, addr, err := pc.ReadFrom(buf) + if err != nil { + echoErr <- err + return + } + _, err = pc.WriteTo(buf[:n], addr) + if err != nil { + echoErr <- err + return + } + }() + + conn, err := lt.s2.Dial(t.Context(), "udp", pc.LocalAddr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer conn.Close() + + want := "hello udp dial" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupListenTest(t, false) + t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) }) + t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupListenTest(t, true) + + var escapedUDPPackets atomic.Int32 + var wg sync.WaitGroup + wg.Go(func() { + for pkt := range lt.tun.Inbound { + var p packet.Parsed + p.Decode(pkt) + if p.IPProto == ipproto.UDP { + escapedUDPPackets.Add(1) + t.Logf("UDP packet escaped to TUN: %v -> %v", p.Src, p.Dst) + } + } + }) + + t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) }) + t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) }) + + lt.tun.Close() + wg.Wait() + if escaped := escapedUDPPackets.Load(); escaped > 0 { + t.Errorf("%d UDP packets escaped to TUN", escaped) + } + }) +} + +// buildDNSQuery builds a UDP/IP packet containing a DNS query for name to the +// Tailscale service IP (100.100.100.100 for IPv4, fd7a:115c:a1e0::53 for IPv6). +func buildDNSQuery(name string, srcIP netip.Addr) []byte { + qtype := byte(0x01) // Type A for IPv4 + if srcIP.Is6() { + qtype = 0x1c // Type AAAA for IPv6 + } + dns := []byte{ + 0x12, 0x34, // ID + 0x01, 0x00, // Flags: standard query, recursion desired + 0x00, 0x01, // QDCOUNT: 1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ANCOUNT, NSCOUNT, ARCOUNT + } + for _, label := range strings.Split(name, ".") { + dns = append(dns, byte(len(label))) + dns = append(dns, label...) + } + dns = append(dns, 0x00, 0x00, qtype, 0x00, 0x01) // null, Type A/AAAA, Class IN + + if srcIP.Is4() { + h := packet.UDP4Header{ + IP4Header: packet.IP4Header{ + Src: srcIP, + Dst: netip.MustParseAddr("100.100.100.100"), + }, + SrcPort: 12345, + DstPort: 53, + } + return packet.Generate(h, dns) + } + h := packet.UDP6Header{ + IP6Header: packet.IP6Header{ + Src: srcIP, + Dst: netip.MustParseAddr("fd7a:115c:a1e0::53"), + }, + SrcPort: 12345, + DstPort: 53, + } + return packet.Generate(h, dns) +} + func TestDeps(t *testing.T) { tstest.Shard(t) deptest.DepChecker{ diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index c2b5d8a32..e05846e15 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -165,6 +165,17 @@ type Impl struct { // over the UDP flow. GetUDPHandlerForFlow func(src, dst netip.AddrPort) (handler func(nettype.ConnPacketConn), intercept bool) + // CheckLocalTransportEndpoints, if true, causes netstack to check if gVisor + // has a registered endpoint for incoming packets to local IPs. This is used + // by tsnet to intercept packets for registered listeners and outbound + // connections when ProcessLocalIPs is false (i.e., when using a TUN). + // It can only be set before calling Start. + // TODO(raggi): refactor the way we handle both CheckLocalTransportEndpoints + // and the earlier netstack registrations for serve, funnel, peerAPI and so + // on. Currently this optimizes away cost for tailscaled in TUN mode, while + // enabling extension support when using tsnet in TUN mode. See #18423. + CheckLocalTransportEndpoints bool + // ProcessLocalIPs is whether netstack should handle incoming // traffic directed at the Node.Addresses (local IPs). // It can only be set before calling Start. @@ -1109,6 +1120,45 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { if ns.ProcessSubnets && !isLocal { return true } + if isLocal && ns.CheckLocalTransportEndpoints { + // Handle packets to registered listeners and replies to outbound + // connections by checking if gVisor has a registered endpoint. + // This covers TCP listeners, UDP listeners, and outbound TCP replies. + if p.IPProto == ipproto.TCP || p.IPProto == ipproto.UDP { + var netProto tcpip.NetworkProtocolNumber + var id stack.TransportEndpointID + if p.Dst.Addr().Is4() { + netProto = ipv4.ProtocolNumber + id = stack.TransportEndpointID{ + LocalAddress: tcpip.AddrFrom4(p.Dst.Addr().As4()), + LocalPort: p.Dst.Port(), + RemoteAddress: tcpip.AddrFrom4(p.Src.Addr().As4()), + RemotePort: p.Src.Port(), + } + } else { + netProto = ipv6.ProtocolNumber + id = stack.TransportEndpointID{ + LocalAddress: tcpip.AddrFrom16(p.Dst.Addr().As16()), + LocalPort: p.Dst.Port(), + RemoteAddress: tcpip.AddrFrom16(p.Src.Addr().As16()), + RemotePort: p.Src.Port(), + } + } + var transProto tcpip.TransportProtocolNumber + if p.IPProto == ipproto.TCP { + transProto = tcp.ProtocolNumber + } else { + transProto = udp.ProtocolNumber + } + ep := ns.ipstack.FindTransportEndpoint(netProto, transProto, id, nicID) + if debugNetstack() { + ns.logf("[v2] FindTransportEndpoint: id=%+v found=%v", id, ep != nil) + } + if ep != nil { + return true + } + } + } return false } @@ -1575,7 +1625,7 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) { ap, err := netip.ParseAddrPort(address) if err != nil { - return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %v", address, err) + return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %w", address, err) } var networkProto tcpip.NetworkProtocolNumber @@ -1612,6 +1662,40 @@ func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) { return gonet.NewUDPConn(&wq, ep), nil } +// ListenTCP listens for TCP connections on the given address. +func (ns *Impl) ListenTCP(network, address string) (*gonet.TCPListener, error) { + ap, err := netip.ParseAddrPort(address) + if err != nil { + return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %w", address, err) + } + + var networkProto tcpip.NetworkProtocolNumber + switch network { + case "tcp4": + networkProto = ipv4.ProtocolNumber + if ap.Addr().IsValid() && !ap.Addr().Is4() { + return nil, fmt.Errorf("netstack: tcp4 requires an IPv4 address") + } + case "tcp6": + networkProto = ipv6.ProtocolNumber + if ap.Addr().IsValid() && !ap.Addr().Is6() { + return nil, fmt.Errorf("netstack: tcp6 requires an IPv6 address") + } + default: + return nil, fmt.Errorf("netstack: unsupported network %q", network) + } + + localAddress := tcpip.FullAddress{ + NIC: nicID, + Port: ap.Port(), + } + if ap.Addr().IsValid() && !ap.Addr().IsUnspecified() { + localAddress.Addr = tcpip.AddrFromSlice(ap.Addr().AsSlice()) + } + + return gonet.ListenTCP(ns.ipstack, localAddress, networkProto) +} + func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { sess := r.ID() if debugNetstack() {