From a7c80c332a881a91fd9ef5dfdf3cacd7aceb1966 Mon Sep 17 00:00:00 2001 From: Tom DNetto Date: Thu, 19 Oct 2023 17:07:07 -0700 Subject: [PATCH] cmd/sniproxy: implement support for control configuration, multiple addresses * Implement missing tests for sniproxy * Wire sniproxy to new appc package * Add support to tsnet for routing subnet router traffic into netstack, so it can be handled Updates: https://github.com/tailscale/corp/issues/15038 Signed-off-by: Tom DNetto --- appc/appc.go | 8 +- cmd/sniproxy/sniproxy.go | 419 +++++++----------- cmd/sniproxy/sniproxy_test.go | 186 ++++++++ tstest/integration/testcontrol/testcontrol.go | 12 + 4 files changed, 354 insertions(+), 271 deletions(-) diff --git a/appc/appc.go b/appc/appc.go index 321f4dcf7..d66ece8a3 100644 --- a/appc/appc.go +++ b/appc/appc.go @@ -67,6 +67,7 @@ func (s *Server) Configure(cfg *appctype.AppConnectorConfig) { s.mu.Lock() defer s.mu.Unlock() s.connectors = makeConnectorsFromConfig(cfg) + log.Printf("installed app connector config: %+v", s.connectors) } // HandleTCPFlow implements tsnet.FallbackTCPHandler. @@ -193,8 +194,7 @@ func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (re } func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { - buf := make([]byte, 1500) - resp := dnsmessage.NewBuilder(buf, + resp := dnsmessage.NewBuilder(response, dnsmessage.Header{ ID: req.Header.ID, Response: true, @@ -203,8 +203,8 @@ func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (respon resp.EnableCompression() if len(req.Questions) == 0 { - buf, _ = resp.Finish() - return buf, nil + response, _ = resp.Finish() + return response, nil } q := req.Questions[0] err = resp.StartQuestions() diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index 5be6e5afe..e94566772 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -10,30 +10,34 @@ package main import ( "context" "errors" - "expvar" "flag" "fmt" "log" "net" "net/http" + "net/netip" "os" + "sort" "strconv" "strings" - "time" "github.com/peterbourgon/ff/v3" "golang.org/x/net/dns/dnsmessage" - "inet.af/tcpproxy" + "tailscale.com/appc" "tailscale.com/client/tailscale" "tailscale.com/hostinfo" - "tailscale.com/metrics" - "tailscale.com/net/netutil" + "tailscale.com/ipn" + "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tsweb" + "tailscale.com/types/appctype" + "tailscale.com/types/ipproto" "tailscale.com/types/nettype" - "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" ) +const configCapKey = "tailscale.com/sniproxy" + var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") // portForward is the state for a single port forwarding entry, as passed to the --forward flag. @@ -68,6 +72,7 @@ func parseForward(value string) (*portForward, error) { } func main() { + // Parse flags fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError) var ( ports = fs.String("ports", "443", "comma-separated list of ports to proxy") @@ -77,334 +82,214 @@ func main() { debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") hostname = fs.String("hostname", "", "Hostname to register the service under") ) - err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC")) if err != nil { log.Fatal("ff.Parse") } - if *ports == "" { - log.Fatal("no ports") - } - hostinfo.SetApp("sniproxy") + var ts tsnet.Server + defer ts.Close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards) +} + +// run actually runs the sniproxy. Its separate from main() to assist in testing. +func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) { + // Wire up Tailscale node + app connector server + hostinfo.SetApp("sniproxy") var s server - s.ts.Port = uint16(*wgPort) - s.ts.Hostname = *hostname - defer s.ts.Close() + s.ts = ts + + s.ts.Port = uint16(wgPort) + s.ts.Hostname = hostname lc, err := s.ts.LocalClient() if err != nil { - log.Fatal(err) + log.Fatalf("LocalClient() failed: %v", err) } s.lc = lc - s.initMetrics() - - for _, portStr := range strings.Split(*ports, ",") { - ln, err := s.ts.Listen("tcp", ":"+portStr) - if err != nil { - log.Fatal(err) - } - log.Printf("Serving on port %v ...", portStr) - go s.serve(ln) - } - - for _, forwStr := range strings.Split(*forwards, ",") { - if forwStr == "" { - continue - } - forw, err := parseForward(forwStr) - if err != nil { - log.Fatal(err) - } - - ln, err := s.ts.Listen("tcp", ":"+strconv.Itoa(forw.Port)) - if err != nil { - log.Fatal(err) - } - log.Printf("Serving on port %d to %s...", forw.Port, forw.Destination) - - // Add an entry to the expvar LabelMap for Prometheus metrics, - // and create a clientmetric to report that same value. - service := portNumberToName(forw) - s.numTCPsessions.SetInt64(service, 0) - metric := fmt.Sprintf("sniproxy_tcp_sessions_%s", service) - clientmetric.NewCounterFunc(metric, func() int64 { - return s.numTCPsessions.Get(service).Value() - }) - - go s.forward(ln, forw) - } + s.ts.RegisterFallbackTCPHandler(s.appc.HandleTCPFlow) + // Start special-purpose listeners: dns, http promotion, debug server ln, err := s.ts.Listen("udp", ":53") if err != nil { - log.Fatal(err) + log.Fatalf("failed listening on port 53: %v", err) } + defer ln.Close() go s.serveDNS(ln) - - if *promoteHTTPS { + if promoteHTTPS { ln, err := s.ts.Listen("tcp", ":80") if err != nil { - log.Fatal(err) + log.Fatalf("failed listening on port 80: %v", err) } + defer ln.Close() log.Printf("Promoting HTTP to HTTPS ...") go s.promoteHTTPS(ln) } - - if *debugPort != 0 { + if debugPort != 0 { mux := http.NewServeMux() tsweb.Debugger(mux) - dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort)) + dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort)) if err != nil { - log.Fatal(err) + log.Fatalf("failed listening on debug port: %v", err) } + defer dln.Close() go func() { - log.Fatal(http.Serve(dln, mux)) + log.Fatalf("debug serve: %v", http.Serve(dln, mux)) }() } - select {} -} - -type server struct { - ts tsnet.Server - lc *tailscale.LocalClient - - numTLSsessions expvar.Int - numTCPsessions *metrics.LabelMap - numBadAddrPort expvar.Int - dnsResponses expvar.Int - dnsFailures expvar.Int - httpPromoted expvar.Int -} - -func (s *server) serve(ln net.Listener) { - for { - c, err := ln.Accept() - if err != nil { - log.Fatal(err) - } - go s.serveConn(c) + // Finally, start mainloop to configure app connector based on information + // in the netmap. + // We set the NotifyInitialNetMap flag so we will always get woken with the + // current netmap, before only being woken on changes. + bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) + if err != nil { + log.Fatalf("watching IPN bus: %v", err) } -} - -func (s *server) forward(ln net.Listener, forw *portForward) { + defer bus.Close() for { - c, err := ln.Accept() + msg, err := bus.Next() if err != nil { - log.Fatal(err) + if errors.Is(err, context.Canceled) { + return + } + log.Fatalf("reading IPN bus: %v", err) } - go s.forwardConn(c, forw) - } -} -func (s *server) serveDNS(ln net.Listener) { - for { - c, err := ln.Accept() - if err != nil { - log.Fatal(err) + // NetMap contains app-connector configuration + if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { + sn := nm.SelfNode.AsStruct() + + var c appctype.AppConnectorConfig + nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) + if err != nil { + log.Printf("failed to read app connector configuration from coordination server: %v", err) + } else if len(nmConf) > 0 { + c = nmConf[0] + } + + if c.AdvertiseRoutes { + if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { + log.Printf("failed to advertise routes: %v", err) + } + } + + // Backwards compatibility: combine any configuration from control with flags specified + // on the command line. This is intentionally done after we advertise any routes + // because its never correct to advertise the nodes native IP addresses. + s.mergeConfigFromFlags(&c, ports, forwards) + s.appc.Configure(&c) } - go s.serveDNSConn(c.(nettype.ConnPacketConn)) } } -func (s *server) serveDNSConn(c nettype.ConnPacketConn) { - defer c.Close() - c.SetReadDeadline(time.Now().Add(5 * time.Second)) - buf := make([]byte, 1500) - n, err := c.Read(buf) - if err != nil { - log.Printf("c.Read failed: %v\n ", err) - s.dnsFailures.Add(1) - return - } - - var msg dnsmessage.Message - err = msg.Unpack(buf[:n]) - if err != nil { - log.Printf("dnsmessage unpack failed: %v\n ", err) - s.dnsFailures.Add(1) - return - } - - buf, err = s.dnsResponse(&msg) - if err != nil { - log.Printf("s.dnsResponse failed: %v\n", err) - s.dnsFailures.Add(1) - return - } - - _, err = c.Write(buf) - if err != nil { - log.Printf("c.Write failed: %v\n", err) - s.dnsFailures.Add(1) - return - } - - s.dnsResponses.Add(1) +type server struct { + appc appc.Server + ts *tsnet.Server + lc *tailscale.LocalClient } -func (s *server) serveConn(c net.Conn) { - addrPortStr := c.LocalAddr().String() - _, port, err := net.SplitHostPort(addrPortStr) - if err != nil { - log.Printf("bogus addrPort %q", addrPortStr) - s.numBadAddrPort.Add(1) - c.Close() - return +func (s *server) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { + // Collect the set of addresses to advertise, using a map + // to avoid duplicate entries. + addrs := map[netip.Addr]struct{}{} + for _, c := range c.SNIProxy { + for _, ip := range c.Addrs { + addrs[ip] = struct{}{} + } + } + for _, c := range c.DNAT { + for _, ip := range c.Addrs { + addrs[ip] = struct{}{} + } } - var dialer net.Dialer - dialer.Timeout = 5 * time.Second - - var p tcpproxy.Proxy - p.ListenFunc = func(net, laddr string) (net.Listener, error) { - return netutil.NewOneConnListener(c, nil), nil + var routes []netip.Prefix + for a := range addrs { + routes = append(routes, netip.PrefixFrom(a, a.BitLen())) } - p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) { - s.numTLSsessions.Add(1) - return &tcpproxy.DialProxy{ - Addr: net.JoinHostPort(sniName, port), - DialContext: dialer.DialContext, - }, true + sort.SliceStable(routes, func(i, j int) bool { + return routes[i].Addr().Less(routes[j].Addr()) // determinism r us }) - p.Start() -} -// portNumberToName returns a human-readable name for several port numbers commonly forwarded, -// and "tcp###" for everything else. It is used for metric label names. -func portNumberToName(forw *portForward) string { - switch forw.Port { - case 22: - return "ssh" - case 1433: - return "sqlserver" - case 3306: - return "mysql" - case 3389: - return "rdp" - case 5432: - return "postgres" - default: - return fmt.Sprintf("%s%d", forw.Proto, forw.Port) - } + _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AdvertiseRoutes: routes, + }, + AdvertiseRoutesSet: true, + }) + return err } -// forwardConn sets up a forwarder for a TCP connection. It does not inspect of the data -// like the SNI forwarding does, it merely forwards all data to the destination specified -// in the --forward=tcp/22/github.com argument. -func (s *server) forwardConn(c net.Conn, forw *portForward) { - addrPortStr := c.LocalAddr().String() - - var dialer net.Dialer - dialer.Timeout = 30 * time.Second +func (s *server) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) { + ip4, ip6 := s.ts.TailscaleIPs() - var p tcpproxy.Proxy - p.ListenFunc = func(net, laddr string) (net.Listener, error) { - return netutil.NewOneConnListener(c, nil), nil + sniConfigFromFlags := appctype.SNIProxyConfig{ + Addrs: []netip.Addr{ip4, ip6}, } - - dial := &tcpproxy.DialProxy{ - Addr: fmt.Sprintf("%s:%d", forw.Destination, forw.Port), - DialContext: dialer.DialContext, + if ports != "" { + for _, portStr := range strings.Split(ports, ",") { + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + log.Fatalf("invalid port: %s", portStr) + } + sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{ + Proto: int(ipproto.TCP), + Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)}, + }) + } } - p.AddRoute(addrPortStr, dial) - s.numTCPsessions.Add(portNumberToName(forw), 1) - p.Start() -} + var forwardConfigFromFlags []appctype.DNATConfig + for _, forwStr := range strings.Split(forwards, ",") { + if forwStr == "" { + continue + } + forw, err := parseForward(forwStr) + if err != nil { + log.Printf("invalid forwarding spec: %v", err) + continue + } -func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) { - resp := dnsmessage.NewBuilder(buf, - dnsmessage.Header{ - ID: req.Header.ID, - Response: true, - Authoritative: true, + forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{ + Addrs: []netip.Addr{ip4, ip6}, + To: []string{forw.Destination}, + IP: []tailcfg.ProtoPortRange{ + { + Proto: int(ipproto.TCP), + Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)}, + }, + }, }) - resp.EnableCompression() - - if len(req.Questions) == 0 { - buf, _ = resp.Finish() - return } - q := req.Questions[0] - err = resp.StartQuestions() - if err != nil { - return + if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 { + return // no config specified on the command line } - resp.Question(q) - ip4, ip6 := s.ts.TailscaleIPs() - err = resp.StartAnswers() - if err != nil { - return - } - - switch q.Type { - case dnsmessage.TypeAAAA: - err = resp.AAAAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AAAAResource{AAAA: ip6.As16()}, - ) - - case dnsmessage.TypeA: - err = resp.AResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AResource{A: ip4.As4()}, - ) - case dnsmessage.TypeSOA: - err = resp.SOAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, - Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, - ) - case dnsmessage.TypeNS: - err = resp.NSResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.NSResource{NS: tsMBox}, - ) + mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags) + for i, forward := range forwardConfigFromFlags { + mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward) } +} - if err != nil { - return +func (s *server) serveDNS(ln net.Listener) { + for { + c, err := ln.Accept() + if err != nil { + log.Printf("serveDNS accept: %v", err) + return + } + go s.appc.HandleDNS(c.(nettype.ConnPacketConn)) } - - return resp.Finish() } func (s *server) promoteHTTPS(ln net.Listener) { err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - s.httpPromoted.Add(1) http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) })) log.Fatalf("promoteHTTPS http.Serve: %v", err) } - -// initMetrics sets up local prometheus metrics, and creates clientmetrics to report those -// same counters. -func (s *server) initMetrics() { - stats := new(metrics.Set) - - stats.Set("tls_sessions", &s.numTLSsessions) - clientmetric.NewCounterFunc("sniproxy_tls_sessions", s.numTLSsessions.Value) - - s.numTCPsessions = &metrics.LabelMap{Label: "proto"} - stats.Set("tcp_sessions", s.numTCPsessions) - // clientmetric doesn't have a good way to implement a Map type. - // We create clientmetrics dynamically when parsing the --forwards argument - - stats.Set("bad_addrport", &s.numBadAddrPort) - clientmetric.NewCounterFunc("sniproxy_bad_addrport", s.numBadAddrPort.Value) - - stats.Set("dns_responses", &s.dnsResponses) - clientmetric.NewCounterFunc("sniproxy_dns_responses", s.dnsResponses.Value) - - stats.Set("dns_failed", &s.dnsFailures) - clientmetric.NewCounterFunc("sniproxy_dns_failed", s.dnsFailures.Value) - - stats.Set("http_promoted", &s.httpPromoted) - clientmetric.NewCounterFunc("sniproxy_http_promoted", s.httpPromoted.Value) - - expvar.Publish("sniproxy", stats) -} diff --git a/cmd/sniproxy/sniproxy_test.go b/cmd/sniproxy/sniproxy_test.go index 15cc2ec21..1e9396cf1 100644 --- a/cmd/sniproxy/sniproxy_test.go +++ b/cmd/sniproxy/sniproxy_test.go @@ -4,10 +4,30 @@ package main import ( + "context" + "encoding/json" + "flag" + "fmt" + "net" + "net/http/httptest" + "net/netip" + "os" + "path/filepath" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" + "tailscale.com/ipn/store/mem" + "tailscale.com/net/netns" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/appctype" + "tailscale.com/types/ipproto" + "tailscale.com/types/key" + "tailscale.com/types/logger" ) func TestPortForwardingArguments(t *testing.T) { @@ -35,3 +55,169 @@ func TestPortForwardingArguments(t *testing.T) { } } } + +var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs") +var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs") + +func startControl(t *testing.T) (control *testcontrol.Server, controlURL string) { + // Corp#4520: don't use netns for tests. + netns.SetEnabled(false) + t.Cleanup(func() { + netns.SetEnabled(true) + }) + + derpLogf := logger.Discard + if *verboseDERP { + derpLogf = t.Logf + } + derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1") + control = &testcontrol.Server{ + DERPMap: derpMap, + DNSConfig: &tailcfg.DNSConfig{ + Proxied: true, + }, + MagicDNSDomain: "tail-scale.ts.net", + } + control.HTTPTestServer = httptest.NewUnstartedServer(control) + control.HTTPTestServer.Start() + t.Cleanup(control.HTTPTestServer.Close) + controlURL = control.HTTPTestServer.URL + t.Logf("testcontrol listening on %s", controlURL) + return control, controlURL +} + +func startNode(t *testing.T, ctx context.Context, controlURL, hostname string) (*tsnet.Server, key.NodePublic, netip.Addr) { + t.Helper() + + tmp := filepath.Join(t.TempDir(), hostname) + os.MkdirAll(tmp, 0755) + s := &tsnet.Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: hostname, + Store: new(mem.Store), + Ephemeral: true, + } + if !*verboseNodes { + s.Logf = logger.Discard + } + t.Cleanup(func() { s.Close() }) + + status, err := s.Up(ctx) + if err != nil { + t.Fatal(err) + } + return s, status.Self.PublicKey, status.TailscaleIPs[0] +} + +func TestSNIProxyWithNetmapConfig(t *testing.T) { + c, controlURL := startControl(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Create a listener to proxy connections to. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + // Start sniproxy + sni, nodeKey, ip := startNode(t, ctx, controlURL, "snitest") + go run(ctx, sni, 0, sni.Hostname, false, 0, "", "") + + // Configure the mock coordination server to send down app connector config. + config := &appctype.AppConnectorConfig{ + DNAT: map[appctype.ConfigID]appctype.DNATConfig{ + "nic_test": { + Addrs: []netip.Addr{ip}, + To: []string{"127.0.0.1"}, + IP: []tailcfg.ProtoPortRange{ + { + Proto: int(ipproto.TCP), + Ports: tailcfg.PortRange{First: uint16(ln.Addr().(*net.TCPAddr).Port), Last: uint16(ln.Addr().(*net.TCPAddr).Port)}, + }, + }, + }, + }, + } + b, err := json.Marshal(config) + if err != nil { + t.Fatal(err) + } + c.SetNodeCapMap(nodeKey, tailcfg.NodeCapMap{ + configCapKey: []tailcfg.RawMessage{tailcfg.RawMessage(b)}, + }) + + // Lets spin up a second node (to represent the client). + client, _, _ := startNode(t, ctx, controlURL, "client") + + // Make sure that the sni node has received its config. + l, err := sni.LocalClient() + if err != nil { + t.Fatal(err) + } + gotConfigured := false + for i := 0; i < 100; i++ { + s, err := l.StatusWithoutPeers(ctx) + if err != nil { + t.Fatal(err) + } + if len(s.Self.CapMap) > 0 { + gotConfigured = true + break // we got it + } + time.Sleep(10 * time.Millisecond) + } + if !gotConfigured { + t.Error("sni node never received its configuration from the coordination server!") + } + + // Lets make the client open a connection to the sniproxy node, and + // make sure it results in a connection to our test listener. + w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port)) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + r, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + r.Close() +} + +func TestSNIProxyWithFlagConfig(t *testing.T) { + _, controlURL := startControl(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Create a listener to proxy connections to. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + // Start sniproxy + sni, _, ip := startNode(t, ctx, controlURL, "snitest") + go run(ctx, sni, 0, sni.Hostname, false, 0, "", fmt.Sprintf("tcp/%d/localhost", ln.Addr().(*net.TCPAddr).Port)) + + // Lets spin up a second node (to represent the client). + client, _, _ := startNode(t, ctx, controlURL, "client") + + // Lets make the client open a connection to the sniproxy node, and + // make sure it results in a connection to our test listener. + w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port)) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + r, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + r.Close() +} diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 6998956e7..8a3e3604d 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -75,6 +75,9 @@ type Server struct { // masquerade address to use for that peer. masquerades map[key.NodePublic]map[key.NodePublic]netip.Addr // node => peer => SelfNodeV{4,6}MasqAddrForThisPeer IP + // nodeCapMaps overrides the capability map sent down to a client. + nodeCapMaps map[key.NodePublic]tailcfg.NodeCapMap + // suppressAutoMapResponses is the set of nodes that should not be sent // automatic map responses from serveMap. (They should only get manually sent ones) suppressAutoMapResponses set.Set[key.NodePublic] @@ -369,6 +372,14 @@ func (s *Server) SetMasqueradeAddresses(pairs []MasqueradePair) { s.updateLocked("SetMasqueradeAddresses", s.nodeIDsLocked(0)) } +// SetNodeCapMap overrides the capability map the specified client receives. +func (s *Server) SetNodeCapMap(nodeKey key.NodePublic, capMap tailcfg.NodeCapMap) { + s.mu.Lock() + defer s.mu.Unlock() + mak.Set(&s.nodeCapMaps, nodeKey, capMap) + s.updateLocked("SetNodeCapMap", s.nodeIDsLocked(0)) +} + // nodeIDsLocked returns the node IDs of all nodes in the server, except // for the node with the given ID. func (s *Server) nodeIDsLocked(except tailcfg.NodeID) []tailcfg.NodeID { @@ -881,6 +892,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, // node key rotated away (once test server supports that) return nil, nil } + node.CapMap = s.nodeCapMaps[nk] node.Capabilities = append(node.Capabilities, tailcfg.NodeAttrDisableUPnP) user, _ := s.getUser(nk)