From 8594292aa435e917c99c71de7e70ffe6df479096 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 6 Aug 2024 17:33:45 -0700 Subject: [PATCH] vnet: add control/derps to test, stateful firewall Updates #13038 Change-Id: Icd65b34c5f03498b5a7109785bb44692bce8911a Signed-off-by: Brad Fitzpatrick --- cmd/tta/tta.go | 104 ++++++++++- cmd/vnet/vnet-main.go | 4 +- gokrazy/Makefile | 3 + tstest/integration/integration.go | 1 + tstest/integration/nat/nat_test.go | 277 +++++++++++++++++++++++++++++ tstest/natlab/vnet/conf.go | 9 + tstest/natlab/vnet/nat.go | 30 +++- tstest/natlab/vnet/vnet.go | 219 +++++++++++++++++++---- 8 files changed, 599 insertions(+), 48 deletions(-) create mode 100644 tstest/integration/nat/nat_test.go diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index c7f587c4b..ed5892e76 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -11,6 +11,7 @@ package main import ( + "bufio" "bytes" "errors" "flag" @@ -19,12 +20,18 @@ import ( "log" "net" "net/http" + "net/http/httputil" + "net/url" "os" "os/exec" + "regexp" "strings" "sync" "time" + "github.com/mitchellh/go-ps" + "tailscale.com/client/tailscale" + "tailscale.com/util/must" "tailscale.com/util/set" "tailscale.com/version/distro" ) @@ -33,21 +40,35 @@ var ( driverAddr = flag.String("driver", "test-driver.tailscale:8008", "address of the test driver; by default we use the DNS name test-driver.tailscale which is special cased in the emulated network's DNS server") ) -type chanListener <-chan net.Conn - -func serveCmd(w http.ResponseWriter, cmd string, args ...string) { +func absify(cmd string) string { if distro.Get() == distro.Gokrazy && !strings.Contains(cmd, "/") { - cmd = "/user/" + cmd + return "/user/" + cmd } - out, err := exec.Command(cmd, args...).CombinedOutput() + return cmd +} + +func serveCmd(w http.ResponseWriter, cmd string, args ...string) { + log.Printf("Got serveCmd for %q %v", cmd, args) + out, err := exec.Command(absify(cmd), args...).CombinedOutput() w.Header().Set("Content-Type", "text/plain; charset=utf-8") if err != nil { w.Header().Set("Exec-Err", err.Error()) w.WriteHeader(500) + log.Printf("Err on serveCmd for %q %v, %d bytes of output: %v", cmd, args, len(out), err) + } else { + log.Printf("Did serveCmd for %q %v, %d bytes of output", cmd, args, len(out)) } w.Write(out) } +type localClientRoundTripper struct { + lc *tailscale.LocalClient +} + +func (rt localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return rt.lc.DoLocalRequest(req) +} + func main() { if distro.Get() == distro.Gokrazy { cmdLine, _ := os.ReadFile("/proc/cmdline") @@ -59,8 +80,52 @@ func main() { } } flag.Parse() + + if distro.Get() == distro.Gokrazy { + nsRx := regexp.MustCompile(`(?m)^nameserver (.*)`) + for t := time.Now(); time.Since(t) < 10*time.Second; time.Sleep(10 * time.Millisecond) { + all, _ := os.ReadFile("/etc/resolv.conf") + if nsRx.Match(all) { + break + } + } + } + + logc, err := net.Dial("tcp", "9.9.9.9:124") + if err == nil { + log.SetOutput(logc) + } + log.Printf("Tailscale Test Agent running.") + if distro.Get() == distro.Gokrazy { + procs, err := ps.Processes() + if err != nil { + log.Fatalf("ps.Processes: %v", err) + } + killed := false + for _, p := range procs { + if p.Executable() == "tailscaled" { + if op, err := os.FindProcess(p.Pid()); err == nil { + op.Signal(os.Interrupt) + killed = true + } + } + } + log.Printf("killed = %v", killed) + if killed { + for { + _, err := exec.Command(absify("tailscale"), "status", "--json").CombinedOutput() + if err == nil { + log.Printf("tailscaled back up") + break + } + log.Printf("tailscale status error; sleeping before trying again...") + time.Sleep(50 * time.Millisecond) + } + } + } + var mux http.ServeMux var hs http.Server hs.Handler = &mux @@ -75,7 +140,7 @@ func main() { switch s { case http.StateNew: newSet.Add(c) - case http.StateClosed: + default: newSet.Delete(c) } if len(newSet) == 0 { @@ -86,20 +151,41 @@ func main() { } } conns := make(chan net.Conn, 1) + var lc tailscale.LocalClient + rp := httputil.NewSingleHostReverseProxy(must.Get(url.Parse("http://local-tailscaled.sock"))) + rp.Transport = localClientRoundTripper{&lc} + + mux.Handle("/localapi/", rp) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "TTA\n") return }) mux.HandleFunc("/up", func(w http.ResponseWriter, r *http.Request) { - serveCmd(w, "tailscale", "up", "--auth-key=test") + cmd := exec.Command(absify("tailscale"), "debug", "daemon-logs") + out, err := cmd.StdoutPipe() + if err != nil { + http.Error(w, err.Error(), 500) + return + } + defer out.Close() + cmd.Start() + defer cmd.Process.Kill() + go func() { + bs := bufio.NewScanner(out) + for bs.Scan() { + log.Printf("Daemon: %s", bs.Text()) + } + }() + + serveCmd(w, "tailscale", "up", "--login-server=http://control.tailscale") }) mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { serveCmd(w, "tailscale", "status", "--json") }) mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { target := r.FormValue("target") - cmd := exec.Command("tailscale", "ping", target) + cmd := exec.Command(absify("tailscale"), "ping", target) w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.(http.Flusher).Flush() cmd.Stdout = w @@ -139,6 +225,8 @@ func connect() (net.Conn, error) { return c, nil } +type chanListener <-chan net.Conn + func (cl chanListener) Accept() (net.Conn, error) { c, ok := <-cl if !ok { diff --git a/cmd/vnet/vnet-main.go b/cmd/vnet/vnet-main.go index 31e11f89f..3bc512995 100644 --- a/cmd/vnet/vnet-main.go +++ b/cmd/vnet/vnet-main.go @@ -19,6 +19,7 @@ import ( var ( listen = flag.String("listen", "/tmp/qemu.sock", "path to listen on") nat = flag.String("nat", "easy", "type of NAT to use") + nat2 = flag.String("nat2", "hard", "type of NAT to use for second network") portmap = flag.Bool("portmap", false, "enable portmapping") dgram = flag.Bool("dgram", false, "enable datagram mode; for use with macOS Hypervisor.Framework and VZFileHandleNetworkDeviceAttachment") ) @@ -52,7 +53,7 @@ func main() { var c vnet.Config node1 := c.AddNode(c.AddNetwork("2.1.1.1", "192.168.1.1/24", vnet.NAT(*nat))) - c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", vnet.NAT(*nat))) + c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", vnet.NAT(*nat2))) if *portmap { node1.Network().AddService(vnet.NATPMP) } @@ -81,6 +82,7 @@ func main() { } for { time.Sleep(5 * time.Second) + //continue getStatus() } }() diff --git a/gokrazy/Makefile b/gokrazy/Makefile index f086dd26b..a0807abe5 100644 --- a/gokrazy/Makefile +++ b/gokrazy/Makefile @@ -6,3 +6,6 @@ image: qemu: image qemu-system-x86_64 -m 1G -drive file=tsapp.img,format=raw -boot d -netdev user,id=user.0 -device virtio-net-pci,netdev=user.0 -serial mon:stdio -audio none + +qcow2: image + qemu-img convert -O qcow2 tsapp.img tsapp.qcow2 diff --git a/tstest/integration/integration.go b/tstest/integration/integration.go index d6fcdca27..36a92759f 100644 --- a/tstest/integration/integration.go +++ b/tstest/integration/integration.go @@ -190,6 +190,7 @@ func RunDERPAndSTUN(t testing.TB, logf logger.Logf, ipAddress string) (derpMap * } httpsrv := httptest.NewUnstartedServer(derphttp.Handler(d)) + httpsrv.Listener.Close() httpsrv.Listener = ln httpsrv.Config.ErrorLog = logger.StdLogger(logf) httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) diff --git a/tstest/integration/nat/nat_test.go b/tstest/integration/nat/nat_test.go new file mode 100644 index 000000000..6f9dd781a --- /dev/null +++ b/tstest/integration/nat/nat_test.go @@ -0,0 +1,277 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package nat + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "golang.org/x/sync/errgroup" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tstest/natlab/vnet" +) + +type natTest struct { + tb testing.TB + base string // base image + tempDir string // for qcow2 images + vnet *vnet.Server +} + +func newNatTest(tb testing.TB) *natTest { + nt := &natTest{ + tb: tb, + tempDir: tb.TempDir(), + base: "/Users/bradfitz/src/tailscale.com/gokrazy/tsapp.qcow2", + } + + if _, err := os.Stat(nt.base); err != nil { + tb.Skipf("skipping test; base image %q not found", nt.base) + } + return nt +} + +type addNodeFunc func(c *vnet.Config) *vnet.Node + +func easy(c *vnet.Config) *vnet.Node { + n := c.NumNodes() + 1 + return c.AddNode(c.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT)) +} + +func hard(c *vnet.Config) *vnet.Node { + n := c.NumNodes() + 1 + return c.AddNode(c.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT)) +} + +func (nt *natTest) runTest(node1, node2 addNodeFunc) { + t := nt.tb + + var c vnet.Config + nodes := []*vnet.Node{ + node1(&c), + node2(&c), + } + + var err error + nt.vnet, err = vnet.New(&c) + if err != nil { + t.Fatalf("newServer: %v", err) + } + nt.tb.Cleanup(func() { + nt.vnet.Close() + }) + + var wg sync.WaitGroup // waiting for srv.Accept goroutine + defer wg.Wait() + + sockAddr := filepath.Join(nt.tempDir, "qemu.sock") + srv, err := net.Listen("unix", sockAddr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer srv.Close() + + wg.Add(1) + go func() { + defer wg.Done() + for { + c, err := srv.Accept() + if err != nil { + return + } + go nt.vnet.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) + } + }() + + for i, node := range nodes { + disk := fmt.Sprintf("%s/node-%d.qcow2", nt.tempDir, i) + out, err := exec.Command("qemu-img", "create", + "-f", "qcow2", + "-F", "qcow2", + "-b", nt.base, + disk).CombinedOutput() + if err != nil { + t.Fatalf("qemu-img create: %v, %s", err, out) + } + + cmd := exec.Command("qemu-system-x86_64", + "-M", "microvm,isa-serial=off", + "-m", "1G", + "-nodefaults", "-no-user-config", "-nographic", + "-kernel", "/Users/bradfitz/src/github.com/tailscale/gokrazy-kernel/vmlinuz", + "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1", + "-drive", "id=blk0,file="+disk+",format=qcow2", + "-device", "virtio-blk-device,drive=blk0", + "-netdev", "stream,id=net0,addr.type=unix,addr.path="+sockAddr, + "-device", "virtio-serial-device", + "-device", "virtio-net-device,netdev=net0,mac="+node.MAC().String(), + "-chardev", "stdio,id=virtiocon0,mux=on", + "-device", "virtconsole,chardev=virtiocon0", + "-mon", "chardev=virtiocon0,mode=readline", + "-audio", "none", + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Start(); err != nil { + t.Fatalf("qemu: %v", err) + } + nt.tb.Cleanup(func() { + cmd.Process.Kill() + cmd.Wait() + }) + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + c1 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[0])} + c2 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[1])} + + var eg errgroup.Group + var sts [2]*ipnstate.Status + for i, c := range []*http.Client{c1, c2} { + i, c := i, c + eg.Go(func() error { + st, err := status(ctx, c) + if err != nil { + return fmt.Errorf("node%d status: %w", i, err) + } + t.Logf("node%d status: %v", i, st) + if err := up(ctx, c); err != nil { + return fmt.Errorf("node%d up: %w", i, err) + } + t.Logf("node%d up!", i) + st, err = status(ctx, c) + if err != nil { + return fmt.Errorf("node%d status: %w", i, err) + } + sts[i] = st + + if st.BackendState != "Running" { + return fmt.Errorf("node%d state = %q", i, st.BackendState) + } + t.Logf("node%d up with %v", i, sts[i].Self.TailscaleIPs) + return nil + }) + } + if err := eg.Wait(); err != nil { + t.Fatalf("initial setup: %v", err) + } + + route, err := ping(ctx, c1, sts[1].Self.TailscaleIPs[0].String()) + t.Logf("ping route: %v, %v", route, err) +} + +func status(ctx context.Context, c *http.Client) (*ipnstate.Status, error) { + req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/status", nil) + if err != nil { + return nil, err + } + res, err := c.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + all, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("ReadAll: %w", err) + } + var st ipnstate.Status + if err := json.Unmarshal(all, &st); err != nil { + return nil, fmt.Errorf("JSON marshal error: %v; body was %q", err, all) + } + return &st, nil +} + +type routeType string + +const ( + routeDirect routeType = "direct" + routeDERP routeType = "derp" + routeLAN routeType = "lan" +) + +func ping(ctx context.Context, c *http.Client, target string) (routeType, error) { + req, err := http.NewRequestWithContext(ctx, "POST", "http://unused/ping?target="+url.QueryEscape(target), nil) + if err != nil { + return "", err + } + res, err := c.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + if res.StatusCode != 200 { + return "", fmt.Errorf("unexpected status code %v", res.Status) + } + all, _ := io.ReadAll(res.Body) + var route routeType + for _, line := range strings.Split(string(all), "\n") { + if strings.Contains(line, " via DERP") { + route = routeDERP + continue + } + // pong from foo (100.82.3.4) via ADDR:PORT in 69ms + if _, rest, ok := strings.Cut(line, " via "); ok { + ipPorStr, _, _ := strings.Cut(rest, " in ") + ipPort, err := netip.ParseAddrPort(ipPorStr) + if err == nil { + if ipPort.Addr().IsPrivate() { + route = routeLAN + } else { + route = routeDirect + } + continue + } + } + } + if route == "" { + return routeType(all), nil + } + return route, nil +} + +func up(ctx context.Context, c *http.Client) error { + req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil) + if err != nil { + return err + } + res, err := c.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + all, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + return fmt.Errorf("unexpected status code %v: %s", res.Status, all) + } + return nil +} + +func TestEasyEasy(t *testing.T) { + nt := newNatTest(t) + nt.runTest(easy, easy) +} + +func TestEasyHard(t *testing.T) { + nt := newNatTest(t) + nt.runTest(easy, hard) +} diff --git a/tstest/natlab/vnet/conf.go b/tstest/natlab/vnet/conf.go index 89dfc9570..8cd91f4cd 100644 --- a/tstest/natlab/vnet/conf.go +++ b/tstest/natlab/vnet/conf.go @@ -27,6 +27,10 @@ type Config struct { networks []*Network } +func (c *Config) NumNodes() int { + return len(c.nodes) +} + // AddNode creates a new node in the world. // // The opts may be of the following types: @@ -110,6 +114,11 @@ type Node struct { nets []*Network } +// MAC returns the MAC address of the node. +func (n *Node) MAC() MAC { + return n.mac +} + // Network returns the first network this node is connected to, // or nil if none. func (n *Node) Network() *Network { diff --git a/tstest/natlab/vnet/nat.go b/tstest/natlab/vnet/nat.go index 9ce04a23a..179feb733 100644 --- a/tstest/natlab/vnet/nat.go +++ b/tstest/natlab/vnet/nat.go @@ -5,6 +5,7 @@ package vnet import ( "errors" + "log" "math/rand/v2" "net/netip" "time" @@ -111,9 +112,9 @@ func (n *oneToOneNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (la return netip.AddrPortFrom(n.lanIP, dst.Port()) } -type hardKeyOut struct { - lanIP netip.Addr - dst netip.AddrPort +type srcDstTuple struct { + src netip.AddrPort + dst netip.AddrPort } type hardKeyIn struct { @@ -137,7 +138,7 @@ type lanAddrAndTime struct { type hardNAT struct { wanIP netip.Addr - out map[hardKeyOut]portMappingAndTime + out map[srcDstTuple]portMappingAndTime in map[hardKeyIn]lanAddrAndTime } @@ -148,7 +149,7 @@ func init() { } func (n *hardNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) { - ko := hardKeyOut{src.Addr(), dst} + ko := srcDstTuple{src, dst} if pm, ok := n.out[ko]; ok { // Existing flow. // TODO: bump timestamp @@ -196,9 +197,10 @@ func (n *hardNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (lanDst // Unlike Linux, this implementation is capped at 32k entries and doesn't resort // to other allocation strategies when all 32k WAN ports are taken. type easyNAT struct { - wanIP netip.Addr - out map[netip.AddrPort]portMappingAndTime - in map[uint16]lanAddrAndTime + wanIP netip.Addr + out map[netip.AddrPort]portMappingAndTime + in map[uint16]lanAddrAndTime + lastOut map[srcDstTuple]time.Time // (lan:port, wan:port) => last packet out time } func init() { @@ -208,6 +210,7 @@ func init() { } func (n *easyNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) { + mak.Set(&n.lastOut, srcDstTuple{src, dst}, at) if pm, ok := n.out[src]; ok { // Existing flow. // TODO: bump timestamp @@ -235,5 +238,14 @@ func (n *easyNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (lanDst if dst.Addr() != n.wanIP { return netip.AddrPort{} // drop; not for us. shouldn't happen if natlabd routing isn't broken. } - return n.in[dst.Port()].lanAddr + lanDst = n.in[dst.Port()].lanAddr + + // Stateful firewall: drop incoming packets that don't have traffic out. + // TODO(bradfitz): verify Linux does this in the router code, not in the NAT code. + if t, ok := n.lastOut[srcDstTuple{lanDst, src}]; !ok || at.Sub(t) > 300*time.Second { + log.Printf("Drop incoming packet from %v to %v; no recent outgoing packet", src, dst) + return netip.AddrPort{} + } + + return lanDst } diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index 7ce86d512..e3332ae3f 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -16,6 +16,7 @@ package vnet import ( "bufio" "context" + "crypto/tls" "encoding/binary" "encoding/json" "errors" @@ -24,6 +25,7 @@ import ( "log" "net" "net/http" + "net/http/httptest" "net/netip" "os/exec" "strconv" @@ -44,9 +46,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/net/netutil" "tailscale.com/net/stun" "tailscale.com/syncs" "tailscale.com/tailcfg" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/key" + "tailscale.com/types/logger" "tailscale.com/util/mak" "tailscale.com/util/set" ) @@ -240,6 +248,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails)) clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) + destPort := reqDetails.LocalPort if !clientRemoteIP.IsValid() { r.Complete(true) // sends a RST return @@ -254,7 +263,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { } ep.SocketOptions().SetKeepAlive(true) - if reqDetails.LocalPort == 123 { + if destPort == 123 { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) io.WriteString(tc, "Hello from Go\nGoodbye.\n") @@ -262,7 +271,21 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { return } - if reqDetails.LocalPort == 8008 && destIP == fakeTestAgentIP { + if destPort == 124 { + r.Complete(false) + tc := gonet.NewTCPConn(&wq, ep) + go func() { + defer tc.Close() + bs := bufio.NewScanner(tc) + for bs.Scan() { + line := bs.Text() + log.Printf("LOG from guest: %s", line) + } + }() + return + } + + if destPort == 8008 && destIP == fakeTestAgentIP { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) node := n.nodesByIP[clientRemoteIP] @@ -271,11 +294,40 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { return } + if destPort == 80 && destIP == fakeControlIP { + r.Complete(false) + tc := gonet.NewTCPConn(&wq, ep) + hs := &http.Server{Handler: n.s.control} + go hs.Serve(netutil.NewOneConnListener(tc, nil)) + return + } + + if destPort == 443 && (destIP == fakeDERP1IP || destIP == fakeDERP2IP) { + ds := n.s.derps[0] + if destIP == fakeDERP2IP { + ds = n.s.derps[1] + } + + r.Complete(false) + tc := gonet.NewTCPConn(&wq, ep) + tlsConn := tls.Server(tc, ds.tlsConfig) + hs := &http.Server{Handler: ds.handler} + go hs.Serve(netutil.NewOneConnListener(tlsConn, nil)) + return + } + if destPort == 80 && (destIP == fakeDERP1IP || destIP == fakeDERP2IP) { + r.Complete(false) + tc := gonet.NewTCPConn(&wq, ep) + hs := &http.Server{Handler: n.s.derps[0].handler} + go hs.Serve(netutil.NewOneConnListener(tc, nil)) + return + } + var targetDial string if n.s.derpIPs.Contains(destIP) { - targetDial = destIP.String() + ":" + strconv.Itoa(int(reqDetails.LocalPort)) - } else if destIP == fakeControlplaneIP { - targetDial = "controlplane.tailscale.com:" + strconv.Itoa(int(reqDetails.LocalPort)) + targetDial = destIP.String() + ":" + strconv.Itoa(int(destPort)) + } else if destIP == fakeProxyControlplaneIP { + targetDial = "controlplane.tailscale.com:" + strconv.Itoa(int(destPort)) } if targetDial != "" { c, err := net.Dial("tcp", targetDial) @@ -298,9 +350,12 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { } var ( - fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11}) - fakeControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1}) - fakeTestAgentIP = netip.AddrFrom4([4]byte{52, 52, 0, 2}) + fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11}) + fakeProxyControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1}) // real controlplane.tailscale.com proxy + fakeTestAgentIP = netip.AddrFrom4([4]byte{52, 52, 0, 2}) + fakeControlIP = netip.AddrFrom4([4]byte{52, 52, 0, 3}) // 3=C for "Control" + fakeDERP1IP = netip.AddrFrom4([4]byte{33, 4, 0, 1}) // 3340=DERP; 1=derp 1 + fakeDERP2IP = netip.AddrFrom4([4]byte{33, 4, 0, 2}) // 3340=DERP; 1=derp 1 ) type EthernetPacket struct { @@ -381,9 +436,33 @@ type node struct { lanIP netip.Addr // must be in net.lanIP prefix + unique in net } +type derpServer struct { + srv *derp.Server + handler http.Handler + tlsConfig *tls.Config +} + +func newDERPServer() *derpServer { + // Just to get a self-signed TLS cert: + ts := httptest.NewTLSServer(nil) + ts.Close() + + ds := &derpServer{ + srv: derp.NewServer(key.NewNode(), logger.Discard), + tlsConfig: ts.TLS, // self-signed; test client configure to not check + } + var mux http.ServeMux + mux.Handle("/derp", derphttp.Handler(ds.srv)) + mux.HandleFunc("/generate_204", derphttp.ServeNoContent) + + ds.handler = &mux + return ds +} + type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc + blendReality bool derpIPs set.Set[netip.Addr] @@ -392,10 +471,50 @@ type Server struct { networks set.Set[*network] networkByWAN map[netip.Addr]*network - mu sync.Mutex - agentConnWaiter map[*node]chan<- struct{} // signaled after added to set - agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all - agentRoundTripper map[*node]*http.Transport + control *testcontrol.Server + derps []*derpServer + + mu sync.Mutex + agentConnWaiter map[*node]chan<- struct{} // signaled after added to set + agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all + agentDialer map[*node]DialFunc +} + +type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) + +var derpMap = &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "atlantis", + RegionName: "Atlantis", + Nodes: []*tailcfg.DERPNode{ + { + Name: "1a", + RegionID: 1, + HostName: "derp1.tailscale", + IPv4: fakeDERP1IP.String(), + InsecureForTests: true, + CanPort80: true, + }, + }, + }, + 2: { + RegionID: 2, + RegionCode: "northpole", + RegionName: "North Pole", + Nodes: []*tailcfg.DERPNode{ + { + Name: "2a", + RegionID: 2, + HostName: "derp2.tailscale", + IPv4: fakeDERP2IP.String(), + InsecureForTests: true, + CanPort80: true, + }, + }, + }, + }, } func New(c *Config) (*Server, error) { @@ -404,12 +523,20 @@ func New(c *Config) (*Server, error) { shutdownCtx: ctx, shutdownCancel: cancel, + control: &testcontrol.Server{ + DERPMap: derpMap, + ExplicitBaseURL: "http://control.tailscale", + }, + derpIPs: set.Of[netip.Addr](), nodeByMAC: map[MAC]*node{}, networkByWAN: map[netip.Addr]*network{}, networks: set.Of[*network](), } + for range 2 { + s.derps = append(s.derps, newDERPServer()) + } if err := s.initFromConfig(c); err != nil { return nil, err } @@ -418,9 +545,14 @@ func New(c *Config) (*Server, error) { return nil, fmt.Errorf("newServer: initStack: %v", err) } } + return s, nil } +func (s *Server) Close() { + s.shutdownCancel() +} + func (s *Server) HWAddr(mac MAC) net.HardwareAddr { // TODO: cache return net.HardwareAddr(mac[:]) @@ -435,7 +567,13 @@ func (s *Server) IPv4ForDNS(qname string) (netip.Addr, bool) { case "test-driver.tailscale": return fakeTestAgentIP, true case "controlplane.tailscale.com": - return fakeControlplaneIP, true + return fakeProxyControlplaneIP, true + case "control.tailscale": + return fakeControlIP, true + case "derp1.tailscale": + return fakeDERP1IP, true + case "derp2.tailscale": + return fakeDERP2IP, true } return netip.Addr{}, false } @@ -538,7 +676,10 @@ func (s *Server) routeUDPPacket(up UDPPacket) { if up.Dst.Port() == stunPort { // TODO(bradfitz): fake latency; time.AfterFunc the response if res, ok := makeSTUNReply(up); ok { + //log.Printf("STUN reply: %+v", res) s.routeUDPPacket(res) + } else { + log.Printf("weird: STUN packet not handled") } return } @@ -622,6 +763,7 @@ func (n *network) HandleEthernetPacket(ep EthernetPacket) { func (n *network) HandleUDPPacket(p UDPPacket) { dst := n.doNATIn(p.Src, p.Dst) if !dst.IsValid() { + log.Printf("Warning: NAT dropped packet; no mapping for %v=>%v", p.Src, p.Dst) return } p.Dst = dst @@ -726,7 +868,10 @@ func (n *network) HandleEthernetIPv4PacketForRouter(ep EthernetPacket) { if toForward && isUDP { src := netip.AddrPortFrom(srcIP, uint16(udp.SrcPort)) dst := netip.AddrPortFrom(dstIP, uint16(udp.DstPort)) + src0 := src src = n.doNATOut(src, dst) + _ = src0 + //log.Printf("XXX UDP out %v=>%v to %v", src0, src, dst) n.s.routeUDPPacket(UDPPacket{ Src: src, @@ -891,12 +1036,19 @@ func (s *Server) shouldInterceptTCP(pkt gopacket.Packet) bool { if !ok { return false } - if tcp.DstPort == 123 { + if tcp.DstPort == 123 || tcp.DstPort == 124 { return true } dstIP, _ := netip.AddrFromSlice(ipv4.DstIP.To4()) if tcp.DstPort == 80 || tcp.DstPort == 443 { - if dstIP == fakeControlplaneIP || s.derpIPs.Contains(dstIP) { + switch dstIP { + case fakeControlIP, fakeDERP1IP, fakeDERP2IP: + return true + } + if dstIP == fakeProxyControlplaneIP { + return s.blendReality + } + if s.derpIPs.Contains(dstIP) { return true } } @@ -1166,12 +1318,15 @@ func (s *Server) takeAgentConn(ctx context.Context, n *node) (_ *agentConn, ok b for { ac, ok := s.takeAgentConnOne(n) if ok { + log.Printf("got agent conn for %v", n.mac) return ac, true } s.mu.Lock() ready := make(chan struct{}) mak.Set(&s.agentConnWaiter, n, ready) s.mu.Unlock() + + log.Printf("waiting for agent conn for %v", n.mac) select { case <-ctx.Done(): return nil, false @@ -1190,36 +1345,40 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) { for ac := range s.agentConns { if ac.node == n { s.agentConns.Delete(ac) + log.Printf("XXX takeAgentConnOne HIT for %v", n.mac) return ac, true } } + log.Printf("XXX takeAgentConnOne MISS for %v", n.mac) return nil, false } -func (s *Server) NodeAgentRoundTripper(ctx context.Context, n *Node) http.RoundTripper { +func (s *Server) NodeAgentDialer(n *Node) DialFunc { s.mu.Lock() defer s.mu.Unlock() - if rt, ok := s.agentRoundTripper[n.n]; ok { - return rt + if d, ok := s.agentDialer[n.n]; ok { + return d } - - var rt = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - ac, ok := s.takeAgentConn(ctx, n.n) - if !ok { - return nil, ctx.Err() - } - return ac.tc, nil - }, + d := func(ctx context.Context, network, addr string) (net.Conn, error) { + ac, ok := s.takeAgentConn(ctx, n.n) + if !ok { + return nil, ctx.Err() + } + return ac.tc, nil } + mak.Set(&s.agentDialer, n.n, d) + return d +} - mak.Set(&s.agentRoundTripper, n.n, rt) - return rt +func (s *Server) NodeAgentRoundTripper(n *Node) http.RoundTripper { + return &http.Transport{ + DialContext: s.NodeAgentDialer(n), + } } func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) { - rt := s.NodeAgentRoundTripper(ctx, n) + rt := s.NodeAgentRoundTripper(n) req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil) if err != nil { return nil, err