wgengine/netstack: only accept connection after dialing (#5503)

If we accept a forwarded TCP connection before dialing, we can
erroneously signal to a client that we support IPv6 (or IPv4) without
that actually being possible. Instead, we only complete the client's TCP
handshake after we've dialed the outbound connection; if that fails, we
respond with a RST.

Updates #5425 (maybe fixes!)

Signed-off-by: Andrew Dunham <andrew@tailscale.com>
pull/5566/head
Andrew Dunham 2 years ago committed by GitHub
parent 2f702b150e
commit 9240f5c1e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -743,12 +743,19 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
ns.removeSubnetAddress(dialIP) ns.removeSubnetAddress(dialIP)
} }
}() }()
var wq waiter.Queue var wq waiter.Queue
// We can't actually create the endpoint or complete the inbound
// request until we're sure that the connection can be handled by this
// endpoint. This function sets up the TCP connection and should be
// called immediately before a connection is handled.
createConn := func() *gonet.TCPConn {
ep, err := r.CreateEndpoint(&wq) ep, err := r.CreateEndpoint(&wq)
if err != nil { if err != nil {
ns.logf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err) ns.logf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err)
r.Complete(true) // sends a RST r.Complete(true) // sends a RST
return return nil
} }
r.Complete(false) r.Complete(false)
@ -775,15 +782,25 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
// gonet.TCPConn.RemoteAddr. The byte copies in both // gonet.TCPConn.RemoteAddr. The byte copies in both
// directions to/from the gonet.TCPConn in forwardTCP will // directions to/from the gonet.TCPConn in forwardTCP will
// block until the TCP handshake is complete. // block until the TCP handshake is complete.
c := gonet.NewTCPConn(&wq, ep) return gonet.NewTCPConn(&wq, ep)
}
// DNS
if reqDetails.LocalPort == 53 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) { if reqDetails.LocalPort == 53 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) {
c := createConn()
if c == nil {
return
}
go ns.dns.HandleTCPConn(c, netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort)) go ns.dns.HandleTCPConn(c, netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort))
return return
} }
if ns.lb != nil { if ns.lb != nil {
if reqDetails.LocalPort == 22 && ns.processSSH() && ns.isLocalIP(dialIP) { if reqDetails.LocalPort == 22 && ns.processSSH() && ns.isLocalIP(dialIP) {
c := createConn()
if c == nil {
return
}
if err := ns.lb.HandleSSHConn(c); err != nil { if err := ns.lb.HandleSSHConn(c); err != nil {
ns.logf("ssh error: %v", err) ns.logf("ssh error: %v", err)
} }
@ -791,6 +808,11 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
} }
if port, ok := ns.lb.GetPeerAPIPort(dialIP); ok { if port, ok := ns.lb.GetPeerAPIPort(dialIP); ok {
if reqDetails.LocalPort == port && ns.isLocalIP(dialIP) { if reqDetails.LocalPort == port && ns.isLocalIP(dialIP) {
c := createConn()
if c == nil {
return
}
src := netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort) src := netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort)
dst := netip.AddrPortFrom(dialIP, port) dst := netip.AddrPortFrom(dialIP, port)
ns.lb.ServePeerAPIConnection(src, dst, c) ns.lb.ServePeerAPIConnection(src, dst, c)
@ -798,12 +820,20 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
} }
} }
if reqDetails.LocalPort == 80 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) { if reqDetails.LocalPort == 80 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) {
c := createConn()
if c == nil {
return
}
ns.lb.HandleQuad100Port80Conn(c) ns.lb.HandleQuad100Port80Conn(c)
return return
} }
} }
if ns.ForwardTCPIn != nil { if ns.ForwardTCPIn != nil {
c := createConn()
if c == nil {
return
}
ns.ForwardTCPIn(c, reqDetails.LocalPort) ns.ForwardTCPIn(c, reqDetails.LocalPort)
return return
} }
@ -811,11 +841,13 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
dialIP = netaddr.IPv4(127, 0, 0, 1) dialIP = netaddr.IPv4(127, 0, 0, 1)
} }
dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort)) dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort))
ns.forwardTCP(c, clientRemoteIP, &wq, dialAddr)
if !ns.forwardTCP(createConn, clientRemoteIP, &wq, dialAddr) {
r.Complete(true) // sends a RST
}
} }
func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netip.Addr, wq *waiter.Queue, dialAddr netip.AddrPort) { func (ns *Impl) forwardTCP(getClient func() *gonet.TCPConn, clientRemoteIP netip.Addr, wq *waiter.Queue, dialAddr netip.AddrPort) (handled bool) {
defer client.Close()
dialAddrStr := dialAddr.String() dialAddrStr := dialAddr.String()
if debugNetstack { if debugNetstack {
ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr) ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr)
@ -823,6 +855,7 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netip.Addr, wq
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) // TODO(bradfitz): right EventMask? waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) // TODO(bradfitz): right EventMask?
wq.EventRegister(&waitEntry) wq.EventRegister(&waitEntry)
defer wq.EventUnregister(&waitEntry) defer wq.EventUnregister(&waitEntry)
@ -840,13 +873,29 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netip.Addr, wq
} }
cancel() cancel()
}() }()
// Attempt to dial the outbound connection before we accept the inbound one.
var stdDialer net.Dialer var stdDialer net.Dialer
server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr)
if err != nil { if err != nil {
ns.logf("netstack: could not connect to local server at %s: %v", dialAddrStr, err) ns.logf("netstack: could not connect to local server at %s: %v", dialAddr.String(), err)
return return
} }
defer server.Close() defer server.Close()
// If we get here, either the getClient call below will succeed and
// return something we can Close, or it will fail and will properly
// respond to the client with a RST. Either way, the caller no longer
// needs to clean up the client connection.
handled = true
// We dialed the connection; we can complete the client's TCP handshake.
client := getClient()
if client == nil {
return
}
defer client.Close()
backendLocalAddr := server.LocalAddr().(*net.TCPAddr) backendLocalAddr := server.LocalAddr().(*net.TCPAddr)
backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort()) backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort())
ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
@ -865,6 +914,7 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netip.Addr, wq
ns.logf("proxy connection closed with error: %v", err) ns.logf("proxy connection closed with error: %v", err)
} }
ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr) ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr)
return
} }
func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {

Loading…
Cancel
Save