appc: implement app connector Server type

This change refactors & moves the bulk of the app connector logic from
./cmd/sniproxy.

A future change will delete the delta in sniproxy and wire it to this type.

Signed-off-by: Tom DNetto <tom@tailscale.com>
Updates: https://github.com/tailscale/corp/issues/15038
pull/9885/head
Tom DNetto 8 months ago committed by Tom
parent 469b7cabad
commit 02908a2d8d

@ -0,0 +1,328 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package appc implements App Connectors.
package appc
import (
"expvar"
"log"
"net"
"net/netip"
"sync"
"time"
"golang.org/x/net/dns/dnsmessage"
"tailscale.com/appctype"
"tailscale.com/metrics"
"tailscale.com/tailcfg"
"tailscale.com/types/ipproto"
"tailscale.com/types/nettype"
"tailscale.com/util/clientmetric"
"tailscale.com/util/mak"
)
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
// target describes the predicates which route some inbound
// traffic to the app connector to a specific handler.
type target struct {
Dest netip.Prefix
Matching tailcfg.ProtoPortRange
}
// Server implements an App Connector.
type Server struct {
mu sync.RWMutex // mu guards following fields
connectors map[appctype.ConfigID]connector
}
type appcMetrics struct {
dnsResponses expvar.Int
dnsFailures expvar.Int
tcpConns expvar.Int
sniConns expvar.Int
unhandledConns expvar.Int
}
var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics {
m := appcMetrics{}
stats := new(metrics.Set)
stats.Set("tls_sessions", &m.sniConns)
clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value)
stats.Set("tcp_sessions", &m.tcpConns)
clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value)
stats.Set("dns_responses", &m.dnsResponses)
clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value)
stats.Set("dns_failed", &m.dnsFailures)
clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value)
expvar.Publish("sniproxy", stats)
return &m
})
// Configure applies the provided configuration to the app connector.
func (s *Server) Configure(cfg *appctype.AppConnectorConfig) {
s.mu.Lock()
defer s.mu.Unlock()
s.connectors = makeConnectorsFromConfig(cfg)
}
// HandleTCPFlow implements tsnet.FallbackTCPHandler.
func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
m := getMetrics()
s.mu.RLock()
defer s.mu.RUnlock()
for _, c := range s.connectors {
if handler, intercept := c.handleTCPFlow(src, dst, m); intercept {
return handler, intercept
}
}
return nil, false
}
// HandleDNS handles a DNS request to the app connector.
func (s *Server) HandleDNS(c nettype.ConnPacketConn) {
defer c.Close()
c.SetReadDeadline(time.Now().Add(5 * time.Second))
m := getMetrics()
buf := make([]byte, 1500)
n, err := c.Read(buf)
if err != nil {
log.Printf("HandleDNS: read failed: %v\n ", err)
m.dnsFailures.Add(1)
return
}
addrPortStr := c.LocalAddr().String()
host, _, err := net.SplitHostPort(addrPortStr)
if err != nil {
log.Printf("HandleDNS: bogus addrPort %q", addrPortStr)
m.dnsFailures.Add(1)
return
}
localAddr, err := netip.ParseAddr(host)
if err != nil {
log.Printf("HandleDNS: bogus local address %q", host)
m.dnsFailures.Add(1)
return
}
var msg dnsmessage.Message
err = msg.Unpack(buf[:n])
if err != nil {
log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err)
m.dnsFailures.Add(1)
return
}
s.mu.RLock()
defer s.mu.RUnlock()
for _, connector := range s.connectors {
resp, err := connector.handleDNS(&msg, localAddr)
if err != nil {
log.Printf("HandleDNS: connector handling failed: %v\n", err)
m.dnsFailures.Add(1)
return
}
if len(resp) > 0 {
// This connector handled the DNS request
_, err = c.Write(resp)
if err != nil {
log.Printf("HandleDNS: write failed: %v\n", err)
m.dnsFailures.Add(1)
return
}
m.dnsResponses.Add(1)
return
}
}
}
// connector describes a logical collection of
// services which need to be proxied.
type connector struct {
Handlers map[target]handler
}
// handleTCPFlow implements tsnet.FallbackTCPHandler.
func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) {
for t, h := range c.Handlers {
if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) {
continue
}
if !t.Dest.Contains(dst.Addr()) {
continue
}
if !t.Matching.Ports.Contains(dst.Port()) {
continue
}
switch h.(type) {
case *tcpSNIHandler:
m.sniConns.Add(1)
case *tcpRoundRobinHandler:
m.tcpConns.Add(1)
default:
log.Printf("handleTCPFlow: unhandled handler type %T", h)
}
return h.Handle, true
}
m.unhandledConns.Add(1)
return nil, false
}
// handleDNS returns the DNS response to the given query. If this
// connector is unable to handle the request, nil is returned.
func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) {
for t, h := range c.Handlers {
if t.Dest.Contains(localAddr) {
return makeDNSResponse(req, h.ReachableOn())
}
}
// Did not match, signal 'not handled' to caller
return nil, nil
}
func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) {
buf := make([]byte, 1500)
resp := dnsmessage.NewBuilder(buf,
dnsmessage.Header{
ID: req.Header.ID,
Response: true,
Authoritative: true,
})
resp.EnableCompression()
if len(req.Questions) == 0 {
buf, _ = resp.Finish()
return buf, nil
}
q := req.Questions[0]
err = resp.StartQuestions()
if err != nil {
return
}
resp.Question(q)
err = resp.StartAnswers()
if err != nil {
return
}
switch q.Type {
case dnsmessage.TypeAAAA:
for _, ip := range reachableIPs {
if ip.Is6() {
err = resp.AAAAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.AAAAResource{AAAA: ip.As16()},
)
}
}
case dnsmessage.TypeA:
for _, ip := range reachableIPs {
if ip.Is4() {
err = resp.AResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.AResource{A: ip.As4()},
)
}
}
case dnsmessage.TypeSOA:
err = resp.SOAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
)
case dnsmessage.TypeNS:
err = resp.NSResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.NSResource{NS: tsMBox},
)
}
if err != nil {
return nil, err
}
return resp.Finish()
}
type handler interface {
// Handle handles the given socket.
Handle(c net.Conn)
// ReachableOn returns the IP addresses this handler is reachable on.
ReachableOn() []netip.Addr
}
func installDNATHandler(d *appctype.DNATConfig, out *connector) {
// These handlers don't actually do DNAT, they just
// proxy the data over the connection.
var dialer net.Dialer
dialer.Timeout = 5 * time.Second
h := tcpRoundRobinHandler{
To: d.To,
DialContext: dialer.DialContext,
ReachableIPs: d.Addrs,
}
for _, addr := range d.Addrs {
for _, protoPort := range d.IP {
t := target{
Dest: netip.PrefixFrom(addr, addr.BitLen()),
Matching: protoPort,
}
mak.Set(&out.Handlers, t, handler(&h))
}
}
}
func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) {
var dialer net.Dialer
dialer.Timeout = 5 * time.Second
h := tcpSNIHandler{
Allowlist: c.AllowedDomains,
DialContext: dialer.DialContext,
ReachableIPs: c.Addrs,
}
for _, addr := range c.Addrs {
for _, protoPort := range c.IP {
t := target{
Dest: netip.PrefixFrom(addr, addr.BitLen()),
Matching: protoPort,
}
mak.Set(&out.Handlers, t, handler(&h))
}
}
}
func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector {
var connectors map[appctype.ConfigID]connector
for cID, d := range cfg.DNAT {
c := connectors[cID]
installDNATHandler(&d, &c)
mak.Set(&connectors, cID, c)
}
for cID, d := range cfg.SNIProxy {
c := connectors[cID]
installSNIHandler(&d, &c)
mak.Set(&connectors, cID, c)
}
return connectors
}

@ -0,0 +1,95 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package appc
import (
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/appctype"
"tailscale.com/tailcfg"
)
func TestMakeConnectorsFromConfig(t *testing.T) {
tcs := []struct {
name string
input *appctype.AppConnectorConfig
want map[appctype.ConfigID]connector
}{
{
"empty",
&appctype.AppConnectorConfig{},
nil,
},
{
"DNAT",
&appctype.AppConnectorConfig{
DNAT: map[appctype.ConfigID]appctype.DNATConfig{
"swiggity_swooty": {
Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
To: []string{"example.org"},
IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
},
},
},
map[appctype.ConfigID]connector{
"swiggity_swooty": {
Handlers: map[target]handler{
{
Dest: netip.MustParsePrefix("100.64.0.1/32"),
Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
}: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
{
Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
}: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
},
},
},
},
{
"SNIProxy",
&appctype.AppConnectorConfig{
SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{
"swiggity_swooty": {
Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
AllowedDomains: []string{"example.org"},
IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
},
},
},
map[appctype.ConfigID]connector{
"swiggity_swooty": {
Handlers: map[target]handler{
{
Dest: netip.MustParsePrefix("100.64.0.1/32"),
Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
}: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
{
Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
}: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
connectors := makeConnectorsFromConfig(tc.input)
if diff := cmp.Diff(connectors, tc.want,
cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"),
cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"),
cmp.Comparer(func(x, y netip.Addr) bool {
return x == y
})); diff != "" {
t.Fatalf("mismatch (-want +got):\n%s", diff)
}
})
}
}

@ -0,0 +1,104 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package appc
import (
"context"
"fmt"
"log"
"math/rand"
"net"
"net/netip"
"slices"
"inet.af/tcpproxy"
"tailscale.com/net/netutil"
)
type tcpRoundRobinHandler struct {
// To is a list of destination addresses to forward to.
// An entry may be either an IP address or a DNS name.
To []string
// DialContext is used to make the outgoing TCP connection.
DialContext func(ctx context.Context, network, address string) (net.Conn, error)
// ReachableIPs enumerates the IP addresses this handler is reachable on.
ReachableIPs []netip.Addr
}
// ReachableOn returns the IP addresses this handler is reachable on.
func (h *tcpRoundRobinHandler) ReachableOn() []netip.Addr {
return h.ReachableIPs
}
func (h *tcpRoundRobinHandler) Handle(c net.Conn) {
addrPortStr := c.LocalAddr().String()
_, port, err := net.SplitHostPort(addrPortStr)
if err != nil {
log.Printf("tcpRoundRobinHandler.Handle: bogus addrPort %q", addrPortStr)
c.Close()
return
}
var p tcpproxy.Proxy
p.ListenFunc = func(net, laddr string) (net.Listener, error) {
return netutil.NewOneConnListener(c, nil), nil
}
dest := h.To[rand.Intn(len(h.To))]
dial := &tcpproxy.DialProxy{
Addr: fmt.Sprintf("%s:%s", dest, port),
DialContext: h.DialContext,
}
p.AddRoute(addrPortStr, dial)
p.Start()
}
type tcpSNIHandler struct {
// Allowlist enumerates the FQDNs which may be proxied via SNI. An
// empty slice means all domains are permitted.
Allowlist []string
// DialContext is used to make the outgoing TCP connection.
DialContext func(ctx context.Context, network, address string) (net.Conn, error)
// ReachableIPs enumerates the IP addresses this handler is reachable on.
ReachableIPs []netip.Addr
}
// ReachableOn returns the IP addresses this handler is reachable on.
func (h *tcpSNIHandler) ReachableOn() []netip.Addr {
return h.ReachableIPs
}
func (h *tcpSNIHandler) Handle(c net.Conn) {
addrPortStr := c.LocalAddr().String()
_, port, err := net.SplitHostPort(addrPortStr)
if err != nil {
log.Printf("tcpSNIHandler.Handle: bogus addrPort %q", addrPortStr)
c.Close()
return
}
var p tcpproxy.Proxy
p.ListenFunc = func(net, laddr string) (net.Listener, error) {
return netutil.NewOneConnListener(c, nil), nil
}
p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
if len(h.Allowlist) > 0 {
// TODO(tom): handle subdomains
if slices.Index(h.Allowlist, sniName) < 0 {
return nil, false
}
}
return &tcpproxy.DialProxy{
Addr: net.JoinHostPort(sniName, port),
DialContext: h.DialContext,
}, true
})
p.Start()
}

@ -0,0 +1,159 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package appc
import (
"bytes"
"context"
"encoding/hex"
"io"
"net"
"net/netip"
"strings"
"testing"
"tailscale.com/net/memnet"
)
func echoConnOnce(conn net.Conn) {
defer conn.Close()
b := make([]byte, 256)
n, err := conn.Read(b)
if err != nil {
return
}
if _, err := conn.Write(b[:n]); err != nil {
return
}
}
func TestTCPRoundRobinHandler(t *testing.T) {
h := tcpRoundRobinHandler{
To: []string{"yeet.com"},
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if network != "tcp" {
t.Errorf("network = %s, want %s", network, "tcp")
}
if addr != "yeet.com:22" {
t.Errorf("addr = %s, want %s", addr, "yeet.com:22")
}
c, s := memnet.NewConn("outbound", 1024)
go echoConnOnce(s)
return c, nil
},
}
cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024)
h.Handle(sSock)
// Test data write and read, the other end will echo back
// a single stanza
want := "hello"
if _, err := io.WriteString(cSock, want); err != nil {
t.Fatal(err)
}
got := make([]byte, len(want))
if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
t.Fatal(err)
}
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
// The other end closed the socket after the first echo, so
// any following read should error.
io.WriteString(cSock, "deadass heres some data on god fr")
if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil {
t.Error("read succeeded on closed socket")
}
}
// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com
const tlsStart = `45000239ff1840004006f9f5c0a801f2
c726b5efcf9e01bbe803b21394e3b752
801801f641dc00000101080ade3474f2
2fb93ee71603010200010001fc030303
c3acbd19d2624765bb19af4bce03365e
1d197f5bb939cdadeff26b0f8e7a0620
295b04127b82bae46aac4ff58cffef25
eba75a4b7a6de729532c411bd9dd0d2c
00203a3a130113021303c02bc02fc02c
c030cca9cca8c013c014009c009d002f
003501000193caca0000000a000a0008
1a1a001d001700180010000e000c0268
3208687474702f312e31002b0007062a
2a03040303ff01000100000d00120010
04030804040105030805050108060601
000b00020100002300000033002b0029
1a1a000100001d0020d3c76bef062979
a812ce935cfb4dbe6b3a84dc5ba9226f
23b0f34af9d1d03b4a001b0003020002
00120000446900050003026832000000
170015000012706b67732e7461696c73
63616c652e636f6d002d000201010005
00050100000000001700003a3a000100
0015002d000000000000000000000000
00000000000000000000000000000000
00000000000000000000000000000000
0000290094006f0069e76f2016f963ad
38c8632d1f240cd75e00e25fdef295d4
7042b26f3a9a543b1c7dc74939d77803
20527d423ff996997bda2c6383a14f49
219eeef8a053e90a32228df37ddbe126
eccf6b085c93890d08341d819aea6111
0d909f4cd6b071d9ea40618e74588a33
90d494bbb5c3002120d5a164a16c9724
c9ef5e540d8d6f007789a7acf9f5f16f
bf6a1907a6782ed02b`
func fakeSNIHeader() []byte {
b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1))
if err != nil {
panic(err)
}
return b[0x34:] // trim IP + TCP header
}
func TestTCPSNIHandler(t *testing.T) {
h := tcpSNIHandler{
Allowlist: []string{"pkgs.tailscale.com"},
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if network != "tcp" {
t.Errorf("network = %s, want %s", network, "tcp")
}
if addr != "pkgs.tailscale.com:443" {
t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443")
}
c, s := memnet.NewConn("outbound", 1024)
go echoConnOnce(s)
return c, nil
},
}
cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024)
h.Handle(sSock)
// Fake a TLS handshake record with an SNI in it.
if _, err := cSock.Write(fakeSNIHeader()); err != nil {
t.Fatal(err)
}
// Test read, the other end will echo back
// a single stanza, which is at least the beginning of the SNI header.
want := fakeSNIHeader()[:5]
if _, err := cSock.Write(want); err != nil {
t.Fatal(err)
}
got := make([]byte, len(want))
if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
Loading…
Cancel
Save