From b0cb39cda1bdd0b72b4ec926c902806df7e44ce7 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Tue, 7 Mar 2023 14:52:06 -0800 Subject: [PATCH] tsnet: only intercept TCP flows that have listeners Previously, it would accept all TCP connections and then close the ones it did not care about. Make it only ever accept the connections that it cares about. Signed-off-by: Maisem Ali --- tsnet/tsnet.go | 18 +++++------------- tsnet/tsnet_test.go | 5 +++++ wgengine/netstack/netstack.go | 35 +++++++++++++++++++++++++---------- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index ac685b9b4..c226a5852 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -491,7 +491,7 @@ func (s *Server) start() (reterr error) { return fmt.Errorf("netstack.Create: %w", err) } ns.ProcessLocalIPs = true - ns.ForwardTCPIn = s.forwardTCP + ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow s.netstack = ns s.dialer.UseNetstackForIP = func(ip netip.Addr) bool { @@ -660,20 +660,12 @@ func (s *Server) listenerForDstAddr(netBase string, dst netip.AddrPort) (_ *list return nil, false } -func (s *Server) forwardTCP(c net.Conn, port uint16) { - dstStr := c.LocalAddr().String() - ap, err := netip.ParseAddrPort(dstStr) - if err != nil { - s.logf("unexpected dst addr %q", dstStr) - c.Close() - return - } - ln, ok := s.listenerForDstAddr("tcp", ap) +func (s *Server) getTCPHandlerForFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { + ln, ok := s.listenerForDstAddr("tcp", dst) if !ok { - c.Close() - return + return nil, true // don't handle, don't forward to localhost } - ln.handle(c) + return ln.handle, true } func (s *Server) getUDPHandlerForFlow(src, dst netip.AddrPort) (handler func(nettype.ConnPacketConn), intercept bool) { diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index bc5adc2e2..ab55b7b60 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -175,6 +175,11 @@ func TestConn(t *testing.T) { if string(got) != want { t.Errorf("got %q, want %q", got, want) } + + _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8082", s1ip)) // some random port + if err == nil { + t.Fatalf("unexpected success; should have seen a connection refused error") + } } func TestLoopbackLocalAPI(t *testing.T) { diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 131c0d392..54a039dd6 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -79,12 +79,18 @@ func init() { // and implements wgengine.FakeImpl to act as a userspace network // stack when Tailscale is running in fake mode. type Impl struct { - // ForwardTCPIn, if non-nil, handles forwarding an inbound TCP connection. + // GetTCPHandlerForFlow conditionally handles an incoming TCP flow for the + // provided (src/port, dst/port) 4-tuple. + // + // A nil value is equivalent to a func returning (nil, false). // - // TODO(bradfitz): convert this to the GetUDPHandlerForFlow pattern below to - // provide mechanism for tsnet to reject a port other than accepting it and - // closing it. - ForwardTCPIn func(c net.Conn, port uint16) + // If func returns intercept=false, the default forwarding behavior (if + // ProcessLocalIPs and/or ProcesssSubnetIPs) takes place. + // + // When intercept=true, the behavior depends on whether the returned handler + // is non-nil: if nil, the connection is rejected. If non-nil, handler takes + // over the TCP conn. + GetTCPHandlerForFlow func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) // GetUDPHandlerForFlow conditionally handles an incoming UDP flow for the // provided (src/port, dst/port) 4-tuple. @@ -795,6 +801,8 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { dialIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) isTailscaleIP := tsaddr.IsTailscaleIP(dialIP) + dstAddrPort := netip.AddrPortFrom(dialIP, reqDetails.LocalPort) + if viaRange.Contains(dialIP) { isTailscaleIP = false dialIP = tsaddr.UnmapVia(dialIP) @@ -913,13 +921,20 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } } - if ns.ForwardTCPIn != nil { - c := createConn() - if c == nil { + if ns.GetTCPHandlerForFlow != nil { + handler, ok := ns.GetTCPHandlerForFlow(clientRemoteAddrPort, dstAddrPort) + if ok { + if handler == nil { + r.Complete(true) + return + } + c := createConn() // will send a RST if it fails + if c == nil { + return + } + handler(c) return } - ns.ForwardTCPIn(c, reqDetails.LocalPort) - return } if isTailscaleIP { dialIP = netaddr.IPv4(127, 0, 0, 1)