diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index df6a62414..368a33283 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -23,6 +23,7 @@ import ( "path/filepath" "regexp" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -1204,13 +1205,157 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { d1.MustCleanShutdown(t) } +// TestNetstackTCPLoopback tests netstack loopback of a TCP stream, in both +// directions. +// TODO(jwhited): do the same for UDP +func TestNetstackTCPLoopback(t *testing.T) { + tstest.Shard(t) + if os.Getuid() != 0 { + t.Skip("skipping when not root") + } + + env := newTestEnv(t) + env.tunMode = true + loopbackPort := uint16(5201) + env.loopbackPort = &loopbackPort + loopbackPortStr := strconv.Itoa(int(loopbackPort)) + n1 := newTestNode(t, env) + d1 := n1.StartDaemon() + + n1.AwaitResponding() + n1.MustUp() + + n1.AwaitIP4() + n1.AwaitRunning() + + cases := []struct { + lisAddr string + network string + dialAddr string + }{ + { + lisAddr: net.JoinHostPort("127.0.0.1", loopbackPortStr), + network: "tcp4", + dialAddr: net.JoinHostPort(tsaddr.TailscaleServiceIPString, loopbackPortStr), + }, + { + lisAddr: net.JoinHostPort("::1", loopbackPortStr), + network: "tcp6", + dialAddr: net.JoinHostPort(tsaddr.TailscaleServiceIPv6String, loopbackPortStr), + }, + } + + writeBufSize := 128 << 10 // 128KiB, exercise GSO if enabled + writeBufIterations := 100 // allow TCP send window to open up + wantTotal := writeBufSize * writeBufIterations + + for _, c := range cases { + lis, err := net.Listen(c.network, c.lisAddr) + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + writeFn := func(conn net.Conn) error { + for i := 0; i < writeBufIterations; i++ { + toWrite := make([]byte, writeBufSize) + var wrote int + for { + n, err := conn.Write(toWrite) + if err != nil { + return err + } + wrote += n + if wrote == len(toWrite) { + break + } + } + } + return nil + } + + readFn := func(conn net.Conn) error { + var read int + for { + b := make([]byte, writeBufSize) + n, err := conn.Read(b) + if err != nil { + return err + } + read += n + if read == wantTotal { + return nil + } + } + } + + lisStepCh := make(chan error) + go func() { + conn, err := lis.Accept() + if err != nil { + lisStepCh <- err + return + } + lisStepCh <- readFn(conn) + lisStepCh <- writeFn(conn) + }() + + var conn net.Conn + err = tstest.WaitFor(time.Second*5, func() error { + conn, err = net.DialTimeout(c.network, c.dialAddr, time.Second*1) + if err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + dialerStepCh := make(chan error) + go func() { + dialerStepCh <- writeFn(conn) + dialerStepCh <- readFn(conn) + }() + + var ( + dialerSteps int + lisSteps int + ) + for { + select { + case lisErr := <-lisStepCh: + if lisErr != nil { + t.Fatal(err) + } + lisSteps++ + if dialerSteps == 2 && lisSteps == 2 { + return + } + case dialerErr := <-dialerStepCh: + if dialerErr != nil { + t.Fatal(err) + } + dialerSteps++ + if dialerSteps == 2 && lisSteps == 2 { + return + } + } + } + } + + d1.MustCleanShutdown(t) +} + // testEnv contains the test environment (set of servers) used by one // or more nodes. type testEnv struct { - t testing.TB - tunMode bool - cli string - daemon string + t testing.TB + tunMode bool + cli string + daemon string + loopbackPort *uint16 LogCatcher *LogCatcher LogCatcherServer *httptest.Server @@ -1511,6 +1656,9 @@ func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { "TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost "TS_DEBUG_LOG_RATE=all", ) + if n.env.loopbackPort != nil { + cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(int(*n.env.loopbackPort))) + } if version.IsRace() { cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1") } diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 47fe23203..2ab40e810 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -187,6 +187,11 @@ type Impl struct { dns *dns.Manager driveForLocal drive.FileSystemForLocal // or nil + // loopbackPort, if non-nil, will enable Impl to loop back (dnat to + // :loopbackPort) TCP & UDP flows originally + // destined to serviceIP{v6}:loopbackPort. + loopbackPort *int + peerapiPort4Atomic atomic.Uint32 // uint16 port number for IPv4 peerapi peerapiPort6Atomic atomic.Uint32 // uint16 port number for IPv6 peerapi @@ -378,6 +383,10 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi dns: dns, driveForLocal: driveForLocal, } + loopbackPort, ok := envknob.LookupInt("TS_DEBUG_NETSTACK_LOOPBACK_PORT") + if ok && loopbackPort >= 0 && loopbackPort <= math.MaxUint16 { + ns.loopbackPort = &loopbackPort + } ns.ctx, ns.ctxCancel = context.WithCancel(context.Background()) ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc()) ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound @@ -706,6 +715,13 @@ func (ns *Impl) UpdateNetstackIPs(nm *netmap.NetworkMap) { } } +func (ns *Impl) isLoopbackPort(port uint16) bool { + if ns.loopbackPort != nil && int(port) == *ns.loopbackPort { + return true + } + return false +} + // handleLocalPackets is hooked into the tun datapath for packets leaving // the host and arriving at tailscaled. This method returns filter.DropSilently // to intercept a packet for handling, for instance traffic to quad-100. @@ -724,11 +740,11 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) filter.Re // 80, and 8080. switch p.IPProto { case ipproto.TCP: - if port := p.Dst.Port(); port != 53 && port != 80 && port != 8080 { + if port := p.Dst.Port(); port != 53 && port != 80 && port != 8080 && !ns.isLoopbackPort(port) { return filter.Accept } case ipproto.UDP: - if port := p.Dst.Port(); port != 53 { + if port := p.Dst.Port(); port != 53 && !ns.isLoopbackPort(port) { return filter.Accept } } @@ -1169,6 +1185,11 @@ func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr { return netip.Addr{} } +var ( + ipv4Loopback = netip.MustParseAddr("127.0.0.1") + ipv6Loopback = netip.MustParseAddr("::1") +) + func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { reqDetails := r.ID() if debugNetstack() { @@ -1305,8 +1326,15 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { return } } - if isTailscaleIP { - dialIP = netaddr.IPv4(127, 0, 0, 1) + switch { + case hittingServiceIP && ns.isLoopbackPort(reqDetails.LocalPort): + if dialIP == serviceIPv6 { + dialIP = ipv6Loopback + } else { + dialIP = ipv4Loopback + } + case isTailscaleIP: + dialIP = ipv4Loopback } dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort)) @@ -1457,16 +1485,23 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { return } - // Handle magicDNS traffic (via UDP) here. + // Handle magicDNS and loopback traffic (via UDP) here. if dst := dstAddr.Addr(); dst == serviceIP || dst == serviceIPv6 { - if dstAddr.Port() != 53 { + switch { + case dstAddr.Port() == 53: + c := gonet.NewUDPConn(&wq, ep) + go ns.handleMagicDNSUDP(srcAddr, c) + return + case ns.isLoopbackPort(dstAddr.Port()): + if dst == serviceIPv6 { + dstAddr = netip.AddrPortFrom(ipv6Loopback, dstAddr.Port()) + } else { + dstAddr = netip.AddrPortFrom(ipv4Loopback, dstAddr.Port()) + } + default: ep.Close() - return // Only MagicDNS traffic runs on the service IPs for now. + return // Only MagicDNS and loopback traffic runs on the service IPs for now. } - - c := gonet.NewUDPConn(&wq, ep) - go ns.handleMagicDNSUDP(srcAddr, c) - return } if get := ns.GetUDPHandlerForFlow; get != nil { @@ -1545,9 +1580,17 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, clientAddr, dstAddr netip.Addr var backendListenAddr *net.UDPAddr var backendRemoteAddr *net.UDPAddr isLocal := ns.isLocalIP(dstAddr.Addr()) + isLoopback := dstAddr.Addr() == ipv4Loopback || dstAddr.Addr() == ipv6Loopback if isLocal { backendRemoteAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)} backendListenAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(srcPort)} + } else if isLoopback { + ip := net.IP(ipv4Loopback.AsSlice()) + if dstAddr.Addr() == ipv6Loopback { + ip = ipv6Loopback.AsSlice() + } + backendRemoteAddr = &net.UDPAddr{IP: ip, Port: int(port)} + backendListenAddr = &net.UDPAddr{IP: ip, Port: int(srcPort)} } else { if dstIP := dstAddr.Addr(); viaRange.Contains(dstIP) { dstAddr = netip.AddrPortFrom(tsaddr.UnmapVia(dstIP), dstAddr.Port())