tsnet: support registering fallback TCP flow handlers

For the app connector use-case, it doesnt make sense to use listeners, because then you would
need to register thousands of listeners (for each proto/service/port combo) to handle ranges.

Instead, we plumb through the TCPHandlerForFlow abstraction, to avoid using the listeners
abstraction that would end up being a bit messy.

Signed-off-by: Tom DNetto <tom@tailscale.com>
Updates: https://github.com/tailscale/corp/issues/15038
pull/9751/head
Tom DNetto 1 year ago committed by Tom
parent 9f05018419
commit fffafc65d6

@ -53,6 +53,7 @@ import (
"tailscale.com/types/nettype"
"tailscale.com/util/clientmetric"
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/util/testenv"
"tailscale.com/wgengine"
"tailscale.com/wgengine/netstack"
@ -135,10 +136,24 @@ type Server struct {
mu sync.Mutex
listeners map[listenKey]*listener
fallbackTCPHandlers set.HandleSet[FallbackTCPHandler]
dialer *tsdial.Dialer
closed bool
}
// FallbackTCPHandler describes the callback which
// conditionally handles an incoming TCP flow for the
// provided (src/port, dst/port) 4-tuple. These are registered
// as handlers of last resort, and are called only if no
// listener could handle the incoming flow.
//
// If the callback returns intercept=false, the flow is rejected.
//
// When intercept=true, the behavior depends on whether the returned handler
// is non-nil: if nil, the connection is rejected. If non-nil, handler takes
// over the TCP conn.
type FallbackTCPHandler func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool)
// Dial connects to the address on the tailnet.
// It will start the server if it has not been started yet.
func (s *Server) Dial(ctx context.Context, network, address string) (net.Conn, error) {
@ -755,6 +770,14 @@ func (s *Server) getTCPHandlerForFunnelFlow(src netip.AddrPort, dstPort uint16)
func (s *Server) getTCPHandlerForFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
ln, ok := s.listenerForDstAddr("tcp", dst, false)
if !ok {
s.mu.Lock()
defer s.mu.Unlock()
for _, handler := range s.fallbackTCPHandlers {
connHandler, intercept := handler(src, dst)
if intercept {
return connHandler, intercept
}
}
return nil, true // don't handle, don't forward to localhost
}
return ln.handle, true
@ -858,6 +881,24 @@ func (s *Server) ListenTLS(network, addr string) (net.Listener, error) {
}), nil
}
// RegisterFallbackTCPHandler registers a callback which will be called
// to handle a TCP flow to this tsnet node, for which no listeners will handle.
//
// If multiple fallback handlers are registered, they will be called in an
// undefined order. See FallbackTCPHandler for details on handling a flow.
//
// The returned function can be used to deregister this callback.
func (s *Server) RegisterFallbackTCPHandler(cb FallbackTCPHandler) func() {
s.mu.Lock()
defer s.mu.Unlock()
hnd := s.fallbackTCPHandlers.Add(cb)
return func() {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.fallbackTCPHandlers, hnd)
}
}
// getCert is the GetCertificate function used by ListenTLS.
//
// It calls GetCertificate on the localClient, passing in the ClientHelloInfo.

@ -630,3 +630,45 @@ type bufferedConn struct {
func (c *bufferedConn) Read(b []byte) (int, error) {
return c.reader.Read(b)
}
func TestFallbackTCPHandler(t *testing.T) {
tstest.ResourceCheck(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL := startControl(t)
s1, s1ip := startServer(t, ctx, controlURL, "s1")
s2, _ := startServer(t, ctx, controlURL, "s2")
lc2, err := s2.LocalClient()
if err != nil {
t.Fatal(err)
}
// ping to make sure the connection is up.
res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP)
if err != nil {
t.Fatal(err)
}
t.Logf("ping success: %#+v", res)
s1TcpConnCount := 0
deregister := s1.RegisterFallbackTCPHandler(func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
s1TcpConnCount++
return nil, false
})
if _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil {
t.Fatal("Expected dial error because fallback handler did not intercept")
}
if s1TcpConnCount != 1 {
t.Errorf("s1TcpConnCount = %d, want %d", s1TcpConnCount, 1)
}
deregister()
if _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil {
t.Fatal("Expected dial error because nothing would intercept")
}
if s1TcpConnCount != 1 {
t.Errorf("s1TcpConnCount = %d, want %d", s1TcpConnCount, 1)
}
}

Loading…
Cancel
Save