vnet: add control/derps to test, stateful firewall

Updates #13038

Change-Id: Icd65b34c5f03498b5a7109785bb44692bce8911a
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
dependabot/go_modules/github.com/docker/docker-26.1.5incompatible
Brad Fitzpatrick 4 months ago committed by Maisem Ali
parent 20691894f5
commit 8594292aa4

@ -11,6 +11,7 @@
package main package main
import ( import (
"bufio"
"bytes" "bytes"
"errors" "errors"
"flag" "flag"
@ -19,12 +20,18 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
"net/url"
"os" "os"
"os/exec" "os/exec"
"regexp"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/mitchellh/go-ps"
"tailscale.com/client/tailscale"
"tailscale.com/util/must"
"tailscale.com/util/set" "tailscale.com/util/set"
"tailscale.com/version/distro" "tailscale.com/version/distro"
) )
@ -33,21 +40,35 @@ var (
driverAddr = flag.String("driver", "test-driver.tailscale:8008", "address of the test driver; by default we use the DNS name test-driver.tailscale which is special cased in the emulated network's DNS server") driverAddr = flag.String("driver", "test-driver.tailscale:8008", "address of the test driver; by default we use the DNS name test-driver.tailscale which is special cased in the emulated network's DNS server")
) )
type chanListener <-chan net.Conn func absify(cmd string) string {
func serveCmd(w http.ResponseWriter, cmd string, args ...string) {
if distro.Get() == distro.Gokrazy && !strings.Contains(cmd, "/") { if distro.Get() == distro.Gokrazy && !strings.Contains(cmd, "/") {
cmd = "/user/" + cmd return "/user/" + cmd
} }
out, err := exec.Command(cmd, args...).CombinedOutput() return cmd
}
func serveCmd(w http.ResponseWriter, cmd string, args ...string) {
log.Printf("Got serveCmd for %q %v", cmd, args)
out, err := exec.Command(absify(cmd), args...).CombinedOutput()
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
if err != nil { if err != nil {
w.Header().Set("Exec-Err", err.Error()) w.Header().Set("Exec-Err", err.Error())
w.WriteHeader(500) w.WriteHeader(500)
log.Printf("Err on serveCmd for %q %v, %d bytes of output: %v", cmd, args, len(out), err)
} else {
log.Printf("Did serveCmd for %q %v, %d bytes of output", cmd, args, len(out))
} }
w.Write(out) w.Write(out)
} }
type localClientRoundTripper struct {
lc *tailscale.LocalClient
}
func (rt localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return rt.lc.DoLocalRequest(req)
}
func main() { func main() {
if distro.Get() == distro.Gokrazy { if distro.Get() == distro.Gokrazy {
cmdLine, _ := os.ReadFile("/proc/cmdline") cmdLine, _ := os.ReadFile("/proc/cmdline")
@ -59,8 +80,52 @@ func main() {
} }
} }
flag.Parse() flag.Parse()
if distro.Get() == distro.Gokrazy {
nsRx := regexp.MustCompile(`(?m)^nameserver (.*)`)
for t := time.Now(); time.Since(t) < 10*time.Second; time.Sleep(10 * time.Millisecond) {
all, _ := os.ReadFile("/etc/resolv.conf")
if nsRx.Match(all) {
break
}
}
}
logc, err := net.Dial("tcp", "9.9.9.9:124")
if err == nil {
log.SetOutput(logc)
}
log.Printf("Tailscale Test Agent running.") log.Printf("Tailscale Test Agent running.")
if distro.Get() == distro.Gokrazy {
procs, err := ps.Processes()
if err != nil {
log.Fatalf("ps.Processes: %v", err)
}
killed := false
for _, p := range procs {
if p.Executable() == "tailscaled" {
if op, err := os.FindProcess(p.Pid()); err == nil {
op.Signal(os.Interrupt)
killed = true
}
}
}
log.Printf("killed = %v", killed)
if killed {
for {
_, err := exec.Command(absify("tailscale"), "status", "--json").CombinedOutput()
if err == nil {
log.Printf("tailscaled back up")
break
}
log.Printf("tailscale status error; sleeping before trying again...")
time.Sleep(50 * time.Millisecond)
}
}
}
var mux http.ServeMux var mux http.ServeMux
var hs http.Server var hs http.Server
hs.Handler = &mux hs.Handler = &mux
@ -75,7 +140,7 @@ func main() {
switch s { switch s {
case http.StateNew: case http.StateNew:
newSet.Add(c) newSet.Add(c)
case http.StateClosed: default:
newSet.Delete(c) newSet.Delete(c)
} }
if len(newSet) == 0 { if len(newSet) == 0 {
@ -86,20 +151,41 @@ func main() {
} }
} }
conns := make(chan net.Conn, 1) conns := make(chan net.Conn, 1)
var lc tailscale.LocalClient
rp := httputil.NewSingleHostReverseProxy(must.Get(url.Parse("http://local-tailscaled.sock")))
rp.Transport = localClientRoundTripper{&lc}
mux.Handle("/localapi/", rp)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "TTA\n") io.WriteString(w, "TTA\n")
return return
}) })
mux.HandleFunc("/up", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/up", func(w http.ResponseWriter, r *http.Request) {
serveCmd(w, "tailscale", "up", "--auth-key=test") cmd := exec.Command(absify("tailscale"), "debug", "daemon-logs")
out, err := cmd.StdoutPipe()
if err != nil {
http.Error(w, err.Error(), 500)
return
}
defer out.Close()
cmd.Start()
defer cmd.Process.Kill()
go func() {
bs := bufio.NewScanner(out)
for bs.Scan() {
log.Printf("Daemon: %s", bs.Text())
}
}()
serveCmd(w, "tailscale", "up", "--login-server=http://control.tailscale")
}) })
mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) {
serveCmd(w, "tailscale", "status", "--json") serveCmd(w, "tailscale", "status", "--json")
}) })
mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
target := r.FormValue("target") target := r.FormValue("target")
cmd := exec.Command("tailscale", "ping", target) cmd := exec.Command(absify("tailscale"), "ping", target)
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.(http.Flusher).Flush() w.(http.Flusher).Flush()
cmd.Stdout = w cmd.Stdout = w
@ -139,6 +225,8 @@ func connect() (net.Conn, error) {
return c, nil return c, nil
} }
type chanListener <-chan net.Conn
func (cl chanListener) Accept() (net.Conn, error) { func (cl chanListener) Accept() (net.Conn, error) {
c, ok := <-cl c, ok := <-cl
if !ok { if !ok {

@ -19,6 +19,7 @@ import (
var ( var (
listen = flag.String("listen", "/tmp/qemu.sock", "path to listen on") listen = flag.String("listen", "/tmp/qemu.sock", "path to listen on")
nat = flag.String("nat", "easy", "type of NAT to use") nat = flag.String("nat", "easy", "type of NAT to use")
nat2 = flag.String("nat2", "hard", "type of NAT to use for second network")
portmap = flag.Bool("portmap", false, "enable portmapping") portmap = flag.Bool("portmap", false, "enable portmapping")
dgram = flag.Bool("dgram", false, "enable datagram mode; for use with macOS Hypervisor.Framework and VZFileHandleNetworkDeviceAttachment") dgram = flag.Bool("dgram", false, "enable datagram mode; for use with macOS Hypervisor.Framework and VZFileHandleNetworkDeviceAttachment")
) )
@ -52,7 +53,7 @@ func main() {
var c vnet.Config var c vnet.Config
node1 := c.AddNode(c.AddNetwork("2.1.1.1", "192.168.1.1/24", vnet.NAT(*nat))) node1 := c.AddNode(c.AddNetwork("2.1.1.1", "192.168.1.1/24", vnet.NAT(*nat)))
c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", vnet.NAT(*nat))) c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", vnet.NAT(*nat2)))
if *portmap { if *portmap {
node1.Network().AddService(vnet.NATPMP) node1.Network().AddService(vnet.NATPMP)
} }
@ -81,6 +82,7 @@ func main() {
} }
for { for {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
//continue
getStatus() getStatus()
} }
}() }()

@ -6,3 +6,6 @@ image:
qemu: image qemu: image
qemu-system-x86_64 -m 1G -drive file=tsapp.img,format=raw -boot d -netdev user,id=user.0 -device virtio-net-pci,netdev=user.0 -serial mon:stdio -audio none qemu-system-x86_64 -m 1G -drive file=tsapp.img,format=raw -boot d -netdev user,id=user.0 -device virtio-net-pci,netdev=user.0 -serial mon:stdio -audio none
qcow2: image
qemu-img convert -O qcow2 tsapp.img tsapp.qcow2

@ -190,6 +190,7 @@ func RunDERPAndSTUN(t testing.TB, logf logger.Logf, ipAddress string) (derpMap *
} }
httpsrv := httptest.NewUnstartedServer(derphttp.Handler(d)) httpsrv := httptest.NewUnstartedServer(derphttp.Handler(d))
httpsrv.Listener.Close()
httpsrv.Listener = ln httpsrv.Listener = ln
httpsrv.Config.ErrorLog = logger.StdLogger(logf) httpsrv.Config.ErrorLog = logger.StdLogger(logf)
httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))

@ -0,0 +1,277 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package nat
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"golang.org/x/sync/errgroup"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tstest/natlab/vnet"
)
type natTest struct {
tb testing.TB
base string // base image
tempDir string // for qcow2 images
vnet *vnet.Server
}
func newNatTest(tb testing.TB) *natTest {
nt := &natTest{
tb: tb,
tempDir: tb.TempDir(),
base: "/Users/bradfitz/src/tailscale.com/gokrazy/tsapp.qcow2",
}
if _, err := os.Stat(nt.base); err != nil {
tb.Skipf("skipping test; base image %q not found", nt.base)
}
return nt
}
type addNodeFunc func(c *vnet.Config) *vnet.Node
func easy(c *vnet.Config) *vnet.Node {
n := c.NumNodes() + 1
return c.AddNode(c.AddNetwork(
fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP
fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT))
}
func hard(c *vnet.Config) *vnet.Node {
n := c.NumNodes() + 1
return c.AddNode(c.AddNetwork(
fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP
fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT))
}
func (nt *natTest) runTest(node1, node2 addNodeFunc) {
t := nt.tb
var c vnet.Config
nodes := []*vnet.Node{
node1(&c),
node2(&c),
}
var err error
nt.vnet, err = vnet.New(&c)
if err != nil {
t.Fatalf("newServer: %v", err)
}
nt.tb.Cleanup(func() {
nt.vnet.Close()
})
var wg sync.WaitGroup // waiting for srv.Accept goroutine
defer wg.Wait()
sockAddr := filepath.Join(nt.tempDir, "qemu.sock")
srv, err := net.Listen("unix", sockAddr)
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer srv.Close()
wg.Add(1)
go func() {
defer wg.Done()
for {
c, err := srv.Accept()
if err != nil {
return
}
go nt.vnet.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU)
}
}()
for i, node := range nodes {
disk := fmt.Sprintf("%s/node-%d.qcow2", nt.tempDir, i)
out, err := exec.Command("qemu-img", "create",
"-f", "qcow2",
"-F", "qcow2",
"-b", nt.base,
disk).CombinedOutput()
if err != nil {
t.Fatalf("qemu-img create: %v, %s", err, out)
}
cmd := exec.Command("qemu-system-x86_64",
"-M", "microvm,isa-serial=off",
"-m", "1G",
"-nodefaults", "-no-user-config", "-nographic",
"-kernel", "/Users/bradfitz/src/github.com/tailscale/gokrazy-kernel/vmlinuz",
"-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1",
"-drive", "id=blk0,file="+disk+",format=qcow2",
"-device", "virtio-blk-device,drive=blk0",
"-netdev", "stream,id=net0,addr.type=unix,addr.path="+sockAddr,
"-device", "virtio-serial-device",
"-device", "virtio-net-device,netdev=net0,mac="+node.MAC().String(),
"-chardev", "stdio,id=virtiocon0,mux=on",
"-device", "virtconsole,chardev=virtiocon0",
"-mon", "chardev=virtiocon0,mode=readline",
"-audio", "none",
)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
t.Fatalf("qemu: %v", err)
}
nt.tb.Cleanup(func() {
cmd.Process.Kill()
cmd.Wait()
})
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
c1 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[0])}
c2 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[1])}
var eg errgroup.Group
var sts [2]*ipnstate.Status
for i, c := range []*http.Client{c1, c2} {
i, c := i, c
eg.Go(func() error {
st, err := status(ctx, c)
if err != nil {
return fmt.Errorf("node%d status: %w", i, err)
}
t.Logf("node%d status: %v", i, st)
if err := up(ctx, c); err != nil {
return fmt.Errorf("node%d up: %w", i, err)
}
t.Logf("node%d up!", i)
st, err = status(ctx, c)
if err != nil {
return fmt.Errorf("node%d status: %w", i, err)
}
sts[i] = st
if st.BackendState != "Running" {
return fmt.Errorf("node%d state = %q", i, st.BackendState)
}
t.Logf("node%d up with %v", i, sts[i].Self.TailscaleIPs)
return nil
})
}
if err := eg.Wait(); err != nil {
t.Fatalf("initial setup: %v", err)
}
route, err := ping(ctx, c1, sts[1].Self.TailscaleIPs[0].String())
t.Logf("ping route: %v, %v", route, err)
}
func status(ctx context.Context, c *http.Client) (*ipnstate.Status, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/status", nil)
if err != nil {
return nil, err
}
res, err := c.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
all, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("ReadAll: %w", err)
}
var st ipnstate.Status
if err := json.Unmarshal(all, &st); err != nil {
return nil, fmt.Errorf("JSON marshal error: %v; body was %q", err, all)
}
return &st, nil
}
type routeType string
const (
routeDirect routeType = "direct"
routeDERP routeType = "derp"
routeLAN routeType = "lan"
)
func ping(ctx context.Context, c *http.Client, target string) (routeType, error) {
req, err := http.NewRequestWithContext(ctx, "POST", "http://unused/ping?target="+url.QueryEscape(target), nil)
if err != nil {
return "", err
}
res, err := c.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
if res.StatusCode != 200 {
return "", fmt.Errorf("unexpected status code %v", res.Status)
}
all, _ := io.ReadAll(res.Body)
var route routeType
for _, line := range strings.Split(string(all), "\n") {
if strings.Contains(line, " via DERP") {
route = routeDERP
continue
}
// pong from foo (100.82.3.4) via ADDR:PORT in 69ms
if _, rest, ok := strings.Cut(line, " via "); ok {
ipPorStr, _, _ := strings.Cut(rest, " in ")
ipPort, err := netip.ParseAddrPort(ipPorStr)
if err == nil {
if ipPort.Addr().IsPrivate() {
route = routeLAN
} else {
route = routeDirect
}
continue
}
}
}
if route == "" {
return routeType(all), nil
}
return route, nil
}
func up(ctx context.Context, c *http.Client) error {
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil)
if err != nil {
return err
}
res, err := c.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
all, _ := io.ReadAll(res.Body)
if res.StatusCode != 200 {
return fmt.Errorf("unexpected status code %v: %s", res.Status, all)
}
return nil
}
func TestEasyEasy(t *testing.T) {
nt := newNatTest(t)
nt.runTest(easy, easy)
}
func TestEasyHard(t *testing.T) {
nt := newNatTest(t)
nt.runTest(easy, hard)
}

@ -27,6 +27,10 @@ type Config struct {
networks []*Network networks []*Network
} }
func (c *Config) NumNodes() int {
return len(c.nodes)
}
// AddNode creates a new node in the world. // AddNode creates a new node in the world.
// //
// The opts may be of the following types: // The opts may be of the following types:
@ -110,6 +114,11 @@ type Node struct {
nets []*Network nets []*Network
} }
// MAC returns the MAC address of the node.
func (n *Node) MAC() MAC {
return n.mac
}
// Network returns the first network this node is connected to, // Network returns the first network this node is connected to,
// or nil if none. // or nil if none.
func (n *Node) Network() *Network { func (n *Node) Network() *Network {

@ -5,6 +5,7 @@ package vnet
import ( import (
"errors" "errors"
"log"
"math/rand/v2" "math/rand/v2"
"net/netip" "net/netip"
"time" "time"
@ -111,8 +112,8 @@ func (n *oneToOneNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (la
return netip.AddrPortFrom(n.lanIP, dst.Port()) return netip.AddrPortFrom(n.lanIP, dst.Port())
} }
type hardKeyOut struct { type srcDstTuple struct {
lanIP netip.Addr src netip.AddrPort
dst netip.AddrPort dst netip.AddrPort
} }
@ -137,7 +138,7 @@ type lanAddrAndTime struct {
type hardNAT struct { type hardNAT struct {
wanIP netip.Addr wanIP netip.Addr
out map[hardKeyOut]portMappingAndTime out map[srcDstTuple]portMappingAndTime
in map[hardKeyIn]lanAddrAndTime in map[hardKeyIn]lanAddrAndTime
} }
@ -148,7 +149,7 @@ func init() {
} }
func (n *hardNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) { func (n *hardNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) {
ko := hardKeyOut{src.Addr(), dst} ko := srcDstTuple{src, dst}
if pm, ok := n.out[ko]; ok { if pm, ok := n.out[ko]; ok {
// Existing flow. // Existing flow.
// TODO: bump timestamp // TODO: bump timestamp
@ -199,6 +200,7 @@ type easyNAT struct {
wanIP netip.Addr wanIP netip.Addr
out map[netip.AddrPort]portMappingAndTime out map[netip.AddrPort]portMappingAndTime
in map[uint16]lanAddrAndTime in map[uint16]lanAddrAndTime
lastOut map[srcDstTuple]time.Time // (lan:port, wan:port) => last packet out time
} }
func init() { func init() {
@ -208,6 +210,7 @@ func init() {
} }
func (n *easyNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) { func (n *easyNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) {
mak.Set(&n.lastOut, srcDstTuple{src, dst}, at)
if pm, ok := n.out[src]; ok { if pm, ok := n.out[src]; ok {
// Existing flow. // Existing flow.
// TODO: bump timestamp // TODO: bump timestamp
@ -235,5 +238,14 @@ func (n *easyNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (lanDst
if dst.Addr() != n.wanIP { if dst.Addr() != n.wanIP {
return netip.AddrPort{} // drop; not for us. shouldn't happen if natlabd routing isn't broken. return netip.AddrPort{} // drop; not for us. shouldn't happen if natlabd routing isn't broken.
} }
return n.in[dst.Port()].lanAddr lanDst = n.in[dst.Port()].lanAddr
// Stateful firewall: drop incoming packets that don't have traffic out.
// TODO(bradfitz): verify Linux does this in the router code, not in the NAT code.
if t, ok := n.lastOut[srcDstTuple{lanDst, src}]; !ok || at.Sub(t) > 300*time.Second {
log.Printf("Drop incoming packet from %v to %v; no recent outgoing packet", src, dst)
return netip.AddrPort{}
}
return lanDst
} }

@ -16,6 +16,7 @@ package vnet
import ( import (
"bufio" "bufio"
"context" "context"
"crypto/tls"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
@ -24,6 +25,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"net/netip" "net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
@ -44,9 +46,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/net/netutil"
"tailscale.com/net/stun" "tailscale.com/net/stun"
"tailscale.com/syncs" "tailscale.com/syncs"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstest/integration/testcontrol"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/mak" "tailscale.com/util/mak"
"tailscale.com/util/set" "tailscale.com/util/set"
) )
@ -240,6 +248,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails)) log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails))
clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress)
destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress)
destPort := reqDetails.LocalPort
if !clientRemoteIP.IsValid() { if !clientRemoteIP.IsValid() {
r.Complete(true) // sends a RST r.Complete(true) // sends a RST
return return
@ -254,7 +263,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
} }
ep.SocketOptions().SetKeepAlive(true) ep.SocketOptions().SetKeepAlive(true)
if reqDetails.LocalPort == 123 { if destPort == 123 {
r.Complete(false) r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep) tc := gonet.NewTCPConn(&wq, ep)
io.WriteString(tc, "Hello from Go\nGoodbye.\n") io.WriteString(tc, "Hello from Go\nGoodbye.\n")
@ -262,7 +271,21 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
return return
} }
if reqDetails.LocalPort == 8008 && destIP == fakeTestAgentIP { if destPort == 124 {
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
go func() {
defer tc.Close()
bs := bufio.NewScanner(tc)
for bs.Scan() {
line := bs.Text()
log.Printf("LOG from guest: %s", line)
}
}()
return
}
if destPort == 8008 && destIP == fakeTestAgentIP {
r.Complete(false) r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep) tc := gonet.NewTCPConn(&wq, ep)
node := n.nodesByIP[clientRemoteIP] node := n.nodesByIP[clientRemoteIP]
@ -271,11 +294,40 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
return return
} }
if destPort == 80 && destIP == fakeControlIP {
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
hs := &http.Server{Handler: n.s.control}
go hs.Serve(netutil.NewOneConnListener(tc, nil))
return
}
if destPort == 443 && (destIP == fakeDERP1IP || destIP == fakeDERP2IP) {
ds := n.s.derps[0]
if destIP == fakeDERP2IP {
ds = n.s.derps[1]
}
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
tlsConn := tls.Server(tc, ds.tlsConfig)
hs := &http.Server{Handler: ds.handler}
go hs.Serve(netutil.NewOneConnListener(tlsConn, nil))
return
}
if destPort == 80 && (destIP == fakeDERP1IP || destIP == fakeDERP2IP) {
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
hs := &http.Server{Handler: n.s.derps[0].handler}
go hs.Serve(netutil.NewOneConnListener(tc, nil))
return
}
var targetDial string var targetDial string
if n.s.derpIPs.Contains(destIP) { if n.s.derpIPs.Contains(destIP) {
targetDial = destIP.String() + ":" + strconv.Itoa(int(reqDetails.LocalPort)) targetDial = destIP.String() + ":" + strconv.Itoa(int(destPort))
} else if destIP == fakeControlplaneIP { } else if destIP == fakeProxyControlplaneIP {
targetDial = "controlplane.tailscale.com:" + strconv.Itoa(int(reqDetails.LocalPort)) targetDial = "controlplane.tailscale.com:" + strconv.Itoa(int(destPort))
} }
if targetDial != "" { if targetDial != "" {
c, err := net.Dial("tcp", targetDial) c, err := net.Dial("tcp", targetDial)
@ -299,8 +351,11 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
var ( var (
fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11}) fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11})
fakeControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1}) fakeProxyControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1}) // real controlplane.tailscale.com proxy
fakeTestAgentIP = netip.AddrFrom4([4]byte{52, 52, 0, 2}) fakeTestAgentIP = netip.AddrFrom4([4]byte{52, 52, 0, 2})
fakeControlIP = netip.AddrFrom4([4]byte{52, 52, 0, 3}) // 3=C for "Control"
fakeDERP1IP = netip.AddrFrom4([4]byte{33, 4, 0, 1}) // 3340=DERP; 1=derp 1
fakeDERP2IP = netip.AddrFrom4([4]byte{33, 4, 0, 2}) // 3340=DERP; 1=derp 1
) )
type EthernetPacket struct { type EthernetPacket struct {
@ -381,9 +436,33 @@ type node struct {
lanIP netip.Addr // must be in net.lanIP prefix + unique in net lanIP netip.Addr // must be in net.lanIP prefix + unique in net
} }
type derpServer struct {
srv *derp.Server
handler http.Handler
tlsConfig *tls.Config
}
func newDERPServer() *derpServer {
// Just to get a self-signed TLS cert:
ts := httptest.NewTLSServer(nil)
ts.Close()
ds := &derpServer{
srv: derp.NewServer(key.NewNode(), logger.Discard),
tlsConfig: ts.TLS, // self-signed; test client configure to not check
}
var mux http.ServeMux
mux.Handle("/derp", derphttp.Handler(ds.srv))
mux.HandleFunc("/generate_204", derphttp.ServeNoContent)
ds.handler = &mux
return ds
}
type Server struct { type Server struct {
shutdownCtx context.Context shutdownCtx context.Context
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
blendReality bool
derpIPs set.Set[netip.Addr] derpIPs set.Set[netip.Addr]
@ -392,10 +471,50 @@ type Server struct {
networks set.Set[*network] networks set.Set[*network]
networkByWAN map[netip.Addr]*network networkByWAN map[netip.Addr]*network
control *testcontrol.Server
derps []*derpServer
mu sync.Mutex mu sync.Mutex
agentConnWaiter map[*node]chan<- struct{} // signaled after added to set agentConnWaiter map[*node]chan<- struct{} // signaled after added to set
agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all
agentRoundTripper map[*node]*http.Transport agentDialer map[*node]DialFunc
}
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
var derpMap = &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
RegionID: 1,
RegionCode: "atlantis",
RegionName: "Atlantis",
Nodes: []*tailcfg.DERPNode{
{
Name: "1a",
RegionID: 1,
HostName: "derp1.tailscale",
IPv4: fakeDERP1IP.String(),
InsecureForTests: true,
CanPort80: true,
},
},
},
2: {
RegionID: 2,
RegionCode: "northpole",
RegionName: "North Pole",
Nodes: []*tailcfg.DERPNode{
{
Name: "2a",
RegionID: 2,
HostName: "derp2.tailscale",
IPv4: fakeDERP2IP.String(),
InsecureForTests: true,
CanPort80: true,
},
},
},
},
} }
func New(c *Config) (*Server, error) { func New(c *Config) (*Server, error) {
@ -404,12 +523,20 @@ func New(c *Config) (*Server, error) {
shutdownCtx: ctx, shutdownCtx: ctx,
shutdownCancel: cancel, shutdownCancel: cancel,
control: &testcontrol.Server{
DERPMap: derpMap,
ExplicitBaseURL: "http://control.tailscale",
},
derpIPs: set.Of[netip.Addr](), derpIPs: set.Of[netip.Addr](),
nodeByMAC: map[MAC]*node{}, nodeByMAC: map[MAC]*node{},
networkByWAN: map[netip.Addr]*network{}, networkByWAN: map[netip.Addr]*network{},
networks: set.Of[*network](), networks: set.Of[*network](),
} }
for range 2 {
s.derps = append(s.derps, newDERPServer())
}
if err := s.initFromConfig(c); err != nil { if err := s.initFromConfig(c); err != nil {
return nil, err return nil, err
} }
@ -418,9 +545,14 @@ func New(c *Config) (*Server, error) {
return nil, fmt.Errorf("newServer: initStack: %v", err) return nil, fmt.Errorf("newServer: initStack: %v", err)
} }
} }
return s, nil return s, nil
} }
func (s *Server) Close() {
s.shutdownCancel()
}
func (s *Server) HWAddr(mac MAC) net.HardwareAddr { func (s *Server) HWAddr(mac MAC) net.HardwareAddr {
// TODO: cache // TODO: cache
return net.HardwareAddr(mac[:]) return net.HardwareAddr(mac[:])
@ -435,7 +567,13 @@ func (s *Server) IPv4ForDNS(qname string) (netip.Addr, bool) {
case "test-driver.tailscale": case "test-driver.tailscale":
return fakeTestAgentIP, true return fakeTestAgentIP, true
case "controlplane.tailscale.com": case "controlplane.tailscale.com":
return fakeControlplaneIP, true return fakeProxyControlplaneIP, true
case "control.tailscale":
return fakeControlIP, true
case "derp1.tailscale":
return fakeDERP1IP, true
case "derp2.tailscale":
return fakeDERP2IP, true
} }
return netip.Addr{}, false return netip.Addr{}, false
} }
@ -538,7 +676,10 @@ func (s *Server) routeUDPPacket(up UDPPacket) {
if up.Dst.Port() == stunPort { if up.Dst.Port() == stunPort {
// TODO(bradfitz): fake latency; time.AfterFunc the response // TODO(bradfitz): fake latency; time.AfterFunc the response
if res, ok := makeSTUNReply(up); ok { if res, ok := makeSTUNReply(up); ok {
//log.Printf("STUN reply: %+v", res)
s.routeUDPPacket(res) s.routeUDPPacket(res)
} else {
log.Printf("weird: STUN packet not handled")
} }
return return
} }
@ -622,6 +763,7 @@ func (n *network) HandleEthernetPacket(ep EthernetPacket) {
func (n *network) HandleUDPPacket(p UDPPacket) { func (n *network) HandleUDPPacket(p UDPPacket) {
dst := n.doNATIn(p.Src, p.Dst) dst := n.doNATIn(p.Src, p.Dst)
if !dst.IsValid() { if !dst.IsValid() {
log.Printf("Warning: NAT dropped packet; no mapping for %v=>%v", p.Src, p.Dst)
return return
} }
p.Dst = dst p.Dst = dst
@ -726,7 +868,10 @@ func (n *network) HandleEthernetIPv4PacketForRouter(ep EthernetPacket) {
if toForward && isUDP { if toForward && isUDP {
src := netip.AddrPortFrom(srcIP, uint16(udp.SrcPort)) src := netip.AddrPortFrom(srcIP, uint16(udp.SrcPort))
dst := netip.AddrPortFrom(dstIP, uint16(udp.DstPort)) dst := netip.AddrPortFrom(dstIP, uint16(udp.DstPort))
src0 := src
src = n.doNATOut(src, dst) src = n.doNATOut(src, dst)
_ = src0
//log.Printf("XXX UDP out %v=>%v to %v", src0, src, dst)
n.s.routeUDPPacket(UDPPacket{ n.s.routeUDPPacket(UDPPacket{
Src: src, Src: src,
@ -891,12 +1036,19 @@ func (s *Server) shouldInterceptTCP(pkt gopacket.Packet) bool {
if !ok { if !ok {
return false return false
} }
if tcp.DstPort == 123 { if tcp.DstPort == 123 || tcp.DstPort == 124 {
return true return true
} }
dstIP, _ := netip.AddrFromSlice(ipv4.DstIP.To4()) dstIP, _ := netip.AddrFromSlice(ipv4.DstIP.To4())
if tcp.DstPort == 80 || tcp.DstPort == 443 { if tcp.DstPort == 80 || tcp.DstPort == 443 {
if dstIP == fakeControlplaneIP || s.derpIPs.Contains(dstIP) { switch dstIP {
case fakeControlIP, fakeDERP1IP, fakeDERP2IP:
return true
}
if dstIP == fakeProxyControlplaneIP {
return s.blendReality
}
if s.derpIPs.Contains(dstIP) {
return true return true
} }
} }
@ -1166,12 +1318,15 @@ func (s *Server) takeAgentConn(ctx context.Context, n *node) (_ *agentConn, ok b
for { for {
ac, ok := s.takeAgentConnOne(n) ac, ok := s.takeAgentConnOne(n)
if ok { if ok {
log.Printf("got agent conn for %v", n.mac)
return ac, true return ac, true
} }
s.mu.Lock() s.mu.Lock()
ready := make(chan struct{}) ready := make(chan struct{})
mak.Set(&s.agentConnWaiter, n, ready) mak.Set(&s.agentConnWaiter, n, ready)
s.mu.Unlock() s.mu.Unlock()
log.Printf("waiting for agent conn for %v", n.mac)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, false return nil, false
@ -1190,36 +1345,40 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) {
for ac := range s.agentConns { for ac := range s.agentConns {
if ac.node == n { if ac.node == n {
s.agentConns.Delete(ac) s.agentConns.Delete(ac)
log.Printf("XXX takeAgentConnOne HIT for %v", n.mac)
return ac, true return ac, true
} }
} }
log.Printf("XXX takeAgentConnOne MISS for %v", n.mac)
return nil, false return nil, false
} }
func (s *Server) NodeAgentRoundTripper(ctx context.Context, n *Node) http.RoundTripper { func (s *Server) NodeAgentDialer(n *Node) DialFunc {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if rt, ok := s.agentRoundTripper[n.n]; ok { if d, ok := s.agentDialer[n.n]; ok {
return rt return d
} }
d := func(ctx context.Context, network, addr string) (net.Conn, error) {
var rt = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
ac, ok := s.takeAgentConn(ctx, n.n) ac, ok := s.takeAgentConn(ctx, n.n)
if !ok { if !ok {
return nil, ctx.Err() return nil, ctx.Err()
} }
return ac.tc, nil return ac.tc, nil
}, }
mak.Set(&s.agentDialer, n.n, d)
return d
} }
mak.Set(&s.agentRoundTripper, n.n, rt) func (s *Server) NodeAgentRoundTripper(n *Node) http.RoundTripper {
return rt return &http.Transport{
DialContext: s.NodeAgentDialer(n),
}
} }
func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) { func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) {
rt := s.NodeAgentRoundTripper(ctx, n) rt := s.NodeAgentRoundTripper(n)
req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil) req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil)
if err != nil { if err != nil {
return nil, err return nil, err

Loading…
Cancel
Save