diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 418840ea9..729a96bce 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -743,47 +743,64 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { ns.removeSubnetAddress(dialIP) } }() + var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) - if err != nil { - ns.logf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err) - r.Complete(true) // sends a RST - return - } - r.Complete(false) - - // SetKeepAlive so that idle connections to peers that have forgotten about - // the connection or gone completely offline eventually time out. - // Applications might be setting this on a forwarded connection, but from - // userspace we can not see those, so the best we can do is to always - // perform them with conservative timing. - // TODO(tailscale/tailscale#4522): Netstack defaults match the Linux - // defaults, and results in a little over two hours before the socket would - // be closed due to keepalive. A shorter default might be better, or seeking - // a default from the host IP stack. This also might be a useful - // user-tunable, as in userspace mode this can have broad implications such - // as lingering connections to fork style daemons. On the other side of the - // fence, the long duration timers are low impact values for battery powered - // peers. - ep.SocketOptions().SetKeepAlive(true) - - // The ForwarderRequest.CreateEndpoint above asynchronously - // starts the TCP handshake. Note that the gonet.TCPConn - // methods c.RemoteAddr() and c.LocalAddr() will return nil - // until the handshake actually completes. But we have the - // remote address in reqDetails instead, so we don't use - // gonet.TCPConn.RemoteAddr. The byte copies in both - // directions to/from the gonet.TCPConn in forwardTCP will - // block until the TCP handshake is complete. - c := gonet.NewTCPConn(&wq, ep) + // We can't actually create the endpoint or complete the inbound + // request until we're sure that the connection can be handled by this + // endpoint. This function sets up the TCP connection and should be + // called immediately before a connection is handled. + createConn := func() *gonet.TCPConn { + ep, err := r.CreateEndpoint(&wq) + if err != nil { + ns.logf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err) + r.Complete(true) // sends a RST + return nil + } + r.Complete(false) + + // SetKeepAlive so that idle connections to peers that have forgotten about + // the connection or gone completely offline eventually time out. + // Applications might be setting this on a forwarded connection, but from + // userspace we can not see those, so the best we can do is to always + // perform them with conservative timing. + // TODO(tailscale/tailscale#4522): Netstack defaults match the Linux + // defaults, and results in a little over two hours before the socket would + // be closed due to keepalive. A shorter default might be better, or seeking + // a default from the host IP stack. This also might be a useful + // user-tunable, as in userspace mode this can have broad implications such + // as lingering connections to fork style daemons. On the other side of the + // fence, the long duration timers are low impact values for battery powered + // peers. + ep.SocketOptions().SetKeepAlive(true) + + // The ForwarderRequest.CreateEndpoint above asynchronously + // starts the TCP handshake. Note that the gonet.TCPConn + // methods c.RemoteAddr() and c.LocalAddr() will return nil + // until the handshake actually completes. But we have the + // remote address in reqDetails instead, so we don't use + // gonet.TCPConn.RemoteAddr. The byte copies in both + // directions to/from the gonet.TCPConn in forwardTCP will + // block until the TCP handshake is complete. + return gonet.NewTCPConn(&wq, ep) + } + + // DNS if reqDetails.LocalPort == 53 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) { + c := createConn() + if c == nil { + return + } go ns.dns.HandleTCPConn(c, netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort)) return } if ns.lb != nil { if reqDetails.LocalPort == 22 && ns.processSSH() && ns.isLocalIP(dialIP) { + c := createConn() + if c == nil { + return + } if err := ns.lb.HandleSSHConn(c); err != nil { ns.logf("ssh error: %v", err) } @@ -791,6 +808,11 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } if port, ok := ns.lb.GetPeerAPIPort(dialIP); ok { if reqDetails.LocalPort == port && ns.isLocalIP(dialIP) { + c := createConn() + if c == nil { + return + } + src := netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort) dst := netip.AddrPortFrom(dialIP, port) ns.lb.ServePeerAPIConnection(src, dst, c) @@ -798,12 +820,20 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } } if reqDetails.LocalPort == 80 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) { + c := createConn() + if c == nil { + return + } ns.lb.HandleQuad100Port80Conn(c) return } } if ns.ForwardTCPIn != nil { + c := createConn() + if c == nil { + return + } ns.ForwardTCPIn(c, reqDetails.LocalPort) return } @@ -811,11 +841,13 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { dialIP = netaddr.IPv4(127, 0, 0, 1) } dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort)) - ns.forwardTCP(c, clientRemoteIP, &wq, dialAddr) + + if !ns.forwardTCP(createConn, clientRemoteIP, &wq, dialAddr) { + r.Complete(true) // sends a RST + } } -func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netip.Addr, wq *waiter.Queue, dialAddr netip.AddrPort) { - defer client.Close() +func (ns *Impl) forwardTCP(getClient func() *gonet.TCPConn, clientRemoteIP netip.Addr, wq *waiter.Queue, dialAddr netip.AddrPort) (handled bool) { dialAddrStr := dialAddr.String() if debugNetstack { ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr) @@ -823,6 +855,7 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netip.Addr, wq ctx, cancel := context.WithCancel(context.Background()) defer cancel() + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) // TODO(bradfitz): right EventMask? wq.EventRegister(&waitEntry) defer wq.EventUnregister(&waitEntry) @@ -840,13 +873,29 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netip.Addr, wq } cancel() }() + + // Attempt to dial the outbound connection before we accept the inbound one. var stdDialer net.Dialer server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) if err != nil { - ns.logf("netstack: could not connect to local server at %s: %v", dialAddrStr, err) + ns.logf("netstack: could not connect to local server at %s: %v", dialAddr.String(), err) return } defer server.Close() + + // If we get here, either the getClient call below will succeed and + // return something we can Close, or it will fail and will properly + // respond to the client with a RST. Either way, the caller no longer + // needs to clean up the client connection. + handled = true + + // We dialed the connection; we can complete the client's TCP handshake. + client := getClient() + if client == nil { + return + } + defer client.Close() + backendLocalAddr := server.LocalAddr().(*net.TCPAddr) backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort()) ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) @@ -865,6 +914,7 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netip.Addr, wq ns.logf("proxy connection closed with error: %v", err) } ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr) + return } func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {