wgengine/magicsock, tstest/natlab: start hooking up natlab to magicsock

Also adds ephemeral port support to natlab.

Work in progress.

Pairing with @danderson.
pull/544/head
Brad Fitzpatrick 4 years ago
parent edcbb5394e
commit 6c74065053

@ -84,10 +84,6 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/tailscale/winipcfg-go v0.0.0-20200413171540-609dcf2df55f h1:uFj5bslHsMzxIM8UTjAhq4VXeo6GfNW91rpoh/WMJaY=
github.com/tailscale/winipcfg-go v0.0.0-20200413171540-609dcf2df55f/go.mod h1:x880GWw5fvrl2DVTQ04ttXQD4DuppTt1Yz6wLibbjNE=
github.com/tailscale/wireguard-go v0.0.0-20200615180905-687c10194779 h1:zg0rgvhBZGA4nvh17nDKcqkEXw6Nbc/Ma2VBvLaW7LU=
github.com/tailscale/wireguard-go v0.0.0-20200615180905-687c10194779/go.mod h1:JPm5cTfu1K+qDFRbiHy0sOlHUylYQbpl356sdYFD8V4=
github.com/tailscale/wireguard-go v0.0.0-20200624060658-de1f1af1f35f h1:hmhdY4xqtJD2rdaKpoNeWf0xLFFAc8dVZXyKMXRWbEM=
github.com/tailscale/wireguard-go v0.0.0-20200624060658-de1f1af1f35f/go.mod h1:JPm5cTfu1K+qDFRbiHy0sOlHUylYQbpl356sdYFD8V4=
github.com/tailscale/wireguard-go v0.0.0-20200710044538-9320f191f6b1 h1:zMEeWu/X0l+xFnsbri69miflb3HIKoLwedZHD5xx6Mk=
github.com/tailscale/wireguard-go v0.0.0-20200710044538-9320f191f6b1/go.mod h1:JPm5cTfu1K+qDFRbiHy0sOlHUylYQbpl356sdYFD8V4=
github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0=

@ -1146,6 +1146,20 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP
if port < 0 || port > 1<<16-1 {
return nil
}
if n.STUNTestIP != "" {
ip, err := netaddr.ParseIP(n.STUNTestIP)
if err != nil {
return nil
}
if proto == probeIPv4 && ip.Is6() {
return nil
}
if proto == probeIPv6 && ip.Is4() {
return nil
}
return netaddr.IPPort{ip, uint16(port)}.UDPAddr()
}
switch proto {
case probeIPv4:
if n.IPv4 != "" {

@ -6,6 +6,7 @@
package stuntest
import (
"context"
"fmt"
"net"
"strconv"
@ -16,6 +17,7 @@ import (
"inet.af/netaddr"
"tailscale.com/net/stun"
"tailscale.com/tailcfg"
"tailscale.com/types/nettype"
)
type stunStats struct {
@ -25,18 +27,22 @@ type stunStats struct {
}
func Serve(t *testing.T) (addr *net.UDPAddr, cleanupFn func()) {
return ServeWithPacketListener(t, nettype.Std{})
}
func ServeWithPacketListener(t *testing.T, ln nettype.PacketListener) (addr *net.UDPAddr, cleanupFn func()) {
t.Helper()
// TODO(crawshaw): use stats to test re-STUN logic
var stats stunStats
pc, err := net.ListenPacket("udp4", ":0")
pc, err := ln.ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatalf("failed to open STUN listener: %v", err)
}
addr = &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: pc.LocalAddr().(*net.UDPAddr).Port,
addr = pc.LocalAddr().(*net.UDPAddr)
if len(addr.IP) == 0 || addr.IP.IsUnspecified() {
addr.IP = net.ParseIP("127.0.0.1")
}
doneCh := make(chan struct{})
go runSTUN(t, pc, &stats, doneCh)

@ -117,4 +117,8 @@ type DERPNode struct {
// of using the default port of 443. If non-zero, TLS
// verification is skipped.
DERPTestPort int `json:",omitempty"`
// STUNTestIP is used in tests to override the STUN server's IP.
// If empty, it's assumed to be the same as the DERP server.
STUNTestIP string `json:",omitempty"`
}

@ -15,7 +15,9 @@ import (
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"math/rand"
"net"
"os"
"sort"
@ -26,10 +28,10 @@ import (
"inet.af/netaddr"
)
var traceOn = os.Getenv("NATLAB_TRACE")
var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE"))
func trace(p []byte, msg string, args ...interface{}) {
if traceOn == "" {
if !traceOn {
return
}
id := packetShort(p)
@ -424,6 +426,32 @@ func (m *Machine) hasv6() bool {
return false
}
func (m *Machine) pickEphemPort() (port uint16, err error) {
m.mu.Lock()
defer m.mu.Unlock()
for tries := 0; tries < 500; tries++ {
port := uint16(rand.Intn(32<<10) + 32<<10)
if !m.portInUseLocked(port) {
return port, nil
}
}
return 0, errors.New("failed to find an ephemeral port")
}
func (m *Machine) portInUseLocked(port uint16) bool {
for ipp := range m.conns4 {
if ipp.Port == port {
return true
}
}
for ipp := range m.conns6 {
if ipp.Port == port {
return true
}
}
return false
}
func (m *Machine) registerConn4(c *conn) error {
m.mu.Lock()
defer m.mu.Unlock()
@ -467,7 +495,7 @@ func registerConn(conns *map[netaddr.IPPort]*conn, c *conn) error {
func (m *Machine) AddNetwork(n *Network) {}
func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error) {
func (m *Machine) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
// if udp4, udp6, etc... look at address IP vs unspec
var (
fam uint8
@ -497,11 +525,18 @@ func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error)
return nil, err
}
}
port, err := strconv.ParseUint(portStr, 10, 16)
porti, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
ipp := netaddr.IPPort{IP: ip, Port: uint16(port)}
port := uint16(porti)
if port == 0 {
port, err = m.pickEphemPort()
if err != nil {
return nil, nil
}
}
ipp := netaddr.IPPort{IP: ip, Port: port}
c := &conn{
m: m,
@ -552,11 +587,17 @@ type activeRead struct {
cancel context.CancelFunc
}
// readDeadlineExceeded reports whether the read deadline is set and has already passed.
func (c *conn) readDeadlineExceeded() bool {
// canRead reports whether we can do a read.
func (c *conn) canRead() error {
c.mu.Lock()
defer c.mu.Unlock()
return !c.readDeadline.IsZero() && c.readDeadline.Before(time.Now())
if c.closed {
return errors.New("closed network connection") // sadface: magic string used by other; don't change
}
if !c.readDeadline.IsZero() && c.readDeadline.Before(time.Now()) {
return errors.New("read deadline exceeded")
}
return nil
}
func (c *conn) registerActiveRead(ar *activeRead, active bool) {
@ -609,8 +650,8 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
ar := &activeRead{cancel: cancel}
if c.readDeadlineExceeded() {
return 0, nil, context.DeadlineExceeded
if err := c.canRead(); err != nil {
return 0, nil, err
}
c.registerActiveRead(ar, true)

@ -5,6 +5,7 @@
package natlab
import (
"context"
"fmt"
"testing"
@ -49,11 +50,12 @@ func TestSendPacket(t *testing.T) {
fooAddr := netaddr.IPPort{IP: ifFoo.V4(), Port: 123}
barAddr := netaddr.IPPort{IP: ifBar.V4(), Port: 456}
fooPC, err := foo.ListenPacket("udp4", fooAddr.String())
ctx := context.Background()
fooPC, err := foo.ListenPacket(ctx, "udp4", fooAddr.String())
if err != nil {
t.Fatal(err)
}
barPC, err := bar.ListenPacket("udp4", barAddr.String())
barPC, err := bar.ListenPacket(ctx, "udp4", barAddr.String())
if err != nil {
t.Fatal(err)
}
@ -93,15 +95,16 @@ func TestMultiNetwork(t *testing.T) {
ifNATLAN := nat.Attach("ethlan", lan)
ifServer := server.Attach("eth0", internet)
clientPC, err := client.ListenPacket("udp", ":123")
ctx := context.Background()
clientPC, err := client.ListenPacket(ctx, "udp", ":123")
if err != nil {
t.Fatal(err)
}
natPC, err := nat.ListenPacket("udp", ":456")
natPC, err := nat.ListenPacket(ctx, "udp", ":456")
if err != nil {
t.Fatal(err)
}
serverPC, err := server.ListenPacket("udp", ":789")
serverPC, err := server.ListenPacket(ctx, "udp", ":789")
if err != nil {
t.Fatal(err)
}
@ -184,11 +187,12 @@ func TestPacketHandler(t *testing.T) {
}
}
clientPC, err := client.ListenPacket("udp4", ":123")
ctx := context.Background()
clientPC, err := client.ListenPacket(ctx, "udp4", ":123")
if err != nil {
t.Fatal(err)
}
serverPC, err := server.ListenPacket("udp4", ":456")
serverPC, err := server.ListenPacket(ctx, "udp4", ":456")
if err != nil {
t.Fatal(err)
}

@ -0,0 +1,25 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package nettype defines an interface that doesn't exist in the Go net package.
package nettype
import (
"context"
"net"
)
// PacketListener defines the ListenPacket method as implemented
// by net.ListenConfig, net.ListenPacket, and tstest/natlab.
type PacketListener interface {
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
}
// Std implements PacketListener using the Go net package's ListenPacket func.
type Std struct{}
func (Std) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
var conf net.ListenConfig
return conf.ListenPacket(ctx, network, address)
}

@ -49,6 +49,7 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/nettype"
"tailscale.com/types/opt"
"tailscale.com/types/structs"
"tailscale.com/version"
@ -82,6 +83,9 @@ type Conn struct {
udpRecvCh chan udpReadResult
derpRecvCh chan derpReadResult
// packetListener optionally specifies a test hook to open a PacketConn.
packetListener nettype.PacketListener
// ============================================================
mu sync.Mutex // guards all following fields
@ -227,6 +231,10 @@ type Options struct {
// IdleFunc optionally provides a func to return how long
// it's been since a TUN packet was sent or received.
IdleFunc func() time.Duration
// PacketListener optionally specifies how to create PacketConns.
// It's meant for testing.
PacketListener nettype.PacketListener
}
func (o *Options) logf() logger.Logf {
@ -273,6 +281,7 @@ func NewConn(opts Options) (*Conn, error) {
c.logf = opts.logf()
c.epFunc = opts.endpointsFunc()
c.idleFunc = opts.IdleFunc
c.packetListener = opts.PacketListener
if err := c.initialBind(); err != nil {
return nil, err
@ -2002,6 +2011,13 @@ func (c *Conn) initialBind() error {
return nil
}
func (c *Conn) listenPacket(ctx context.Context, network, addr string) (net.PacketConn, error) {
if c.packetListener != nil {
return c.packetListener.ListenPacket(ctx, network, addr)
}
return netns.Listener().ListenPacket(ctx, network, addr)
}
func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error {
host := ""
if v, _ := strconv.ParseBool(os.Getenv("IN_TS_TEST")); v {
@ -2011,13 +2027,13 @@ func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error {
var err error
listenCtx := context.Background() // unused without DNS name to resolve
if c.pconnPort == 0 && DefaultPort != 0 {
pc, err = netns.Listener().ListenPacket(listenCtx, which, fmt.Sprintf("%s:%d", host, DefaultPort))
pc, err = c.listenPacket(listenCtx, which, fmt.Sprintf("%s:%d", host, DefaultPort))
if err != nil {
c.logf("magicsock: bind: default port %s/%v unavailable; picking random", which, DefaultPort)
}
}
if pc == nil {
pc, err = netns.Listener().ListenPacket(listenCtx, which, fmt.Sprintf("%s:%d", host, c.pconnPort))
pc, err = c.listenPacket(listenCtx, which, fmt.Sprintf("%s:%d", host, c.pconnPort))
}
if err != nil {
c.logf("magicsock: bind(%s/%v): %v", which, c.pconnPort, err)
@ -2026,7 +2042,7 @@ func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error {
if *ruc == nil {
*ruc = new(RebindingUDPConn)
}
(*ruc).Reset(pc.(*net.UDPConn))
(*ruc).Reset(pc)
return nil
}
@ -2043,7 +2059,7 @@ func (c *Conn) Rebind() {
if err := c.pconn4.pconn.Close(); err != nil {
c.logf("magicsock: link change close failed: %v", err)
}
packetConn, err := netns.Listener().ListenPacket(listenCtx, "udp4", fmt.Sprintf("%s:%d", host, c.pconnPort))
packetConn, err := c.listenPacket(listenCtx, "udp4", fmt.Sprintf("%s:%d", host, c.pconnPort))
if err == nil {
c.logf("magicsock: link change rebound port: %d", c.pconnPort)
c.pconn4.pconn = packetConn.(*net.UDPConn)
@ -2054,7 +2070,7 @@ func (c *Conn) Rebind() {
c.pconn4.mu.Unlock()
}
c.logf("magicsock: link change, binding new port")
packetConn, err := netns.Listener().ListenPacket(listenCtx, "udp4", host+":0")
packetConn, err := c.listenPacket(listenCtx, "udp4", host+":0")
if err != nil {
c.logf("magicsock: link change failed to bind new port: %v", err)
return
@ -2481,10 +2497,10 @@ type RebindingUDPConn struct {
ippCache ippCache
mu sync.Mutex
pconn *net.UDPConn
pconn net.PacketConn
}
func (c *RebindingUDPConn) Reset(pconn *net.UDPConn) {
func (c *RebindingUDPConn) Reset(pconn net.PacketConn) {
c.mu.Lock()
old := c.pconn
c.pconn = pconn
@ -2539,7 +2555,7 @@ func (c *RebindingUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error)
pconn := c.pconn
c.mu.Unlock()
n, err := pconn.WriteToUDP(b, addr)
n, err := pconn.WriteTo(b, addr)
if err != nil {
c.mu.Lock()
pconn2 := c.pconn

@ -32,8 +32,10 @@ import (
"tailscale.com/net/stun/stuntest"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/tstest/natlab"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/nettype"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/tstun"
)
@ -334,6 +336,16 @@ func makeNestable(t *testing.T) (logf logger.Logf, setT func(t *testing.T)) {
}
func TestTwoDevicePing(t *testing.T) {
t.Run("real", func(t *testing.T) {
testTwoDevicePing(t, false)
})
t.Run("natlab", func(t *testing.T) {
t.Skip("TODO: finish")
testTwoDevicePing(t, true)
})
}
func testTwoDevicePing(t *testing.T, useNatlab bool) {
tstest.PanicOnLog()
rc := tstest.NewResourceCheck()
defer rc.Assert(t)
@ -344,7 +356,28 @@ func TestTwoDevicePing(t *testing.T) {
derpServer, derpAddr, derpCleanupFn := runDERP(t, logf)
defer derpCleanupFn()
stunAddr, stunCleanupFn := stuntest.Serve(t)
packetConn := func(m *natlab.Machine) nettype.PacketListener {
if m == nil {
return nettype.Std{}
}
return m
}
var stunTestIP = "127.0.0.1"
var stunMachine, machine1, machine2 *natlab.Machine
if useNatlab {
stunMachine = &natlab.Machine{Name: "stun"}
machine1 = &natlab.Machine{Name: "machine1"}
machine2 = &natlab.Machine{Name: "machine2"}
internet := natlab.NewInternet()
stunIf := stunMachine.Attach("eth0", internet)
machine1.Attach("eth0", internet)
machine2.Attach("eth0", internet)
stunTestIP = stunIf.V4().String()
}
stunAddr, stunCleanupFn := stuntest.ServeWithPacketListener(t, packetConn(stunMachine))
defer stunCleanupFn()
derpMap := &tailcfg.DERPMap{
@ -361,6 +394,7 @@ func TestTwoDevicePing(t *testing.T) {
IPv6: "none",
STUNPort: stunAddr.Port,
DERPTestPort: derpAddr.Port,
STUNTestIP: stunTestIP,
},
},
},
@ -370,6 +404,7 @@ func TestTwoDevicePing(t *testing.T) {
epCh1 := make(chan []string, 16)
conn1, err := NewConn(Options{
Logf: logger.WithPrefix(logf, "conn1: "),
PacketListener: packetConn(machine1),
EndpointsFunc: func(eps []string) {
epCh1 <- eps
},
@ -384,6 +419,7 @@ func TestTwoDevicePing(t *testing.T) {
epCh2 := make(chan []string, 16)
conn2, err := NewConn(Options{
Logf: logger.WithPrefix(logf, "conn2: "),
PacketListener: packetConn(machine2),
EndpointsFunc: func(eps []string) {
epCh2 <- eps
},
@ -396,6 +432,14 @@ func TestTwoDevicePing(t *testing.T) {
conn2.SetDERPMap(derpMap)
ports := []uint16{conn1.LocalPort(), conn2.LocalPort()}
if useNatlab {
// TODO: ...
} else {
addrs := []netaddr.IPPort{
// netaddr.IPPort
}
_ = addrs
}
cfgs := makeConfigs(t, ports)
if err := conn1.SetPrivateKey(cfgs[0].PrivateKey); err != nil {

Loading…
Cancel
Save