diff --git a/go.sum b/go.sum index 30922c6ff..10ccefa3c 100644 --- a/go.sum +++ b/go.sum @@ -84,10 +84,6 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/tailscale/winipcfg-go v0.0.0-20200413171540-609dcf2df55f h1:uFj5bslHsMzxIM8UTjAhq4VXeo6GfNW91rpoh/WMJaY= github.com/tailscale/winipcfg-go v0.0.0-20200413171540-609dcf2df55f/go.mod h1:x880GWw5fvrl2DVTQ04ttXQD4DuppTt1Yz6wLibbjNE= -github.com/tailscale/wireguard-go v0.0.0-20200615180905-687c10194779 h1:zg0rgvhBZGA4nvh17nDKcqkEXw6Nbc/Ma2VBvLaW7LU= -github.com/tailscale/wireguard-go v0.0.0-20200615180905-687c10194779/go.mod h1:JPm5cTfu1K+qDFRbiHy0sOlHUylYQbpl356sdYFD8V4= -github.com/tailscale/wireguard-go v0.0.0-20200624060658-de1f1af1f35f h1:hmhdY4xqtJD2rdaKpoNeWf0xLFFAc8dVZXyKMXRWbEM= -github.com/tailscale/wireguard-go v0.0.0-20200624060658-de1f1af1f35f/go.mod h1:JPm5cTfu1K+qDFRbiHy0sOlHUylYQbpl356sdYFD8V4= github.com/tailscale/wireguard-go v0.0.0-20200710044538-9320f191f6b1 h1:zMEeWu/X0l+xFnsbri69miflb3HIKoLwedZHD5xx6Mk= github.com/tailscale/wireguard-go v0.0.0-20200710044538-9320f191f6b1/go.mod h1:JPm5cTfu1K+qDFRbiHy0sOlHUylYQbpl356sdYFD8V4= github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0= diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index c2990bd98..ab7abeda8 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -1146,6 +1146,20 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP if port < 0 || port > 1<<16-1 { return nil } + if n.STUNTestIP != "" { + ip, err := netaddr.ParseIP(n.STUNTestIP) + if err != nil { + return nil + } + if proto == probeIPv4 && ip.Is6() { + return nil + } + if proto == probeIPv6 && ip.Is4() { + return nil + } + return netaddr.IPPort{ip, uint16(port)}.UDPAddr() + } + switch proto { case probeIPv4: if n.IPv4 != "" { diff --git a/net/stun/stuntest/stuntest.go b/net/stun/stuntest/stuntest.go index c2e462f78..e97accbb8 100644 --- a/net/stun/stuntest/stuntest.go +++ b/net/stun/stuntest/stuntest.go @@ -6,6 +6,7 @@ package stuntest import ( + "context" "fmt" "net" "strconv" @@ -16,6 +17,7 @@ import ( "inet.af/netaddr" "tailscale.com/net/stun" "tailscale.com/tailcfg" + "tailscale.com/types/nettype" ) type stunStats struct { @@ -25,18 +27,22 @@ type stunStats struct { } func Serve(t *testing.T) (addr *net.UDPAddr, cleanupFn func()) { + return ServeWithPacketListener(t, nettype.Std{}) +} + +func ServeWithPacketListener(t *testing.T, ln nettype.PacketListener) (addr *net.UDPAddr, cleanupFn func()) { t.Helper() // TODO(crawshaw): use stats to test re-STUN logic var stats stunStats - pc, err := net.ListenPacket("udp4", ":0") + pc, err := ln.ListenPacket(context.Background(), "udp4", ":0") if err != nil { t.Fatalf("failed to open STUN listener: %v", err) } - addr = &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: pc.LocalAddr().(*net.UDPAddr).Port, + addr = pc.LocalAddr().(*net.UDPAddr) + if len(addr.IP) == 0 || addr.IP.IsUnspecified() { + addr.IP = net.ParseIP("127.0.0.1") } doneCh := make(chan struct{}) go runSTUN(t, pc, &stats, doneCh) diff --git a/tailcfg/derpmap.go b/tailcfg/derpmap.go index 5e6bc6271..fff8f50d4 100644 --- a/tailcfg/derpmap.go +++ b/tailcfg/derpmap.go @@ -117,4 +117,8 @@ type DERPNode struct { // of using the default port of 443. If non-zero, TLS // verification is skipped. DERPTestPort int `json:",omitempty"` + + // STUNTestIP is used in tests to override the STUN server's IP. + // If empty, it's assumed to be the same as the DERP server. + STUNTestIP string `json:",omitempty"` } diff --git a/tstest/natlab/natlab.go b/tstest/natlab/natlab.go index 70c627027..64cba3622 100644 --- a/tstest/natlab/natlab.go +++ b/tstest/natlab/natlab.go @@ -15,7 +15,9 @@ import ( "context" "crypto/sha256" "encoding/base64" + "errors" "fmt" + "math/rand" "net" "os" "sort" @@ -26,10 +28,10 @@ import ( "inet.af/netaddr" ) -var traceOn = os.Getenv("NATLAB_TRACE") +var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE")) func trace(p []byte, msg string, args ...interface{}) { - if traceOn == "" { + if !traceOn { return } id := packetShort(p) @@ -424,6 +426,32 @@ func (m *Machine) hasv6() bool { return false } +func (m *Machine) pickEphemPort() (port uint16, err error) { + m.mu.Lock() + defer m.mu.Unlock() + for tries := 0; tries < 500; tries++ { + port := uint16(rand.Intn(32<<10) + 32<<10) + if !m.portInUseLocked(port) { + return port, nil + } + } + return 0, errors.New("failed to find an ephemeral port") +} + +func (m *Machine) portInUseLocked(port uint16) bool { + for ipp := range m.conns4 { + if ipp.Port == port { + return true + } + } + for ipp := range m.conns6 { + if ipp.Port == port { + return true + } + } + return false +} + func (m *Machine) registerConn4(c *conn) error { m.mu.Lock() defer m.mu.Unlock() @@ -467,7 +495,7 @@ func registerConn(conns *map[netaddr.IPPort]*conn, c *conn) error { func (m *Machine) AddNetwork(n *Network) {} -func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error) { +func (m *Machine) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { // if udp4, udp6, etc... look at address IP vs unspec var ( fam uint8 @@ -497,11 +525,18 @@ func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error) return nil, err } } - port, err := strconv.ParseUint(portStr, 10, 16) + porti, err := strconv.ParseUint(portStr, 10, 16) if err != nil { return nil, err } - ipp := netaddr.IPPort{IP: ip, Port: uint16(port)} + port := uint16(porti) + if port == 0 { + port, err = m.pickEphemPort() + if err != nil { + return nil, nil + } + } + ipp := netaddr.IPPort{IP: ip, Port: port} c := &conn{ m: m, @@ -552,11 +587,17 @@ type activeRead struct { cancel context.CancelFunc } -// readDeadlineExceeded reports whether the read deadline is set and has already passed. -func (c *conn) readDeadlineExceeded() bool { +// canRead reports whether we can do a read. +func (c *conn) canRead() error { c.mu.Lock() defer c.mu.Unlock() - return !c.readDeadline.IsZero() && c.readDeadline.Before(time.Now()) + if c.closed { + return errors.New("closed network connection") // sadface: magic string used by other; don't change + } + if !c.readDeadline.IsZero() && c.readDeadline.Before(time.Now()) { + return errors.New("read deadline exceeded") + } + return nil } func (c *conn) registerActiveRead(ar *activeRead, active bool) { @@ -609,8 +650,8 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { ar := &activeRead{cancel: cancel} - if c.readDeadlineExceeded() { - return 0, nil, context.DeadlineExceeded + if err := c.canRead(); err != nil { + return 0, nil, err } c.registerActiveRead(ar, true) diff --git a/tstest/natlab/natlab_test.go b/tstest/natlab/natlab_test.go index 970e8131b..4972c7596 100644 --- a/tstest/natlab/natlab_test.go +++ b/tstest/natlab/natlab_test.go @@ -5,6 +5,7 @@ package natlab import ( + "context" "fmt" "testing" @@ -49,11 +50,12 @@ func TestSendPacket(t *testing.T) { fooAddr := netaddr.IPPort{IP: ifFoo.V4(), Port: 123} barAddr := netaddr.IPPort{IP: ifBar.V4(), Port: 456} - fooPC, err := foo.ListenPacket("udp4", fooAddr.String()) + ctx := context.Background() + fooPC, err := foo.ListenPacket(ctx, "udp4", fooAddr.String()) if err != nil { t.Fatal(err) } - barPC, err := bar.ListenPacket("udp4", barAddr.String()) + barPC, err := bar.ListenPacket(ctx, "udp4", barAddr.String()) if err != nil { t.Fatal(err) } @@ -93,15 +95,16 @@ func TestMultiNetwork(t *testing.T) { ifNATLAN := nat.Attach("ethlan", lan) ifServer := server.Attach("eth0", internet) - clientPC, err := client.ListenPacket("udp", ":123") + ctx := context.Background() + clientPC, err := client.ListenPacket(ctx, "udp", ":123") if err != nil { t.Fatal(err) } - natPC, err := nat.ListenPacket("udp", ":456") + natPC, err := nat.ListenPacket(ctx, "udp", ":456") if err != nil { t.Fatal(err) } - serverPC, err := server.ListenPacket("udp", ":789") + serverPC, err := server.ListenPacket(ctx, "udp", ":789") if err != nil { t.Fatal(err) } @@ -184,11 +187,12 @@ func TestPacketHandler(t *testing.T) { } } - clientPC, err := client.ListenPacket("udp4", ":123") + ctx := context.Background() + clientPC, err := client.ListenPacket(ctx, "udp4", ":123") if err != nil { t.Fatal(err) } - serverPC, err := server.ListenPacket("udp4", ":456") + serverPC, err := server.ListenPacket(ctx, "udp4", ":456") if err != nil { t.Fatal(err) } diff --git a/types/nettype/nettype.go b/types/nettype/nettype.go new file mode 100644 index 000000000..c7827ef6a --- /dev/null +++ b/types/nettype/nettype.go @@ -0,0 +1,25 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package nettype defines an interface that doesn't exist in the Go net package. +package nettype + +import ( + "context" + "net" +) + +// PacketListener defines the ListenPacket method as implemented +// by net.ListenConfig, net.ListenPacket, and tstest/natlab. +type PacketListener interface { + ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) +} + +// Std implements PacketListener using the Go net package's ListenPacket func. +type Std struct{} + +func (Std) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + var conf net.ListenConfig + return conf.ListenPacket(ctx, network, address) +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 22c49a7b1..f7678c163 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -49,6 +49,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/nettype" "tailscale.com/types/opt" "tailscale.com/types/structs" "tailscale.com/version" @@ -82,6 +83,9 @@ type Conn struct { udpRecvCh chan udpReadResult derpRecvCh chan derpReadResult + // packetListener optionally specifies a test hook to open a PacketConn. + packetListener nettype.PacketListener + // ============================================================ mu sync.Mutex // guards all following fields @@ -227,6 +231,10 @@ type Options struct { // IdleFunc optionally provides a func to return how long // it's been since a TUN packet was sent or received. IdleFunc func() time.Duration + + // PacketListener optionally specifies how to create PacketConns. + // It's meant for testing. + PacketListener nettype.PacketListener } func (o *Options) logf() logger.Logf { @@ -273,6 +281,7 @@ func NewConn(opts Options) (*Conn, error) { c.logf = opts.logf() c.epFunc = opts.endpointsFunc() c.idleFunc = opts.IdleFunc + c.packetListener = opts.PacketListener if err := c.initialBind(); err != nil { return nil, err @@ -2002,6 +2011,13 @@ func (c *Conn) initialBind() error { return nil } +func (c *Conn) listenPacket(ctx context.Context, network, addr string) (net.PacketConn, error) { + if c.packetListener != nil { + return c.packetListener.ListenPacket(ctx, network, addr) + } + return netns.Listener().ListenPacket(ctx, network, addr) +} + func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error { host := "" if v, _ := strconv.ParseBool(os.Getenv("IN_TS_TEST")); v { @@ -2011,13 +2027,13 @@ func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error { var err error listenCtx := context.Background() // unused without DNS name to resolve if c.pconnPort == 0 && DefaultPort != 0 { - pc, err = netns.Listener().ListenPacket(listenCtx, which, fmt.Sprintf("%s:%d", host, DefaultPort)) + pc, err = c.listenPacket(listenCtx, which, fmt.Sprintf("%s:%d", host, DefaultPort)) if err != nil { c.logf("magicsock: bind: default port %s/%v unavailable; picking random", which, DefaultPort) } } if pc == nil { - pc, err = netns.Listener().ListenPacket(listenCtx, which, fmt.Sprintf("%s:%d", host, c.pconnPort)) + pc, err = c.listenPacket(listenCtx, which, fmt.Sprintf("%s:%d", host, c.pconnPort)) } if err != nil { c.logf("magicsock: bind(%s/%v): %v", which, c.pconnPort, err) @@ -2026,7 +2042,7 @@ func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error { if *ruc == nil { *ruc = new(RebindingUDPConn) } - (*ruc).Reset(pc.(*net.UDPConn)) + (*ruc).Reset(pc) return nil } @@ -2043,7 +2059,7 @@ func (c *Conn) Rebind() { if err := c.pconn4.pconn.Close(); err != nil { c.logf("magicsock: link change close failed: %v", err) } - packetConn, err := netns.Listener().ListenPacket(listenCtx, "udp4", fmt.Sprintf("%s:%d", host, c.pconnPort)) + packetConn, err := c.listenPacket(listenCtx, "udp4", fmt.Sprintf("%s:%d", host, c.pconnPort)) if err == nil { c.logf("magicsock: link change rebound port: %d", c.pconnPort) c.pconn4.pconn = packetConn.(*net.UDPConn) @@ -2054,7 +2070,7 @@ func (c *Conn) Rebind() { c.pconn4.mu.Unlock() } c.logf("magicsock: link change, binding new port") - packetConn, err := netns.Listener().ListenPacket(listenCtx, "udp4", host+":0") + packetConn, err := c.listenPacket(listenCtx, "udp4", host+":0") if err != nil { c.logf("magicsock: link change failed to bind new port: %v", err) return @@ -2481,10 +2497,10 @@ type RebindingUDPConn struct { ippCache ippCache mu sync.Mutex - pconn *net.UDPConn + pconn net.PacketConn } -func (c *RebindingUDPConn) Reset(pconn *net.UDPConn) { +func (c *RebindingUDPConn) Reset(pconn net.PacketConn) { c.mu.Lock() old := c.pconn c.pconn = pconn @@ -2539,7 +2555,7 @@ func (c *RebindingUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) pconn := c.pconn c.mu.Unlock() - n, err := pconn.WriteToUDP(b, addr) + n, err := pconn.WriteTo(b, addr) if err != nil { c.mu.Lock() pconn2 := c.pconn diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index a090f900b..50a537858 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -32,8 +32,10 @@ import ( "tailscale.com/net/stun/stuntest" "tailscale.com/tailcfg" "tailscale.com/tstest" + "tailscale.com/tstest/natlab" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/nettype" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/tstun" ) @@ -334,6 +336,16 @@ func makeNestable(t *testing.T) (logf logger.Logf, setT func(t *testing.T)) { } func TestTwoDevicePing(t *testing.T) { + t.Run("real", func(t *testing.T) { + testTwoDevicePing(t, false) + }) + t.Run("natlab", func(t *testing.T) { + t.Skip("TODO: finish") + testTwoDevicePing(t, true) + }) +} + +func testTwoDevicePing(t *testing.T, useNatlab bool) { tstest.PanicOnLog() rc := tstest.NewResourceCheck() defer rc.Assert(t) @@ -344,7 +356,28 @@ func TestTwoDevicePing(t *testing.T) { derpServer, derpAddr, derpCleanupFn := runDERP(t, logf) defer derpCleanupFn() - stunAddr, stunCleanupFn := stuntest.Serve(t) + + packetConn := func(m *natlab.Machine) nettype.PacketListener { + if m == nil { + return nettype.Std{} + } + return m + } + + var stunTestIP = "127.0.0.1" + var stunMachine, machine1, machine2 *natlab.Machine + if useNatlab { + stunMachine = &natlab.Machine{Name: "stun"} + machine1 = &natlab.Machine{Name: "machine1"} + machine2 = &natlab.Machine{Name: "machine2"} + internet := natlab.NewInternet() + stunIf := stunMachine.Attach("eth0", internet) + machine1.Attach("eth0", internet) + machine2.Attach("eth0", internet) + stunTestIP = stunIf.V4().String() + } + + stunAddr, stunCleanupFn := stuntest.ServeWithPacketListener(t, packetConn(stunMachine)) defer stunCleanupFn() derpMap := &tailcfg.DERPMap{ @@ -361,6 +394,7 @@ func TestTwoDevicePing(t *testing.T) { IPv6: "none", STUNPort: stunAddr.Port, DERPTestPort: derpAddr.Port, + STUNTestIP: stunTestIP, }, }, }, @@ -369,7 +403,8 @@ func TestTwoDevicePing(t *testing.T) { epCh1 := make(chan []string, 16) conn1, err := NewConn(Options{ - Logf: logger.WithPrefix(logf, "conn1: "), + Logf: logger.WithPrefix(logf, "conn1: "), + PacketListener: packetConn(machine1), EndpointsFunc: func(eps []string) { epCh1 <- eps }, @@ -383,7 +418,8 @@ func TestTwoDevicePing(t *testing.T) { epCh2 := make(chan []string, 16) conn2, err := NewConn(Options{ - Logf: logger.WithPrefix(logf, "conn2: "), + Logf: logger.WithPrefix(logf, "conn2: "), + PacketListener: packetConn(machine2), EndpointsFunc: func(eps []string) { epCh2 <- eps }, @@ -396,6 +432,14 @@ func TestTwoDevicePing(t *testing.T) { conn2.SetDERPMap(derpMap) ports := []uint16{conn1.LocalPort(), conn2.LocalPort()} + if useNatlab { + // TODO: ... + } else { + addrs := []netaddr.IPPort{ + // netaddr.IPPort + } + _ = addrs + } cfgs := makeConfigs(t, ports) if err := conn1.SetPrivateKey(cfgs[0].PrivateKey); err != nil {