wgengine/netstack: implement netstack loopback (#13301)

When the TS_DEBUG_NETSTACK_LOOPBACK_PORT environment variable is set,
netstack will loop back (dnat to addressFamilyLoopback:loopbackPort)
TCP & UDP flows originally destined to localServicesIP:loopbackPort.
localServicesIP is quad-100 or the IPv6 equivalent.

Updates tailscale/corp#22713

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/13302/head
Jordan Whited 3 months ago committed by GitHub
parent 80b2b45d60
commit d21ebc28af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -23,6 +23,7 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -1204,6 +1205,149 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) {
d1.MustCleanShutdown(t) d1.MustCleanShutdown(t)
} }
// TestNetstackTCPLoopback tests netstack loopback of a TCP stream, in both
// directions.
// TODO(jwhited): do the same for UDP
func TestNetstackTCPLoopback(t *testing.T) {
tstest.Shard(t)
if os.Getuid() != 0 {
t.Skip("skipping when not root")
}
env := newTestEnv(t)
env.tunMode = true
loopbackPort := uint16(5201)
env.loopbackPort = &loopbackPort
loopbackPortStr := strconv.Itoa(int(loopbackPort))
n1 := newTestNode(t, env)
d1 := n1.StartDaemon()
n1.AwaitResponding()
n1.MustUp()
n1.AwaitIP4()
n1.AwaitRunning()
cases := []struct {
lisAddr string
network string
dialAddr string
}{
{
lisAddr: net.JoinHostPort("127.0.0.1", loopbackPortStr),
network: "tcp4",
dialAddr: net.JoinHostPort(tsaddr.TailscaleServiceIPString, loopbackPortStr),
},
{
lisAddr: net.JoinHostPort("::1", loopbackPortStr),
network: "tcp6",
dialAddr: net.JoinHostPort(tsaddr.TailscaleServiceIPv6String, loopbackPortStr),
},
}
writeBufSize := 128 << 10 // 128KiB, exercise GSO if enabled
writeBufIterations := 100 // allow TCP send window to open up
wantTotal := writeBufSize * writeBufIterations
for _, c := range cases {
lis, err := net.Listen(c.network, c.lisAddr)
if err != nil {
t.Fatal(err)
}
defer lis.Close()
writeFn := func(conn net.Conn) error {
for i := 0; i < writeBufIterations; i++ {
toWrite := make([]byte, writeBufSize)
var wrote int
for {
n, err := conn.Write(toWrite)
if err != nil {
return err
}
wrote += n
if wrote == len(toWrite) {
break
}
}
}
return nil
}
readFn := func(conn net.Conn) error {
var read int
for {
b := make([]byte, writeBufSize)
n, err := conn.Read(b)
if err != nil {
return err
}
read += n
if read == wantTotal {
return nil
}
}
}
lisStepCh := make(chan error)
go func() {
conn, err := lis.Accept()
if err != nil {
lisStepCh <- err
return
}
lisStepCh <- readFn(conn)
lisStepCh <- writeFn(conn)
}()
var conn net.Conn
err = tstest.WaitFor(time.Second*5, func() error {
conn, err = net.DialTimeout(c.network, c.dialAddr, time.Second*1)
if err != nil {
return err
}
return nil
})
if err != nil {
t.Fatal(err)
}
defer conn.Close()
dialerStepCh := make(chan error)
go func() {
dialerStepCh <- writeFn(conn)
dialerStepCh <- readFn(conn)
}()
var (
dialerSteps int
lisSteps int
)
for {
select {
case lisErr := <-lisStepCh:
if lisErr != nil {
t.Fatal(err)
}
lisSteps++
if dialerSteps == 2 && lisSteps == 2 {
return
}
case dialerErr := <-dialerStepCh:
if dialerErr != nil {
t.Fatal(err)
}
dialerSteps++
if dialerSteps == 2 && lisSteps == 2 {
return
}
}
}
}
d1.MustCleanShutdown(t)
}
// testEnv contains the test environment (set of servers) used by one // testEnv contains the test environment (set of servers) used by one
// or more nodes. // or more nodes.
type testEnv struct { type testEnv struct {
@ -1211,6 +1355,7 @@ type testEnv struct {
tunMode bool tunMode bool
cli string cli string
daemon string daemon string
loopbackPort *uint16
LogCatcher *LogCatcher LogCatcher *LogCatcher
LogCatcherServer *httptest.Server LogCatcherServer *httptest.Server
@ -1511,6 +1656,9 @@ func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon {
"TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost "TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost
"TS_DEBUG_LOG_RATE=all", "TS_DEBUG_LOG_RATE=all",
) )
if n.env.loopbackPort != nil {
cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(int(*n.env.loopbackPort)))
}
if version.IsRace() { if version.IsRace() {
cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1") cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1")
} }

@ -187,6 +187,11 @@ type Impl struct {
dns *dns.Manager dns *dns.Manager
driveForLocal drive.FileSystemForLocal // or nil driveForLocal drive.FileSystemForLocal // or nil
// loopbackPort, if non-nil, will enable Impl to loop back (dnat to
// <address-family-loopback>:loopbackPort) TCP & UDP flows originally
// destined to serviceIP{v6}:loopbackPort.
loopbackPort *int
peerapiPort4Atomic atomic.Uint32 // uint16 port number for IPv4 peerapi peerapiPort4Atomic atomic.Uint32 // uint16 port number for IPv4 peerapi
peerapiPort6Atomic atomic.Uint32 // uint16 port number for IPv6 peerapi peerapiPort6Atomic atomic.Uint32 // uint16 port number for IPv6 peerapi
@ -378,6 +383,10 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
dns: dns, dns: dns,
driveForLocal: driveForLocal, driveForLocal: driveForLocal,
} }
loopbackPort, ok := envknob.LookupInt("TS_DEBUG_NETSTACK_LOOPBACK_PORT")
if ok && loopbackPort >= 0 && loopbackPort <= math.MaxUint16 {
ns.loopbackPort = &loopbackPort
}
ns.ctx, ns.ctxCancel = context.WithCancel(context.Background()) ns.ctx, ns.ctxCancel = context.WithCancel(context.Background())
ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc()) ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc())
ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound
@ -706,6 +715,13 @@ func (ns *Impl) UpdateNetstackIPs(nm *netmap.NetworkMap) {
} }
} }
func (ns *Impl) isLoopbackPort(port uint16) bool {
if ns.loopbackPort != nil && int(port) == *ns.loopbackPort {
return true
}
return false
}
// handleLocalPackets is hooked into the tun datapath for packets leaving // handleLocalPackets is hooked into the tun datapath for packets leaving
// the host and arriving at tailscaled. This method returns filter.DropSilently // the host and arriving at tailscaled. This method returns filter.DropSilently
// to intercept a packet for handling, for instance traffic to quad-100. // to intercept a packet for handling, for instance traffic to quad-100.
@ -724,11 +740,11 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) filter.Re
// 80, and 8080. // 80, and 8080.
switch p.IPProto { switch p.IPProto {
case ipproto.TCP: case ipproto.TCP:
if port := p.Dst.Port(); port != 53 && port != 80 && port != 8080 { if port := p.Dst.Port(); port != 53 && port != 80 && port != 8080 && !ns.isLoopbackPort(port) {
return filter.Accept return filter.Accept
} }
case ipproto.UDP: case ipproto.UDP:
if port := p.Dst.Port(); port != 53 { if port := p.Dst.Port(); port != 53 && !ns.isLoopbackPort(port) {
return filter.Accept return filter.Accept
} }
} }
@ -1169,6 +1185,11 @@ func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr {
return netip.Addr{} return netip.Addr{}
} }
var (
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
ipv6Loopback = netip.MustParseAddr("::1")
)
func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
reqDetails := r.ID() reqDetails := r.ID()
if debugNetstack() { if debugNetstack() {
@ -1305,8 +1326,15 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
return return
} }
} }
if isTailscaleIP { switch {
dialIP = netaddr.IPv4(127, 0, 0, 1) case hittingServiceIP && ns.isLoopbackPort(reqDetails.LocalPort):
if dialIP == serviceIPv6 {
dialIP = ipv6Loopback
} else {
dialIP = ipv4Loopback
}
case isTailscaleIP:
dialIP = ipv4Loopback
} }
dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort)) dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort))
@ -1457,16 +1485,23 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
return return
} }
// Handle magicDNS traffic (via UDP) here. // Handle magicDNS and loopback traffic (via UDP) here.
if dst := dstAddr.Addr(); dst == serviceIP || dst == serviceIPv6 { if dst := dstAddr.Addr(); dst == serviceIP || dst == serviceIPv6 {
if dstAddr.Port() != 53 { switch {
ep.Close() case dstAddr.Port() == 53:
return // Only MagicDNS traffic runs on the service IPs for now.
}
c := gonet.NewUDPConn(&wq, ep) c := gonet.NewUDPConn(&wq, ep)
go ns.handleMagicDNSUDP(srcAddr, c) go ns.handleMagicDNSUDP(srcAddr, c)
return return
case ns.isLoopbackPort(dstAddr.Port()):
if dst == serviceIPv6 {
dstAddr = netip.AddrPortFrom(ipv6Loopback, dstAddr.Port())
} else {
dstAddr = netip.AddrPortFrom(ipv4Loopback, dstAddr.Port())
}
default:
ep.Close()
return // Only MagicDNS and loopback traffic runs on the service IPs for now.
}
} }
if get := ns.GetUDPHandlerForFlow; get != nil { if get := ns.GetUDPHandlerForFlow; get != nil {
@ -1545,9 +1580,17 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, clientAddr, dstAddr netip.Addr
var backendListenAddr *net.UDPAddr var backendListenAddr *net.UDPAddr
var backendRemoteAddr *net.UDPAddr var backendRemoteAddr *net.UDPAddr
isLocal := ns.isLocalIP(dstAddr.Addr()) isLocal := ns.isLocalIP(dstAddr.Addr())
isLoopback := dstAddr.Addr() == ipv4Loopback || dstAddr.Addr() == ipv6Loopback
if isLocal { if isLocal {
backendRemoteAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)} backendRemoteAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
backendListenAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(srcPort)} backendListenAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(srcPort)}
} else if isLoopback {
ip := net.IP(ipv4Loopback.AsSlice())
if dstAddr.Addr() == ipv6Loopback {
ip = ipv6Loopback.AsSlice()
}
backendRemoteAddr = &net.UDPAddr{IP: ip, Port: int(port)}
backendListenAddr = &net.UDPAddr{IP: ip, Port: int(srcPort)}
} else { } else {
if dstIP := dstAddr.Addr(); viaRange.Contains(dstIP) { if dstIP := dstAddr.Addr(); viaRange.Contains(dstIP) {
dstAddr = netip.AddrPortFrom(tsaddr.UnmapVia(dstIP), dstAddr.Port()) dstAddr = netip.AddrPortFrom(tsaddr.UnmapVia(dstIP), dstAddr.Port())

Loading…
Cancel
Save