diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index 8804fcb5c..b389c93e7 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -107,6 +107,15 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK // If a cert is expired, or expires sooner than minValidity, it will be renewed // synchronously. Otherwise it will be renewed asynchronously. func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string, minValidity time.Duration) (*TLSCertKeyPair, error) { + b.mu.Lock() + getCertForTest := b.getCertForTest + b.mu.Unlock() + + if getCertForTest != nil { + testenv.AssertInTest() + return getCertForTest(domain) + } + if !validLookingCertDomain(domain) { return nil, errors.New("invalid domain") } @@ -303,6 +312,16 @@ func (b *LocalBackend) getCertStore() (certStore, error) { return certFileStore{dir: dir, testRoots: testX509Roots}, nil } +// ConfigureCertsForTest sets a certificate retrieval function to be used by +// this local backend, skipping the usual ACME certificate registration. Should +// only be used in tests. +func (b *LocalBackend) ConfigureCertsForTest(getCert func(hostname string) (*TLSCertKeyPair, error)) { + testenv.AssertInTest() + b.mu.Lock() + b.getCertForTest = getCert + b.mu.Unlock() +} + // certFileStore implements certStore by storing the cert & key files in the named directory. type certFileStore struct { dir string diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index cebb96130..147be08eb 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -399,6 +399,10 @@ type LocalBackend struct { // hardwareAttested is whether backend should use a hardware-backed key to // bind the node identity to this device. hardwareAttested atomic.Bool + + // getCertForTest is used to retrieve TLS certificates in tests. + // See [LocalBackend.ConfigureCertsForTest]. + getCertForTest func(hostname string) (*TLSCertKeyPair, error) } // SetHardwareAttested enables hardware attestation key signatures in map diff --git a/tsnet/example/tsnet-services/tsnet-services.go b/tsnet/example/tsnet-services/tsnet-services.go new file mode 100644 index 000000000..9aaa2ebba --- /dev/null +++ b/tsnet/example/tsnet-services/tsnet-services.go @@ -0,0 +1,82 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tsnet-services example demonstrates how to use tsnet with Services. +// +// To run this example yourself: +// +// 1. Add access controls which (i) define a new ACL tag, (ii) allow the demo +// node to host the Service, and (iii) allow peers on the tailnet to reach +// the Service. A sample ACL policy is provided below. +// +// 2. Generate an auth key using the Tailscale admin panel. When doing so, add +// your new tag to your key (Service hosts must be tagged nodes). +// +// https://tailscale.com/kb/1085/auth-keys#generate-an-auth-key +// +// 3. Define a Service. For the purposes of this demo, it must be defined to +// listen on TCP port 443. Note that you only need to follow Step 1 in the +// following document. +// +// https://tailscale.com/kb/1552/tailscale-services#step-1-define-a-tailscale-service +// +// 4. Run the demo on the command line: +// TS_AUTHKEY= go run tsnet-services.go -service +// +// Sample ACL policy for step 1: +// +// "tagOwners": { +// "tag:tsnet-demo-host": ["autogroup:member"], +// }, +// "autoApprovers": { +// "services": { +// "svc:tsnet-demo": ["tag:tsnet-demo-host"], +// }, +// }, +// "grants": [ +// "src": ["*"], +// "dst": ["tag:tsnet-demo-host"], +// "ip": ["*"], +// ], +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + + "tailscale.com/tsnet" +) + +var ( + svcName = flag.String("service", "", "the name of your Service, e.g. svc:demo-service") +) + +func main() { + flag.Parse() + if *svcName == "" { + log.Fatal("a Service name must be provided") + } + + s := &tsnet.Server{ + Hostname: "tsnet-services-demo", + } + defer s.Close() + + ln, err := s.ListenService(*svcName, tsnet.ServiceModeHTTP{ + HTTPS: true, + Port: 443, + }) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + log.Printf("Listening on https://%v\n", ln.FQDN) + + err = http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "

Hello, tailnet!

") + })) + log.Fatal(err) +} diff --git a/tsnet/example_tsnet_listen_service_multiple_ports_test.go b/tsnet/example_tsnet_listen_service_multiple_ports_test.go new file mode 100644 index 000000000..2b8f01526 --- /dev/null +++ b/tsnet/example_tsnet_listen_service_multiple_ports_test.go @@ -0,0 +1,66 @@ +package tsnet_test + +import ( + "fmt" + "log" + "net/http" + _ "net/http/pprof" + "strings" + + "tailscale.com/tsnet" +) + +// This example function is in a separate file for the "net/http/pprof" import. + +// ExampleServer_ListenService_multiplePorts demonstrates how to advertise a +// Service on multiple ports. In this example, we run an HTTPS server on 443 and +// an HTTP server handling pprof requests to the same runtime on 6060. +func ExampleServer_ListenService_multiplePorts() { + s := &tsnet.Server{ + Hostname: "tsnet-services-demo", + } + defer s.Close() + + ln, err := s.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ + HTTPS: true, + Port: 443, + }) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + pprofLn, err := s.ListenService("svc:my-service", tsnet.ServiceModeTCP{ + Port: 6060, + }) + if err != nil { + log.Fatal(err) + } + defer pprofLn.Close() + + go func() { + log.Printf("Listening for pprof requests on http://%v:%d\n", pprofLn.FQDN, 6060) + + handler := func(w http.ResponseWriter, r *http.Request) { + // The pprof listener is separate from our main server, so we can + // allow users to leave off the /debug/pprof prefix. We'll just + // attach it here, then pass along to the pprof handlers, which have + // been added implicitly due to our import of net/http/pprof. + if !strings.HasPrefix("/debug/pprof", r.URL.Path) { + r.URL.Path = "/debug/pprof" + r.URL.Path + } + http.DefaultServeMux.ServeHTTP(w, r) + } + if err := http.Serve(pprofLn, http.HandlerFunc(handler)); err != nil { + log.Fatal("error serving pprof:", err) + } + }() + + log.Printf("Listening on https://%v\n", ln.FQDN) + + // Specifying a handler here means pprof endpoints will not be served by + // this server (since we are not using http.DefaultServeMux). + log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "

Hello, tailnet!

") + }))) +} diff --git a/tsnet/example_tsnet_test.go b/tsnet/example_tsnet_test.go index c5a20ab77..2a3236b3b 100644 --- a/tsnet/example_tsnet_test.go +++ b/tsnet/example_tsnet_test.go @@ -8,6 +8,8 @@ import ( "fmt" "log" "net/http" + "net/http/httputil" + "net/url" "os" "path/filepath" @@ -200,3 +202,56 @@ func ExampleServer_ListenFunnel_funnelOnly() { fmt.Fprintln(w, "Hi there! Welcome to the tailnet!") }))) } + +// ExampleServer_ListenService demonstrates how to advertise an HTTPS Service. +func ExampleServer_ListenService() { + s := &tsnet.Server{ + Hostname: "tsnet-services-demo", + } + defer s.Close() + + ln, err := s.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ + HTTPS: true, + Port: 443, + }) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + log.Printf("Listening on https://%v\n", ln.FQDN) + log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "

Hello, tailnet!

") + }))) +} + +// ExampleServer_ListenService_reverseProxy demonstrates how to advertise a +// Service targeting a reverse proxy. This is useful when the backing server is +// external to the tsnet application. +func ExampleServer_ListenService_reverseProxy() { + // targetAddress represents the address of the backing server. + const targetAddress = "1.2.3.4:80" + + // We will use a reverse proxy to direct traffic to the backing server. + reverseProxy := httputil.NewSingleHostReverseProxy(&url.URL{ + Scheme: "http", + Host: targetAddress, + }) + + s := &tsnet.Server{ + Hostname: "tsnet-services-demo", + } + defer s.Close() + + ln, err := s.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ + HTTPS: true, + Port: 443, + }) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + log.Printf("Listening on https://%v\n", ln.FQDN) + log.Fatal(http.Serve(ln, reverseProxy)) +} diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 9efad32b3..fe353b58c 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -52,6 +52,7 @@ import ( "tailscale.com/net/proxymux" "tailscale.com/net/socks5" "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/types/bools" "tailscale.com/types/logger" @@ -158,8 +159,6 @@ type Server struct { // that the control server will allow the node to adopt that tag. AdvertiseTags []string - getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error) - initOnce sync.Once initErr error lb *ipnlocal.LocalBackend @@ -1106,9 +1105,6 @@ func (s *Server) RegisterFallbackTCPHandler(cb FallbackTCPHandler) func() { // It calls GetCertificate on the localClient, passing in the ClientHelloInfo. // For testing, if s.getCertForTesting is set, it will call that instead. func (s *Server) getCert(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { - if s.getCertForTesting != nil { - return s.getCertForTesting(hi) - } lc, err := s.LocalClient() if err != nil { return nil, err @@ -1259,6 +1255,266 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L return tls.NewListener(ln, tlsConfig), nil } +// ServiceMode defines how a Service is run. Currently supported modes are: +// - [ServiceModeTCP] +// - [ServiceModeHTTP] +// +// For more information, see [Server.ListenService]. +type ServiceMode interface { + // serviceMode is a no-op used to identify a type as a ServiceMode. + serviceMode() + + // network is the network this Service will advertise on. Per Go convention, + // this should be lowercase, e.g. 'tcp'. + network() string +} + +// serviceModeWithPort is a convenience type to extract the port from +// ServiceMode types which have one. +type serviceModeWithPort interface { + ServiceMode + port() uint16 +} + +// ServiceModeTCP is used to configure a TCP Service via [Server.ListenService]. +type ServiceModeTCP struct { + // Port is the TCP port to advertise. If this Service needs to advertise + // multiple ports, call ListenService multiple times. + Port uint16 + + // TerminateTLS means that TLS connections will be terminated before being + // forwarded to the listener. In this case, the only server name indicator + // (SNI) permitted is the Service's fully-qualified domain name. + TerminateTLS bool + + // PROXYProtocolVersion indicates whether to send a PROXY protocol header + // before forwarding the connection to the listener and which version of the + // protocol to use. + // + // For more information, see + // https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt + PROXYProtocolVersion int +} + +func (ServiceModeTCP) serviceMode() {} + +func (ServiceModeTCP) network() string { return "tcp" } + +func (m ServiceModeTCP) port() uint16 { return m.Port } + +// ServiceModeHTTP is used to configure an HTTP Service via +// [Server.ListenService]. +type ServiceModeHTTP struct { + // Port is the TCP port to advertise. If this Service needs to advertise + // multiple ports, call ListenService multiple times. + Port uint16 + + // HTTPS, if true, means that the listener should handle connections as + // HTTPS connections. In this case, the only server name indicator (SNI) + // permitted is the Service's fully-qualified domain name. + HTTPS bool + + // AcceptAppCaps defines the app capabilities to forward to the server. The + // keys in this map are the mount points for each set of capabilities. + // + // By example, + // + // AcceptAppCaps: map[string][]string{ + // "/": {"example.com/cap/all-paths"}, + // "/foo": {"example.com/cap/all-paths", "example.com/cap/foo"}, + // } + // + // would forward `example.com/cap/all-paths` to all paths on the server and + // `example.com/cap/foo` only to paths beginning with /foo. + // + // For more information on app capabilities, see + // https://tailscale.com/kb/1537/grants-app-capabilities + AcceptAppCaps map[string][]string + + // PROXYProtocolVersion indicates whether to send a PROXY protocol header + // before forwarding the connection to the listener and which version of the + // protocol to use. + // + // For more information, see + // https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt + PROXYProtocol int +} + +func (ServiceModeHTTP) serviceMode() {} + +func (ServiceModeHTTP) network() string { return "tcp" } + +func (m ServiceModeHTTP) port() uint16 { return m.Port } + +func (m ServiceModeHTTP) capsMap() map[string][]tailcfg.PeerCapability { + capsMap := map[string][]tailcfg.PeerCapability{} + for path, capNames := range m.AcceptAppCaps { + caps := make([]tailcfg.PeerCapability, 0, len(capNames)) + for _, c := range capNames { + caps = append(caps, tailcfg.PeerCapability(c)) + } + capsMap[path] = caps + } + return capsMap +} + +// A ServiceListener is a network listener for a Tailscale Service. For more +// information about Services, see +// https://tailscale.com/kb/1552/tailscale-services +type ServiceListener struct { + net.Listener + addr addr + + // FQDN is the fully-qualifed domain name of this Service. + FQDN string +} + +// Addr returns the listener's network address. This will be the Service's +// fully-qualified domain name (FQDN) and the port. +// +// A hostname is not truly a network address, but Services listen on multiple +// addresses (the IPv4 and IPv6 virtual IPs). +func (sl ServiceListener) Addr() net.Addr { + return sl.addr +} + +// ErrUntaggedServiceHost is returned by ListenService when run on a node +// without any ACL tags. A node must use a tag-based identity to act as a +// Service host. For more information, see: +// https://tailscale.com/kb/1552/tailscale-services#prerequisites +var ErrUntaggedServiceHost = errors.New("service hosts must be tagged nodes") + +// ListenService creates a network listener for a Tailscale Service. This will +// advertise this node as hosting the Service. Note that: +// - Approval must still be granted by an admin or by ACL auto-approval rules. +// - Service hosts must be tagged nodes. +// - A valid Service host must advertise all ports defined for the Service. +// +// To advertise a Service with multiple ports, run ListenService multiple times. +// For more information about Services, see +// https://tailscale.com/kb/1552/tailscale-services +func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, error) { + if err := tailcfg.ServiceName(name).Validate(); err != nil { + return nil, err + } + if mode == nil { + return nil, errors.New("mode may not be nil") + } + svcName := name + + // TODO(hwh33,tailscale/corp#35859): support TUN mode + + ctx := context.Background() + _, err := s.Up(ctx) + if err != nil { + return nil, err + } + + st := s.lb.StatusWithoutPeers() + if st.Self.Tags == nil || st.Self.Tags.Len() == 0 { + return nil, ErrUntaggedServiceHost + } + + advertisedServices := s.lb.Prefs().AdvertiseServices().AsSlice() + if !slices.Contains(advertisedServices, svcName) { + // TODO(hwh33,tailscale/corp#35860): clean these prefs up when (a) we + // exit early due to error or (b) when the returned listener is closed. + _, err = s.lb.EditPrefs(&ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: append(advertisedServices, svcName), + }, + }) + if err != nil { + return nil, fmt.Errorf("updating advertised Services: %w", err) + } + } + + srvConfig := new(ipn.ServeConfig) + sc, srvConfigETag, err := s.lb.ServeConfigETag() + if err != nil { + return nil, fmt.Errorf("fetching current serve config: %w", err) + } + if sc.Valid() { + srvConfig = sc.AsStruct() + } + + fqdn := tailcfg.ServiceName(svcName).WithoutPrefix() + "." + st.CurrentTailnet.MagicDNSSuffix + + // svcAddr is used to implement Addr() on the returned listener. + svcAddr := addr{ + network: mode.network(), + // A hostname is not a network address, but Services listen on + // multiple addresses (the IPv4 and IPv6 virtual IPs), and there's + // no clear winner here between the two. Therefore prefer the FQDN. + // + // In the case of TCP or HTTP Services, the port will be added below. + addr: fqdn, + } + if m, ok := mode.(serviceModeWithPort); ok { + if m.port() == 0 { + return nil, errors.New("must specify a port to advertise") + } + svcAddr.addr += ":" + strconv.Itoa(int(m.port())) + } + + // Start listening on a local TCP socket. + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, fmt.Errorf("starting local listener: %w", err) + } + + switch m := mode.(type) { + case ServiceModeTCP: + // Forward all connections from service-hostname:port to our socket. + srvConfig.SetTCPForwardingForService( + m.Port, ln.Addr().String(), m.TerminateTLS, + tailcfg.ServiceName(svcName), m.PROXYProtocolVersion, st.CurrentTailnet.MagicDNSSuffix) + case ServiceModeHTTP: + // For HTTP Services, proxy all connections to our socket. + mds := st.CurrentTailnet.MagicDNSSuffix + haveRootHandler := false + // We need to add a separate proxy for each mount point in the caps map. + for path, caps := range m.capsMap() { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + h := ipn.HTTPHandler{ + AcceptAppCaps: caps, + Proxy: ln.Addr().String(), + } + if path == "/" { + haveRootHandler = true + } else { + h.Proxy += path + } + srvConfig.SetWebHandler(&h, svcName, m.Port, path, m.HTTPS, mds) + } + // We always need a root handler. + if !haveRootHandler { + h := ipn.HTTPHandler{Proxy: ln.Addr().String()} + srvConfig.SetWebHandler(&h, svcName, m.Port, "/", m.HTTPS, mds) + } + default: + ln.Close() + return nil, fmt.Errorf("unknown ServiceMode type %T", m) + } + + if err := s.lb.SetServeConfig(srvConfig, srvConfigETag); err != nil { + ln.Close() + return nil, err + } + + // TODO(hwh33,tailscale/corp#35860): clean up state (advertising prefs, + // serve config changes) when the returned listener is closed. + + return &ServiceListener{ + Listener: ln, + FQDN: fqdn, + addr: svcAddr, + }, nil +} + type listenOn string const ( @@ -1420,7 +1676,12 @@ func (ln *listener) Accept() (net.Conn, error) { } } -func (ln *listener) Addr() net.Addr { return addr{ln} } +func (ln *listener) Addr() net.Addr { + return addr{ + network: ln.keys[0].network, + addr: ln.addr, + } +} func (ln *listener) Close() error { ln.s.mu.Lock() @@ -1460,10 +1721,12 @@ func (ln *listener) handle(c net.Conn) { // Server returns the tsnet Server associated with the listener. func (ln *listener) Server() *Server { return ln.s } -type addr struct{ ln *listener } +type addr struct { + network, addr string +} -func (a addr) Network() string { return a.ln.keys[0].network } -func (a addr) String() string { return a.ln.addr } +func (a addr) Network() string { return a.network } +func (a addr) String() string { return a.addr } // cleanupListener wraps a net.Listener with a function to be run on Close. type cleanupListener struct { diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index af8fa765d..93685993b 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -14,6 +14,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/json" + "encoding/pem" "errors" "flag" "fmt" @@ -28,6 +29,7 @@ import ( "path/filepath" "reflect" "runtime" + "slices" "strings" "sync" "sync/atomic" @@ -38,10 +40,12 @@ import ( dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" "golang.org/x/net/proxy" + "tailscale.com/client/local" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" + "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/net/netns" "tailscale.com/tailcfg" @@ -51,6 +55,8 @@ import ( "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/views" + "tailscale.com/util/mak" "tailscale.com/util/must" ) @@ -136,7 +142,7 @@ func startControl(t *testing.T) (controlURL string, control *testcontrol.Server) type testCertIssuer struct { mu sync.Mutex - certs map[string]*tls.Certificate + certs map[string]ipnlocal.TLSCertKeyPair // keyed by hostname root *x509.Certificate rootKey *ecdsa.PrivateKey @@ -168,18 +174,18 @@ func newCertIssuer() *testCertIssuer { panic(err) } return &testCertIssuer{ - certs: make(map[string]*tls.Certificate), root: rootCA, rootKey: rootKey, + certs: map[string]ipnlocal.TLSCertKeyPair{}, } } -func (tci *testCertIssuer) getCert(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (tci *testCertIssuer) getCert(hostname string) (*ipnlocal.TLSCertKeyPair, error) { tci.mu.Lock() defer tci.mu.Unlock() - cert, ok := tci.certs[chi.ServerName] + cert, ok := tci.certs[hostname] if ok { - return cert, nil + return &cert, nil } certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -188,7 +194,7 @@ func (tci *testCertIssuer) getCert(chi *tls.ClientHelloInfo) (*tls.Certificate, } certTmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), - DNSNames: []string{chi.ServerName}, + DNSNames: []string{hostname}, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour), } @@ -196,12 +202,22 @@ func (tci *testCertIssuer) getCert(chi *tls.ClientHelloInfo) (*tls.Certificate, if err != nil { return nil, err } - cert = &tls.Certificate{ - Certificate: [][]byte{certDER, tci.root.Raw}, - PrivateKey: certPrivKey, + keyDER, err := x509.MarshalPKCS8PrivateKey(certPrivKey) + if err != nil { + return nil, err } - tci.certs[chi.ServerName] = cert - return cert, nil + cert = ipnlocal.TLSCertKeyPair{ + CertPEM: pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }), + KeyPEM: pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: keyDER, + }), + } + tci.certs[hostname] = cert + return &cert, nil } func (tci *testCertIssuer) Pool() *x509.CertPool { @@ -218,12 +234,11 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) tmp := filepath.Join(t.TempDir(), hostname) os.MkdirAll(tmp, 0755) s := &Server{ - Dir: tmp, - ControlURL: controlURL, - Hostname: hostname, - Store: new(mem.Store), - Ephemeral: true, - getCertForTesting: testCertRoot.getCert, + Dir: tmp, + ControlURL: controlURL, + Hostname: hostname, + Store: new(mem.Store), + Ephemeral: true, } if *verboseNodes { s.Logf = t.Logf @@ -234,6 +249,8 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) if err != nil { t.Fatal(err) } + s.lb.ConfigureCertsForTest(testCertRoot.getCert) + return s, status.TailscaleIPs[0], status.Self.PublicKey } @@ -259,12 +276,11 @@ func TestDialBlocks(t *testing.T) { tmp := filepath.Join(t.TempDir(), "s2") os.MkdirAll(tmp, 0755) s2 := &Server{ - Dir: tmp, - ControlURL: controlURL, - Hostname: "s2", - Store: new(mem.Store), - Ephemeral: true, - getCertForTesting: testCertRoot.getCert, + Dir: tmp, + ControlURL: controlURL, + Hostname: "s2", + Store: new(mem.Store), + Ephemeral: true, } if *verboseNodes { s2.Logf = log.Printf @@ -842,6 +858,366 @@ func TestFunnelClose(t *testing.T) { }) } +func TestListenService(t *testing.T) { + // First test an error case which doesn't require all of the fancy setup. + t.Run("untagged_node_error", func(t *testing.T) { + ctx := t.Context() + + controlURL, _ := startControl(t) + serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") + + ln, err := serviceHost.ListenService("svc:foo", ServiceModeTCP{Port: 8080}) + if ln != nil { + ln.Close() + } + if !errors.Is(err, ErrUntaggedServiceHost) { + t.Fatalf("expected %v, got %v", ErrUntaggedServiceHost, err) + } + }) + + // Now on to the fancier tests. + + type dialFn func(context.Context, string, string) (net.Conn, error) + + // TCP helpers + acceptAndEcho := func(t *testing.T, ln net.Listener) { + t.Helper() + conn, err := ln.Accept() + if err != nil { + t.Error("accept error:", err) + return + } + defer conn.Close() + if _, err := io.Copy(conn, conn); err != nil { + t.Error("copy error:", err) + } + } + assertEcho := func(t *testing.T, conn net.Conn) { + t.Helper() + msg := "echo" + buf := make([]byte, 1024) + if _, err := conn.Write([]byte(msg)); err != nil { + t.Fatal("write failed:", err) + } + n, err := conn.Read(buf) + if err != nil { + t.Fatal("read failed:", err) + } + got := string(buf[:n]) + if got != msg { + t.Fatalf("unexpected response:\n\twant: %s\n\tgot: %s", msg, got) + } + } + + // HTTP helpers + checkAndEcho := func(t *testing.T, ln net.Listener, check func(r *http.Request)) { + t.Helper() + if check == nil { + check = func(*http.Request) {} + } + http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + check(r) + if _, err := io.Copy(w, r.Body); err != nil { + t.Error("copy error:", err) + w.WriteHeader(http.StatusInternalServerError) + } + })) + } + assertEchoHTTP := func(t *testing.T, hostname, path string, dial dialFn) { + t.Helper() + c := http.Client{ + Transport: &http.Transport{ + DialContext: dial, + }, + } + msg := "echo" + resp, err := c.Post("http://"+hostname+path, "text/plain", strings.NewReader(msg)) + if err != nil { + t.Fatal("posting request:", err) + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal("reading body:", err) + } + got := string(b) + if got != msg { + t.Fatalf("unexpected response:\n\twant: %s\n\tgot: %s", msg, got) + } + } + + tests := []struct { + name string + + // modes is used as input to [Server.ListenService]. + // + // If this slice has multiple modes, then ListenService will be invoked + // multiple times. The number of listeners provided to the run function + // (below) will always match the number of elements in this slice. + modes []ServiceMode + + extraSetup func(t *testing.T, control *testcontrol.Server) + + // run executes the test. This function does not need to close any of + // the input resources, but it should close any new resources it opens. + // listeners[i] corresponds to inputs[i]. + run func(t *testing.T, listeners []*ServiceListener, peer *Server) + }{ + { + name: "basic_TCP", + modes: []ServiceMode{ + ServiceModeTCP{Port: 99}, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + go acceptAndEcho(t, listeners[0]) + + target := fmt.Sprintf("%s:%d", listeners[0].FQDN, 99) + conn := must.Get(peer.Dial(t.Context(), "tcp", target)) + defer conn.Close() + + assertEcho(t, conn) + }, + }, + { + name: "TLS_terminated_TCP", + modes: []ServiceMode{ + ServiceModeTCP{ + TerminateTLS: true, + Port: 443, + }, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + go acceptAndEcho(t, listeners[0]) + + target := fmt.Sprintf("%s:%d", listeners[0].FQDN, 443) + conn := must.Get(peer.Dial(t.Context(), "tcp", target)) + defer conn.Close() + + assertEcho(t, tls.Client(conn, &tls.Config{ + ServerName: listeners[0].FQDN, + RootCAs: testCertRoot.Pool(), + })) + }, + }, + { + name: "identity_headers", + modes: []ServiceMode{ + ServiceModeHTTP{ + Port: 80, + }, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + expectHeader := "Tailscale-User-Name" + go checkAndEcho(t, listeners[0], func(r *http.Request) { + if _, ok := r.Header[expectHeader]; !ok { + t.Error("did not see expected header:", expectHeader) + } + }) + assertEchoHTTP(t, listeners[0].FQDN, "", peer.Dial) + }, + }, + { + name: "identity_headers_TLS", + modes: []ServiceMode{ + ServiceModeHTTP{ + HTTPS: true, + Port: 80, + }, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + expectHeader := "Tailscale-User-Name" + go checkAndEcho(t, listeners[0], func(r *http.Request) { + if _, ok := r.Header[expectHeader]; !ok { + t.Error("did not see expected header:", expectHeader) + } + }) + + dial := func(ctx context.Context, network, addr string) (net.Conn, error) { + tcpConn, err := peer.Dial(ctx, network, addr) + if err != nil { + return nil, err + } + return tls.Client(tcpConn, &tls.Config{ + ServerName: listeners[0].FQDN, + RootCAs: testCertRoot.Pool(), + }), nil + } + + assertEchoHTTP(t, listeners[0].FQDN, "", dial) + }, + }, + { + name: "app_capabilities", + modes: []ServiceMode{ + ServiceModeHTTP{ + Port: 80, + AcceptAppCaps: map[string][]string{ + "/": {"example.com/cap/all-paths"}, + "/foo": {"example.com/cap/all-paths", "example.com/cap/foo"}, + }, + }, + }, + extraSetup: func(t *testing.T, control *testcontrol.Server) { + control.SetGlobalAppCaps(tailcfg.PeerCapMap{ + "example.com/cap/all-paths": []tailcfg.RawMessage{`true`}, + "example.com/cap/foo": []tailcfg.RawMessage{`true`}, + }) + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + allPathsCap := "example.com/cap/all-paths" + fooCap := "example.com/cap/foo" + checkCaps := func(r *http.Request) { + rawCaps, ok := r.Header["Tailscale-App-Capabilities"] + if !ok { + t.Error("no app capabilities header") + return + } + if len(rawCaps) != 1 { + t.Error("expected one app capabilities header value, got", len(rawCaps)) + return + } + var caps map[string][]any + if err := json.Unmarshal([]byte(rawCaps[0]), &caps); err != nil { + t.Error("error unmarshaling app caps:", err) + return + } + if _, ok := caps[allPathsCap]; !ok { + t.Errorf("got app caps, but %v is not present; saw:\n%v", allPathsCap, caps) + } + if strings.HasPrefix(r.URL.Path, "/foo") { + if _, ok := caps[fooCap]; !ok { + t.Errorf("%v should be present for /foo request; saw:\n%v", fooCap, caps) + } + } else { + if _, ok := caps[fooCap]; ok { + t.Errorf("%v should not be present for non-/foo request; saw:\n%v", fooCap, caps) + } + } + } + + go checkAndEcho(t, listeners[0], checkCaps) + assertEchoHTTP(t, listeners[0].FQDN, "", peer.Dial) + assertEchoHTTP(t, listeners[0].FQDN, "/foo", peer.Dial) + assertEchoHTTP(t, listeners[0].FQDN, "/foo/bar", peer.Dial) + }, + }, + { + name: "multiple_ports", + modes: []ServiceMode{ + ServiceModeTCP{ + Port: 99, + }, + ServiceModeHTTP{ + Port: 80, + }, + }, + run: func(t *testing.T, listeners []*ServiceListener, peer *Server) { + go acceptAndEcho(t, listeners[0]) + + target := fmt.Sprintf("%s:%d", listeners[0].FQDN, 99) + conn := must.Get(peer.Dial(t.Context(), "tcp", target)) + defer conn.Close() + assertEcho(t, conn) + + go checkAndEcho(t, listeners[1], nil) + assertEchoHTTP(t, listeners[1].FQDN, "", peer.Dial) + }, + }, + } + + for _, tt := range tests { + // Overview: + // - start test control + // - start 2 tsnet nodes: + // one to act as Service host and a second to act as a peer client + // - configure necessary state on control mock + // - start a Service listener from the host + // - call tt.run with our test bed + // + // This ends up also testing the Service forwarding logic in + // LocalBackend, but that's useful too. + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + + controlURL, control := startControl(t) + serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") + serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") + + const serviceName = tailcfg.ServiceName("svc:foo") + const serviceVIP = "100.11.22.33" + + // == Set up necessary state in our mock == + + // The Service host must have the 'service-host' capability, which + // is a mapping from the Service name to the Service VIP. + var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr] + mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)})) + j := must.Get(json.Marshal(serviceHostCaps)) + cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap() + mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)}) + control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm) + + // The Service host must be allowed to advertise the Service VIP. + control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{ + netip.MustParsePrefix(serviceVIP + `/32`), + }) + + // The Service host must be a tagged node (any tag will do). + serviceHostNode := control.Node(serviceHost.lb.NodeKey()) + serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag") + control.UpdateNode(serviceHostNode) + + // The service client must accept routes advertised by other nodes + // (RouteAll is equivalent to --accept-routes). + must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ + RouteAllSet: true, + Prefs: ipn.Prefs{ + RouteAll: true, + }, + })) + + // Set up DNS for our Service. + control.DNSConfig.ExtraRecords = append(control.DNSConfig.ExtraRecords, tailcfg.DNSRecord{ + Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain, + Value: serviceVIP, + }) + + if tt.extraSetup != nil { + tt.extraSetup(t, control) + } + + // Force netmap updates to avoid race conditions. The nodes need to + // see our control updates before we can start the test. + must.Do(control.ForceNetmapUpdates()) + netmapUpToDate := func(s *Server) bool { + nm := s.lb.NetMap() + return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool { + return r.Value == serviceVIP + }) + } + for !netmapUpToDate(serviceClient) { + time.Sleep(10 * time.Millisecond) + } + for !netmapUpToDate(serviceHost) { + time.Sleep(10 * time.Millisecond) + } + + // == Done setting up mock state == + + // Start the Service listeners. + listeners := make([]*ServiceListener, 0, len(tt.modes)) + for _, input := range tt.modes { + ln := must.Get(serviceHost.ListenService(serviceName.String(), input)) + defer ln.Close() + listeners = append(listeners, ln) + } + + tt.run(t, listeners, serviceClient) + }) + } +} + func TestListenerClose(t *testing.T) { tstest.Shard(t) ctx := context.Background() diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 19964c91f..7b5f5ffaf 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -110,6 +110,16 @@ type Server struct { // nodeCapMaps overrides the capability map sent down to a client. nodeCapMaps map[key.NodePublic]tailcfg.NodeCapMap + // globalAppCaps configures global app capabilities, equivalent to: + // "grants": [ + // { + // "src": ["*"], + // "dst": ["*"], + // "app": + // } + // ] + globalAppCaps tailcfg.PeerCapMap + // suppressAutoMapResponses is the set of nodes that should not be sent // automatic map responses from serveMap. (They should only get manually sent ones) suppressAutoMapResponses set.Set[key.NodePublic] @@ -289,6 +299,40 @@ func (s *Server) addDebugMessage(nodeKeyDst key.NodePublic, msg any) bool { return sendUpdate(oldUpdatesCh, updateDebugInjection) } +// ForceNetmapUpdates issues updated netmaps to all connected nodes. It is an +// error for a node to disconnect while this function runs. The intended use +// case is ensuring that state changes propagate before running a test. This +// function cannot guarantee that nodes have processed the issued updates, so +// tests should confirm processing by querying the nodes. By example: +// +// if err := s.ForceNetmapUpdates(); err != nil { +// // handle error +// } +// for !expectedChangesPresent(node.NetMap()) { +// time.Sleep(10 * time.Millisecond) +// } +func (s *Server) ForceNetmapUpdates() error { + s.mu.Lock() + connectedNodes := map[key.NodePublic]*tailcfg.Node{} + for k, n := range s.nodes { + if _, ok := s.updates[n.ID]; ok { + connectedNodes[k] = n + } + } + s.mu.Unlock() + + for k, n := range connectedNodes { + mr, err := s.MapResponse(&tailcfg.MapRequest{NodeKey: k}) + if err != nil { + return fmt.Errorf("generating map response for %v (%v): %w", n.ID, n.Hostinfo.Hostname(), err) + } + if !s.addDebugMessage(k, mr) { + return fmt.Errorf("sending map response to %v (%v): update channel full or missing for node", n.ID, n.Hostinfo.Hostname()) + } + } + return nil +} + // Mark the Node key of every node as expired func (s *Server) SetExpireAllNodes(expired bool) { s.mu.Lock() @@ -531,6 +575,21 @@ func (s *Server) SetNodeCapMap(nodeKey key.NodePublic, capMap tailcfg.NodeCapMap s.updateLocked("SetNodeCapMap", s.nodeIDsLocked(0)) } +// SetGlobalAppCaps configures global app capabilities. This is equivalent to +// +// "grants": [ +// { +// "src": ["*"], +// "dst": ["*"], +// "app": +// } +// ] +func (s *Server) SetGlobalAppCaps(appCaps tailcfg.PeerCapMap) { + s.mu.Lock() + s.globalAppCaps = appCaps + s.mu.Unlock() +} + // nodeIDsLocked returns the node IDs of all nodes in the server, except // for the node with the given ID. func (s *Server) nodeIDsLocked(except tailcfg.NodeID) []tailcfg.NodeID { @@ -838,6 +897,9 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. CapMap: capMap, Capabilities: slices.Collect(maps.Keys(capMap)), } + if s.MagicDNSDomain != "" { + node.Name = node.Name + "." + s.MagicDNSDomain + "." + } s.nodes[nk] = node } requireAuth := s.RequireAuth @@ -1261,9 +1323,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, dns := s.DNSConfig if dns != nil && s.MagicDNSDomain != "" { dns = dns.Clone() - dns.CertDomains = []string{ - node.Hostinfo.Hostname() + "." + s.MagicDNSDomain, - } + dns.CertDomains = append(dns.CertDomains, node.Hostinfo.Hostname()+"."+s.MagicDNSDomain) } res = &tailcfg.MapResponse{ @@ -1279,6 +1339,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, s.mu.Lock() nodeMasqs := s.masquerades[node.Key] jailed := maps.Clone(s.peerIsJailed[node.Key]) + globalAppCaps := s.globalAppCaps s.mu.Unlock() for _, p := range s.AllNodes() { if p.StableID == node.StableID { @@ -1330,6 +1391,18 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, v6Prefix, } + if globalAppCaps != nil { + res.PacketFilter = append(res.PacketFilter, tailcfg.FilterRule{ + SrcIPs: []string{"*"}, + CapGrant: []tailcfg.CapGrant{ + { + Dsts: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + CapMap: globalAppCaps, + }, + }, + }) + } + // If the server is tracking TKA state, and there's a single TKA head, // add it to the MapResponse. if s.tkaStorage != nil {