diff --git a/appc/appc.go b/appc/appc.go new file mode 100644 index 000000000..2e1a53f74 --- /dev/null +++ b/appc/appc.go @@ -0,0 +1,328 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package appc implements App Connectors. +package appc + +import ( + "expvar" + "log" + "net" + "net/netip" + "sync" + "time" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/appctype" + "tailscale.com/metrics" + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" + "tailscale.com/types/nettype" + "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" +) + +var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") + +// target describes the predicates which route some inbound +// traffic to the app connector to a specific handler. +type target struct { + Dest netip.Prefix + Matching tailcfg.ProtoPortRange +} + +// Server implements an App Connector. +type Server struct { + mu sync.RWMutex // mu guards following fields + connectors map[appctype.ConfigID]connector +} + +type appcMetrics struct { + dnsResponses expvar.Int + dnsFailures expvar.Int + tcpConns expvar.Int + sniConns expvar.Int + unhandledConns expvar.Int +} + +var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics { + m := appcMetrics{} + + stats := new(metrics.Set) + stats.Set("tls_sessions", &m.sniConns) + clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value) + stats.Set("tcp_sessions", &m.tcpConns) + clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value) + stats.Set("dns_responses", &m.dnsResponses) + clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value) + stats.Set("dns_failed", &m.dnsFailures) + clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value) + expvar.Publish("sniproxy", stats) + + return &m +}) + +// Configure applies the provided configuration to the app connector. +func (s *Server) Configure(cfg *appctype.AppConnectorConfig) { + s.mu.Lock() + defer s.mu.Unlock() + s.connectors = makeConnectorsFromConfig(cfg) +} + +// HandleTCPFlow implements tsnet.FallbackTCPHandler. +func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { + m := getMetrics() + s.mu.RLock() + defer s.mu.RUnlock() + + for _, c := range s.connectors { + if handler, intercept := c.handleTCPFlow(src, dst, m); intercept { + return handler, intercept + } + } + + return nil, false +} + +// HandleDNS handles a DNS request to the app connector. +func (s *Server) HandleDNS(c nettype.ConnPacketConn) { + defer c.Close() + c.SetReadDeadline(time.Now().Add(5 * time.Second)) + m := getMetrics() + + buf := make([]byte, 1500) + n, err := c.Read(buf) + if err != nil { + log.Printf("HandleDNS: read failed: %v\n ", err) + m.dnsFailures.Add(1) + return + } + + addrPortStr := c.LocalAddr().String() + host, _, err := net.SplitHostPort(addrPortStr) + if err != nil { + log.Printf("HandleDNS: bogus addrPort %q", addrPortStr) + m.dnsFailures.Add(1) + return + } + localAddr, err := netip.ParseAddr(host) + if err != nil { + log.Printf("HandleDNS: bogus local address %q", host) + m.dnsFailures.Add(1) + return + } + + var msg dnsmessage.Message + err = msg.Unpack(buf[:n]) + if err != nil { + log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) + m.dnsFailures.Add(1) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, connector := range s.connectors { + resp, err := connector.handleDNS(&msg, localAddr) + if err != nil { + log.Printf("HandleDNS: connector handling failed: %v\n", err) + m.dnsFailures.Add(1) + return + } + if len(resp) > 0 { + // This connector handled the DNS request + _, err = c.Write(resp) + if err != nil { + log.Printf("HandleDNS: write failed: %v\n", err) + m.dnsFailures.Add(1) + return + } + + m.dnsResponses.Add(1) + return + } + } +} + +// connector describes a logical collection of +// services which need to be proxied. +type connector struct { + Handlers map[target]handler +} + +// handleTCPFlow implements tsnet.FallbackTCPHandler. +func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) { + for t, h := range c.Handlers { + if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) { + continue + } + if !t.Dest.Contains(dst.Addr()) { + continue + } + if !t.Matching.Ports.Contains(dst.Port()) { + continue + } + + switch h.(type) { + case *tcpSNIHandler: + m.sniConns.Add(1) + case *tcpRoundRobinHandler: + m.tcpConns.Add(1) + default: + log.Printf("handleTCPFlow: unhandled handler type %T", h) + } + + return h.Handle, true + } + + m.unhandledConns.Add(1) + return nil, false +} + +// handleDNS returns the DNS response to the given query. If this +// connector is unable to handle the request, nil is returned. +func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) { + for t, h := range c.Handlers { + if t.Dest.Contains(localAddr) { + return makeDNSResponse(req, h.ReachableOn()) + } + } + + // Did not match, signal 'not handled' to caller + return nil, nil +} + +func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { + buf := make([]byte, 1500) + resp := dnsmessage.NewBuilder(buf, + dnsmessage.Header{ + ID: req.Header.ID, + Response: true, + Authoritative: true, + }) + resp.EnableCompression() + + if len(req.Questions) == 0 { + buf, _ = resp.Finish() + return buf, nil + } + q := req.Questions[0] + err = resp.StartQuestions() + if err != nil { + return + } + resp.Question(q) + + err = resp.StartAnswers() + if err != nil { + return + } + + switch q.Type { + case dnsmessage.TypeAAAA: + for _, ip := range reachableIPs { + if ip.Is6() { + err = resp.AAAAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AAAAResource{AAAA: ip.As16()}, + ) + } + } + + case dnsmessage.TypeA: + for _, ip := range reachableIPs { + if ip.Is4() { + err = resp.AResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AResource{A: ip.As4()}, + ) + } + } + + case dnsmessage.TypeSOA: + err = resp.SOAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, + Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, + ) + case dnsmessage.TypeNS: + err = resp.NSResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.NSResource{NS: tsMBox}, + ) + } + + if err != nil { + return nil, err + } + return resp.Finish() +} + +type handler interface { + // Handle handles the given socket. + Handle(c net.Conn) + + // ReachableOn returns the IP addresses this handler is reachable on. + ReachableOn() []netip.Addr +} + +func installDNATHandler(d *appctype.DNATConfig, out *connector) { + // These handlers don't actually do DNAT, they just + // proxy the data over the connection. + var dialer net.Dialer + dialer.Timeout = 5 * time.Second + h := tcpRoundRobinHandler{ + To: d.To, + DialContext: dialer.DialContext, + ReachableIPs: d.Addrs, + } + + for _, addr := range d.Addrs { + for _, protoPort := range d.IP { + t := target{ + Dest: netip.PrefixFrom(addr, addr.BitLen()), + Matching: protoPort, + } + + mak.Set(&out.Handlers, t, handler(&h)) + } + } +} + +func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) { + var dialer net.Dialer + dialer.Timeout = 5 * time.Second + h := tcpSNIHandler{ + Allowlist: c.AllowedDomains, + DialContext: dialer.DialContext, + ReachableIPs: c.Addrs, + } + + for _, addr := range c.Addrs { + for _, protoPort := range c.IP { + t := target{ + Dest: netip.PrefixFrom(addr, addr.BitLen()), + Matching: protoPort, + } + + mak.Set(&out.Handlers, t, handler(&h)) + } + } +} + +func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector { + var connectors map[appctype.ConfigID]connector + + for cID, d := range cfg.DNAT { + c := connectors[cID] + installDNATHandler(&d, &c) + mak.Set(&connectors, cID, c) + } + for cID, d := range cfg.SNIProxy { + c := connectors[cID] + installSNIHandler(&d, &c) + mak.Set(&connectors, cID, c) + } + + return connectors +} diff --git a/appc/appc_test.go b/appc/appc_test.go new file mode 100644 index 000000000..d14a5bbf0 --- /dev/null +++ b/appc/appc_test.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appc + +import ( + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/appctype" + "tailscale.com/tailcfg" +) + +func TestMakeConnectorsFromConfig(t *testing.T) { + tcs := []struct { + name string + input *appctype.AppConnectorConfig + want map[appctype.ConfigID]connector + }{ + { + "empty", + &appctype.AppConnectorConfig{}, + nil, + }, + { + "DNAT", + &appctype.AppConnectorConfig{ + DNAT: map[appctype.ConfigID]appctype.DNATConfig{ + "swiggity_swooty": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + To: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }, + }, + }, + map[appctype.ConfigID]connector{ + "swiggity_swooty": { + Handlers: map[target]handler{ + { + Dest: netip.MustParsePrefix("100.64.0.1/32"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + { + Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + }, + }, + }, + }, + { + "SNIProxy", + &appctype.AppConnectorConfig{ + SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{ + "swiggity_swooty": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + AllowedDomains: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }, + }, + }, + map[appctype.ConfigID]connector{ + "swiggity_swooty": { + Handlers: map[target]handler{ + { + Dest: netip.MustParsePrefix("100.64.0.1/32"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + { + Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + }, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + connectors := makeConnectorsFromConfig(tc.input) + + if diff := cmp.Diff(connectors, tc.want, + cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"), + cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"), + cmp.Comparer(func(x, y netip.Addr) bool { + return x == y + })); diff != "" { + t.Fatalf("mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/appc/handlers.go b/appc/handlers.go new file mode 100644 index 000000000..0d017309b --- /dev/null +++ b/appc/handlers.go @@ -0,0 +1,104 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appc + +import ( + "context" + "fmt" + "log" + "math/rand" + "net" + "net/netip" + "slices" + + "inet.af/tcpproxy" + "tailscale.com/net/netutil" +) + +type tcpRoundRobinHandler struct { + // To is a list of destination addresses to forward to. + // An entry may be either an IP address or a DNS name. + To []string + + // DialContext is used to make the outgoing TCP connection. + DialContext func(ctx context.Context, network, address string) (net.Conn, error) + + // ReachableIPs enumerates the IP addresses this handler is reachable on. + ReachableIPs []netip.Addr +} + +// ReachableOn returns the IP addresses this handler is reachable on. +func (h *tcpRoundRobinHandler) ReachableOn() []netip.Addr { + return h.ReachableIPs +} + +func (h *tcpRoundRobinHandler) Handle(c net.Conn) { + addrPortStr := c.LocalAddr().String() + _, port, err := net.SplitHostPort(addrPortStr) + if err != nil { + log.Printf("tcpRoundRobinHandler.Handle: bogus addrPort %q", addrPortStr) + c.Close() + return + } + + var p tcpproxy.Proxy + p.ListenFunc = func(net, laddr string) (net.Listener, error) { + return netutil.NewOneConnListener(c, nil), nil + } + + dest := h.To[rand.Intn(len(h.To))] + dial := &tcpproxy.DialProxy{ + Addr: fmt.Sprintf("%s:%s", dest, port), + DialContext: h.DialContext, + } + + p.AddRoute(addrPortStr, dial) + p.Start() +} + +type tcpSNIHandler struct { + // Allowlist enumerates the FQDNs which may be proxied via SNI. An + // empty slice means all domains are permitted. + Allowlist []string + + // DialContext is used to make the outgoing TCP connection. + DialContext func(ctx context.Context, network, address string) (net.Conn, error) + + // ReachableIPs enumerates the IP addresses this handler is reachable on. + ReachableIPs []netip.Addr +} + +// ReachableOn returns the IP addresses this handler is reachable on. +func (h *tcpSNIHandler) ReachableOn() []netip.Addr { + return h.ReachableIPs +} + +func (h *tcpSNIHandler) Handle(c net.Conn) { + addrPortStr := c.LocalAddr().String() + _, port, err := net.SplitHostPort(addrPortStr) + if err != nil { + log.Printf("tcpSNIHandler.Handle: bogus addrPort %q", addrPortStr) + c.Close() + return + } + + var p tcpproxy.Proxy + p.ListenFunc = func(net, laddr string) (net.Listener, error) { + return netutil.NewOneConnListener(c, nil), nil + } + p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) { + if len(h.Allowlist) > 0 { + // TODO(tom): handle subdomains + if slices.Index(h.Allowlist, sniName) < 0 { + return nil, false + } + } + + return &tcpproxy.DialProxy{ + Addr: net.JoinHostPort(sniName, port), + DialContext: h.DialContext, + }, true + }) + p.Start() +} diff --git a/appc/handlers_test.go b/appc/handlers_test.go new file mode 100644 index 000000000..d8229d004 --- /dev/null +++ b/appc/handlers_test.go @@ -0,0 +1,159 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appc + +import ( + "bytes" + "context" + "encoding/hex" + "io" + "net" + "net/netip" + "strings" + "testing" + + "tailscale.com/net/memnet" +) + +func echoConnOnce(conn net.Conn) { + defer conn.Close() + + b := make([]byte, 256) + n, err := conn.Read(b) + if err != nil { + return + } + + if _, err := conn.Write(b[:n]); err != nil { + return + } +} + +func TestTCPRoundRobinHandler(t *testing.T) { + h := tcpRoundRobinHandler{ + To: []string{"yeet.com"}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + t.Errorf("network = %s, want %s", network, "tcp") + } + if addr != "yeet.com:22" { + t.Errorf("addr = %s, want %s", addr, "yeet.com:22") + } + + c, s := memnet.NewConn("outbound", 1024) + go echoConnOnce(s) + return c, nil + }, + } + + cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024) + h.Handle(sSock) + + // Test data write and read, the other end will echo back + // a single stanza + want := "hello" + if _, err := io.WriteString(cSock, want); err != nil { + t.Fatal(err) + } + got := make([]byte, len(want)) + if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } + + // The other end closed the socket after the first echo, so + // any following read should error. + io.WriteString(cSock, "deadass heres some data on god fr") + if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil { + t.Error("read succeeded on closed socket") + } +} + +// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com +const tlsStart = `45000239ff1840004006f9f5c0a801f2 +c726b5efcf9e01bbe803b21394e3b752 +801801f641dc00000101080ade3474f2 +2fb93ee71603010200010001fc030303 +c3acbd19d2624765bb19af4bce03365e +1d197f5bb939cdadeff26b0f8e7a0620 +295b04127b82bae46aac4ff58cffef25 +eba75a4b7a6de729532c411bd9dd0d2c +00203a3a130113021303c02bc02fc02c +c030cca9cca8c013c014009c009d002f +003501000193caca0000000a000a0008 +1a1a001d001700180010000e000c0268 +3208687474702f312e31002b0007062a +2a03040303ff01000100000d00120010 +04030804040105030805050108060601 +000b00020100002300000033002b0029 +1a1a000100001d0020d3c76bef062979 +a812ce935cfb4dbe6b3a84dc5ba9226f +23b0f34af9d1d03b4a001b0003020002 +00120000446900050003026832000000 +170015000012706b67732e7461696c73 +63616c652e636f6d002d000201010005 +00050100000000001700003a3a000100 +0015002d000000000000000000000000 +00000000000000000000000000000000 +00000000000000000000000000000000 +0000290094006f0069e76f2016f963ad +38c8632d1f240cd75e00e25fdef295d4 +7042b26f3a9a543b1c7dc74939d77803 +20527d423ff996997bda2c6383a14f49 +219eeef8a053e90a32228df37ddbe126 +eccf6b085c93890d08341d819aea6111 +0d909f4cd6b071d9ea40618e74588a33 +90d494bbb5c3002120d5a164a16c9724 +c9ef5e540d8d6f007789a7acf9f5f16f +bf6a1907a6782ed02b` + +func fakeSNIHeader() []byte { + b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1)) + if err != nil { + panic(err) + } + return b[0x34:] // trim IP + TCP header +} + +func TestTCPSNIHandler(t *testing.T) { + h := tcpSNIHandler{ + Allowlist: []string{"pkgs.tailscale.com"}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + t.Errorf("network = %s, want %s", network, "tcp") + } + if addr != "pkgs.tailscale.com:443" { + t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443") + } + + c, s := memnet.NewConn("outbound", 1024) + go echoConnOnce(s) + return c, nil + }, + } + + cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024) + h.Handle(sSock) + + // Fake a TLS handshake record with an SNI in it. + if _, err := cSock.Write(fakeSNIHeader()); err != nil { + t.Fatal(err) + } + + // Test read, the other end will echo back + // a single stanza, which is at least the beginning of the SNI header. + want := fakeSNIHeader()[:5] + if _, err := cSock.Write(want); err != nil { + t.Fatal(err) + } + got := make([]byte, len(want)) + if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Errorf("got %q, want %q", got, want) + } +}