tstest/integration: add UDP netstack loopback integration test (#13312)

Updates tailscale/corp#22713

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/13318/head
Jordan Whited 3 months ago committed by GitHub
parent e93c160a39
commit 71acf87830
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -40,6 +40,7 @@ import (
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/ipn/store" "tailscale.com/ipn/store"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/net/tstun"
"tailscale.com/safesocket" "tailscale.com/safesocket"
"tailscale.com/syncs" "tailscale.com/syncs"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -1207,7 +1208,6 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) {
// TestNetstackTCPLoopback tests netstack loopback of a TCP stream, in both // TestNetstackTCPLoopback tests netstack loopback of a TCP stream, in both
// directions. // directions.
// TODO(jwhited): do the same for UDP
func TestNetstackTCPLoopback(t *testing.T) { func TestNetstackTCPLoopback(t *testing.T) {
tstest.Shard(t) tstest.Shard(t)
if os.Getuid() != 0 { if os.Getuid() != 0 {
@ -1216,9 +1216,9 @@ func TestNetstackTCPLoopback(t *testing.T) {
env := newTestEnv(t) env := newTestEnv(t)
env.tunMode = true env.tunMode = true
loopbackPort := uint16(5201) loopbackPort := 5201
env.loopbackPort = &loopbackPort env.loopbackPort = &loopbackPort
loopbackPortStr := strconv.Itoa(int(loopbackPort)) loopbackPortStr := strconv.Itoa(loopbackPort)
n1 := newTestNode(t, env) n1 := newTestNode(t, env)
d1 := n1.StartDaemon() d1 := n1.StartDaemon()
@ -1348,6 +1348,153 @@ func TestNetstackTCPLoopback(t *testing.T) {
d1.MustCleanShutdown(t) d1.MustCleanShutdown(t)
} }
// TestNetstackUDPLoopback tests netstack loopback of UDP packets, in both
// directions.
func TestNetstackUDPLoopback(t *testing.T) {
tstest.Shard(t)
if os.Getuid() != 0 {
t.Skip("skipping when not root")
}
env := newTestEnv(t)
env.tunMode = true
loopbackPort := 5201
env.loopbackPort = &loopbackPort
n1 := newTestNode(t, env)
d1 := n1.StartDaemon()
n1.AwaitResponding()
n1.MustUp()
ip4 := n1.AwaitIP4()
ip6 := n1.AwaitIP6()
n1.AwaitRunning()
cases := []struct {
pingerLAddr *net.UDPAddr
pongerLAddr *net.UDPAddr
network string
dialAddr *net.UDPAddr
}{
{
pingerLAddr: &net.UDPAddr{IP: ip4.AsSlice(), Port: loopbackPort + 1},
pongerLAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: loopbackPort},
network: "udp4",
dialAddr: &net.UDPAddr{IP: tsaddr.TailscaleServiceIP().AsSlice(), Port: loopbackPort},
},
{
pingerLAddr: &net.UDPAddr{IP: ip6.AsSlice(), Port: loopbackPort + 1},
pongerLAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: loopbackPort},
network: "udp6",
dialAddr: &net.UDPAddr{IP: tsaddr.TailscaleServiceIPv6().AsSlice(), Port: loopbackPort},
},
}
writeBufSize := int(tstun.DefaultTUNMTU()) - 40 - 8 // mtu - ipv6 header - udp header
wantPongs := 100
for _, c := range cases {
pongerConn, err := net.ListenUDP(c.network, c.pongerLAddr)
if err != nil {
t.Fatal(err)
}
defer pongerConn.Close()
var pingerConn *net.UDPConn
err = tstest.WaitFor(time.Second*5, func() error {
pingerConn, err = net.DialUDP(c.network, c.pingerLAddr, c.dialAddr)
return err
})
if err != nil {
t.Fatal(err)
}
defer pingerConn.Close()
pingerFn := func(conn *net.UDPConn) error {
b := make([]byte, writeBufSize)
n, err := conn.Write(b)
if err != nil {
return err
}
if n != len(b) {
return fmt.Errorf("bad write size: %d", n)
}
err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * 500))
if err != nil {
return err
}
n, err = conn.Read(b)
if err != nil {
return err
}
if n != len(b) {
return fmt.Errorf("bad read size: %d", n)
}
return nil
}
pongerFn := func(conn *net.UDPConn) error {
for {
b := make([]byte, writeBufSize)
n, from, err := conn.ReadFromUDP(b)
if err != nil {
return err
}
if n != len(b) {
return fmt.Errorf("bad read size: %d", n)
}
n, err = conn.WriteToUDP(b, from)
if err != nil {
return err
}
if n != len(b) {
return fmt.Errorf("bad write size: %d", n)
}
}
}
pongerErrCh := make(chan error, 1)
go func() {
pongerErrCh <- pongerFn(pongerConn)
}()
err = tstest.WaitFor(time.Second*5, func() error {
err = pingerFn(pingerConn)
if err != nil {
return err
}
return nil
})
if err != nil {
t.Fatal(err)
}
var pongsRX int
for {
pingerErrCh := make(chan error)
go func() {
pingerErrCh <- pingerFn(pingerConn)
}()
select {
case err := <-pongerErrCh:
t.Fatal(err)
case err := <-pingerErrCh:
if err != nil {
t.Fatal(err)
}
}
pongsRX++
if pongsRX == wantPongs {
break
}
}
}
d1.MustCleanShutdown(t)
}
// testEnv contains the test environment (set of servers) used by one // testEnv contains the test environment (set of servers) used by one
// or more nodes. // or more nodes.
type testEnv struct { type testEnv struct {
@ -1355,7 +1502,7 @@ type testEnv struct {
tunMode bool tunMode bool
cli string cli string
daemon string daemon string
loopbackPort *uint16 loopbackPort *int
LogCatcher *LogCatcher LogCatcher *LogCatcher
LogCatcherServer *httptest.Server LogCatcherServer *httptest.Server
@ -1657,7 +1804,7 @@ func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon {
"TS_DEBUG_LOG_RATE=all", "TS_DEBUG_LOG_RATE=all",
) )
if n.env.loopbackPort != nil { if n.env.loopbackPort != nil {
cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(int(*n.env.loopbackPort))) cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(*n.env.loopbackPort))
} }
if version.IsRace() { if version.IsRace() {
cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1") cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1")

Loading…
Cancel
Save