diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index ac685b9b4..c226a5852 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -491,7 +491,7 @@ func (s *Server) start() (reterr error) { return fmt.Errorf("netstack.Create: %w", err) } ns.ProcessLocalIPs = true - ns.ForwardTCPIn = s.forwardTCP + ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow s.netstack = ns s.dialer.UseNetstackForIP = func(ip netip.Addr) bool { @@ -660,20 +660,12 @@ func (s *Server) listenerForDstAddr(netBase string, dst netip.AddrPort) (_ *list return nil, false } -func (s *Server) forwardTCP(c net.Conn, port uint16) { - dstStr := c.LocalAddr().String() - ap, err := netip.ParseAddrPort(dstStr) - if err != nil { - s.logf("unexpected dst addr %q", dstStr) - c.Close() - return - } - ln, ok := s.listenerForDstAddr("tcp", ap) +func (s *Server) getTCPHandlerForFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { + ln, ok := s.listenerForDstAddr("tcp", dst) if !ok { - c.Close() - return + return nil, true // don't handle, don't forward to localhost } - ln.handle(c) + return ln.handle, true } func (s *Server) getUDPHandlerForFlow(src, dst netip.AddrPort) (handler func(nettype.ConnPacketConn), intercept bool) { diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index bc5adc2e2..ab55b7b60 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -175,6 +175,11 @@ func TestConn(t *testing.T) { if string(got) != want { t.Errorf("got %q, want %q", got, want) } + + _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8082", s1ip)) // some random port + if err == nil { + t.Fatalf("unexpected success; should have seen a connection refused error") + } } func TestLoopbackLocalAPI(t *testing.T) { diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 131c0d392..54a039dd6 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -79,12 +79,18 @@ func init() { // and implements wgengine.FakeImpl to act as a userspace network // stack when Tailscale is running in fake mode. type Impl struct { - // ForwardTCPIn, if non-nil, handles forwarding an inbound TCP connection. + // GetTCPHandlerForFlow conditionally handles an incoming TCP flow for the + // provided (src/port, dst/port) 4-tuple. + // + // A nil value is equivalent to a func returning (nil, false). // - // TODO(bradfitz): convert this to the GetUDPHandlerForFlow pattern below to - // provide mechanism for tsnet to reject a port other than accepting it and - // closing it. - ForwardTCPIn func(c net.Conn, port uint16) + // If func returns intercept=false, the default forwarding behavior (if + // ProcessLocalIPs and/or ProcesssSubnetIPs) takes place. + // + // When intercept=true, the behavior depends on whether the returned handler + // is non-nil: if nil, the connection is rejected. If non-nil, handler takes + // over the TCP conn. + GetTCPHandlerForFlow func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) // GetUDPHandlerForFlow conditionally handles an incoming UDP flow for the // provided (src/port, dst/port) 4-tuple. @@ -795,6 +801,8 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { dialIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) isTailscaleIP := tsaddr.IsTailscaleIP(dialIP) + dstAddrPort := netip.AddrPortFrom(dialIP, reqDetails.LocalPort) + if viaRange.Contains(dialIP) { isTailscaleIP = false dialIP = tsaddr.UnmapVia(dialIP) @@ -913,13 +921,20 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } } - if ns.ForwardTCPIn != nil { - c := createConn() - if c == nil { + if ns.GetTCPHandlerForFlow != nil { + handler, ok := ns.GetTCPHandlerForFlow(clientRemoteAddrPort, dstAddrPort) + if ok { + if handler == nil { + r.Complete(true) + return + } + c := createConn() // will send a RST if it fails + if c == nil { + return + } + handler(c) return } - ns.ForwardTCPIn(c, reqDetails.LocalPort) - return } if isTailscaleIP { dialIP = netaddr.IPv4(127, 0, 0, 1)