ipn/ipnlocal,wgengine/netstack: move LocalBackend specifc serving logic to LocalBackend

The netstack code had a bunch of logic to figure out if the LocalBackend should handle an
incoming connection and then would call the function directly on LocalBackend. Move that
logic to LocalBackend and refactor the methods to return conn handlers.

Updates #cleanup

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/8332/head
Maisem Ali 1 year ago committed by Maisem Ali
parent 5b110685fb
commit fe95d81b43

@ -32,6 +32,7 @@ import (
"go4.org/mem" "go4.org/mem"
"go4.org/netipx" "go4.org/netipx"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"gvisor.dev/gvisor/pkg/tcpip"
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/control/controlclient" "tailscale.com/control/controlclient"
"tailscale.com/doctor" "tailscale.com/doctor"
@ -2828,14 +2829,14 @@ func (b *LocalBackend) GetPeerAPIPort(ip netip.Addr) (port uint16, ok bool) {
return 0, false return 0, false
} }
// ServePeerAPIConnection serves an already-accepted connection c. // handlePeerAPIConn serves an already-accepted connection c.
// //
// The remote parameter is the remote address. // The remote parameter is the remote address.
// The local parameter is the local address (either a Tailscale IPv4 // The local parameter is the local address (either a Tailscale IPv4
// or IPv6 IP and the peerapi port for that address). // or IPv6 IP and the peerapi port for that address).
// //
// The connection will be closed by ServePeerAPIConnection. // The connection will be closed by handlePeerAPIConn.
func (b *LocalBackend) ServePeerAPIConnection(remote, local netip.AddrPort, c net.Conn) { func (b *LocalBackend) handlePeerAPIConn(remote, local netip.AddrPort, c net.Conn) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
for _, pln := range b.peerAPIListeners { for _, pln := range b.peerAPIListeners {
@ -2849,6 +2850,48 @@ func (b *LocalBackend) ServePeerAPIConnection(remote, local netip.AddrPort, c ne
return return
} }
func (b *LocalBackend) isLocalIP(ip netip.Addr) bool {
nm := b.NetMap()
return nm != nil && slices.Contains(nm.Addresses, netip.PrefixFrom(ip, ip.BitLen()))
}
var (
magicDNSIP = tsaddr.TailscaleServiceIP()
magicDNSIPv6 = tsaddr.TailscaleServiceIPv6()
)
// TCPHandlerForDst returns a TCP handler for connections to dst, or nil if
// no handler is needed. It also returns a list of TCP socket options to
// apply to the socket before calling the handler.
func (b *LocalBackend) TCPHandlerForDst(src, dst netip.AddrPort) (handler func(c net.Conn) error, opts []tcpip.SettableSocketOption) {
if dst.Port() == 80 && (dst.Addr() == magicDNSIP || dst.Addr() == magicDNSIPv6) {
return b.HandleQuad100Port80Conn, opts
}
if !b.isLocalIP(dst.Addr()) {
return nil, nil
}
if dst.Port() == 22 && b.ShouldRunSSH() {
// Use a higher keepalive idle time for SSH connections, as they are
// typically long lived and idle connections are more likely to be
// intentional. Ideally we would turn this off entirely, but we can't
// tell the difference between a long lived connection that is idle
// vs a connection that is dead because the peer has gone away.
// We pick 72h as that is typically sufficient for a long weekend.
opts = append(opts, ptr.To(tcpip.KeepaliveIdleOption(72*time.Hour)))
return b.handleSSHConn, opts
}
if port, ok := b.GetPeerAPIPort(dst.Addr()); ok && dst.Port() == port {
return func(c net.Conn) error {
b.handlePeerAPIConn(src, dst, c)
return nil
}, opts
}
if handler := b.tcpHandlerForServe(dst.Port(), src); handler != nil {
return handler, opts
}
return nil, nil
}
func (b *LocalBackend) peerAPIServicesLocked() (ret []tailcfg.Service) { func (b *LocalBackend) peerAPIServicesLocked() (ret []tailcfg.Service) {
for _, pln := range b.peerAPIListeners { for _, pln := range b.peerAPIListeners {
proto := tailcfg.PeerAPI4 proto := tailcfg.PeerAPI4
@ -4674,7 +4717,7 @@ func checkSELinux() {
} }
} }
func (b *LocalBackend) HandleSSHConn(c net.Conn) (err error) { func (b *LocalBackend) handleSSHConn(c net.Conn) (err error) {
s, err := b.sshServerOrInit() s, err := b.sshServerOrInit()
if err != nil { if err != nil {
return err return err
@ -4685,10 +4728,10 @@ func (b *LocalBackend) HandleSSHConn(c net.Conn) (err error) {
// HandleQuad100Port80Conn serves http://100.100.100.100/ on port 80 (and // HandleQuad100Port80Conn serves http://100.100.100.100/ on port 80 (and
// the equivalent tsaddr.TailscaleServiceIPv6 address). // the equivalent tsaddr.TailscaleServiceIPv6 address).
func (b *LocalBackend) HandleQuad100Port80Conn(c net.Conn) { func (b *LocalBackend) HandleQuad100Port80Conn(c net.Conn) error {
var s http.Server var s http.Server
s.Handler = http.HandlerFunc(b.handleQuad100Port80Conn) s.Handler = http.HandlerFunc(b.handleQuad100Port80Conn)
s.Serve(netutil.NewOneConnListener(c, nil)) return s.Serve(netutil.NewOneConnListener(c, nil))
} }
func validQuad100Host(h string) bool { func validQuad100Host(h string) bool {

@ -780,7 +780,7 @@ func (h *peerAPIHandler) handleServeIngress(w http.ResponseWriter, r *http.Reque
return return
} }
getConn := func() (net.Conn, bool) { getConnOrReset := func() (net.Conn, bool) {
conn, _, err := w.(http.Hijacker).Hijack() conn, _, err := w.(http.Hijacker).Hijack()
if err != nil { if err != nil {
h.logf("ingress: failed hijacking conn") h.logf("ingress: failed hijacking conn")
@ -798,7 +798,7 @@ func (h *peerAPIHandler) handleServeIngress(w http.ResponseWriter, r *http.Reque
http.Error(w, "denied", http.StatusForbidden) http.Error(w, "denied", http.StatusForbidden)
} }
h.ps.b.HandleIngressTCPConn(h.peerNode, target, srcAddr, getConn, sendRST) h.ps.b.HandleIngressTCPConn(h.peerNode, target, srcAddr, getConnOrReset, sendRST)
} }
func (h *peerAPIHandler) handleServeInterfaces(w http.ResponseWriter, r *http.Request) { func (h *peerAPIHandler) handleServeInterfaces(w http.ResponseWriter, r *http.Request) {

@ -162,12 +162,13 @@ func (s *serveListener) handleServeListenersAccept(ln net.Listener) error {
return err return err
} }
srcAddr := conn.RemoteAddr().(*net.TCPAddr).AddrPort() srcAddr := conn.RemoteAddr().(*net.TCPAddr).AddrPort()
getConn := func() (net.Conn, bool) { return conn, true } handler := s.b.tcpHandlerForServe(s.ap.Port(), srcAddr)
sendRST := func() { if handler == nil {
s.b.logf("serve RST for %v", srcAddr) s.b.logf("serve RST for %v", srcAddr)
conn.Close() conn.Close()
continue
} }
go s.b.HandleInterceptedTCPConn(s.ap.Port(), srcAddr, getConn, sendRST) go handler(conn)
} }
} }
@ -256,7 +257,7 @@ func (b *LocalBackend) ServeConfig() ipn.ServeConfigView {
return b.serveConfig return b.serveConfig
} }
func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ipn.HostPort, srcAddr netip.AddrPort, getConn func() (net.Conn, bool), sendRST func()) { func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ipn.HostPort, srcAddr netip.AddrPort, getConnOrReset func() (net.Conn, bool), sendRST func()) {
b.mu.Lock() b.mu.Lock()
sc := b.serveConfig sc := b.serveConfig
b.mu.Unlock() b.mu.Unlock()
@ -289,7 +290,7 @@ func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ip
if b.getTCPHandlerForFunnelFlow != nil { if b.getTCPHandlerForFunnelFlow != nil {
handler := b.getTCPHandlerForFunnelFlow(srcAddr, dport) handler := b.getTCPHandlerForFunnelFlow(srcAddr, dport)
if handler != nil { if handler != nil {
c, ok := getConn() c, ok := getConnOrReset()
if !ok { if !ok {
b.logf("localbackend: getConn didn't complete from %v to port %v", srcAddr, dport) b.logf("localbackend: getConn didn't complete from %v to port %v", srcAddr, dport)
return return
@ -298,35 +299,40 @@ func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ip
return return
} }
} }
// TODO(bradfitz): pass ingressPeer etc in context to HandleInterceptedTCPConn, // TODO(bradfitz): pass ingressPeer etc in context to tcpHandlerForServe,
// extend serveHTTPContext or similar. // extend serveHTTPContext or similar.
b.HandleInterceptedTCPConn(dport, srcAddr, getConn, sendRST) handler := b.tcpHandlerForServe(dport, srcAddr)
if handler == nil {
sendRST()
return
}
c, ok := getConnOrReset()
if !ok {
b.logf("localbackend: getConn didn't complete from %v to port %v", srcAddr, dport)
return
}
handler(c)
} }
func (b *LocalBackend) HandleInterceptedTCPConn(dport uint16, srcAddr netip.AddrPort, getConn func() (net.Conn, bool), sendRST func()) { // tcpHandlerForServe returns a handler for a TCP connection to be served via
// the ipn.ServeConfig.
func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort) (handler func(net.Conn) error) {
b.mu.Lock() b.mu.Lock()
sc := b.serveConfig sc := b.serveConfig
b.mu.Unlock() b.mu.Unlock()
if !sc.Valid() { if !sc.Valid() {
b.logf("[unexpected] localbackend: got TCP conn w/o serveConfig; from %v to port %v", srcAddr, dport) b.logf("[unexpected] localbackend: got TCP conn w/o serveConfig; from %v to port %v", srcAddr, dport)
sendRST() return nil
return
} }
tcph, ok := sc.TCP().GetOk(dport) tcph, ok := sc.TCP().GetOk(dport)
if !ok { if !ok {
b.logf("[unexpected] localbackend: got TCP conn without TCP config for port %v; from %v", dport, srcAddr) b.logf("[unexpected] localbackend: got TCP conn without TCP config for port %v; from %v", dport, srcAddr)
sendRST() return nil
return
} }
if tcph.HTTPS() { if tcph.HTTPS() {
conn, ok := getConn()
if !ok {
b.logf("localbackend: getConn didn't complete from %v to port %v", srcAddr, dport)
return
}
hs := &http.Server{ hs := &http.Server{
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
GetCertificate: b.getTLSServeCertForPort(dport), GetCertificate: b.getTLSServeCertForPort(dport),
@ -339,28 +345,22 @@ func (b *LocalBackend) HandleInterceptedTCPConn(dport uint16, srcAddr netip.Addr
}) })
}, },
} }
hs.ServeTLS(netutil.NewOneConnListener(conn, nil), "", "") return func(c net.Conn) error {
return return hs.ServeTLS(netutil.NewOneConnListener(c, nil), "", "")
}
} }
if backDst := tcph.TCPForward(); backDst != "" { if backDst := tcph.TCPForward(); backDst != "" {
return func(conn net.Conn) error {
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
backConn, err := b.dialer.SystemDial(ctx, "tcp", backDst) backConn, err := b.dialer.SystemDial(ctx, "tcp", backDst)
cancel() cancel()
if err != nil { if err != nil {
b.logf("localbackend: failed to TCP proxy port %v (from %v) to %s: %v", dport, srcAddr, backDst, err) b.logf("localbackend: failed to TCP proxy port %v (from %v) to %s: %v", dport, srcAddr, backDst, err)
sendRST() return nil
return
}
conn, ok := getConn()
if !ok {
b.logf("localbackend: getConn didn't complete from %v to port %v", srcAddr, dport)
backConn.Close()
return
} }
defer conn.Close()
defer backConn.Close() defer backConn.Close()
if sni := tcph.TerminateTLS(); sni != "" { if sni := tcph.TerminateTLS(); sni != "" {
conn = tls.Server(conn, &tls.Config{ conn = tls.Server(conn, &tls.Config{
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
@ -381,7 +381,6 @@ func (b *LocalBackend) HandleInterceptedTCPConn(dport uint16, srcAddr netip.Addr
// TODO(bradfitz): do the RegisterIPPortIdentity and // TODO(bradfitz): do the RegisterIPPortIdentity and
// UnregisterIPPortIdentity stuff that netstack does // UnregisterIPPortIdentity stuff that netstack does
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
_, err := io.Copy(backConn, conn) _, err := io.Copy(backConn, conn)
@ -391,12 +390,12 @@ func (b *LocalBackend) HandleInterceptedTCPConn(dport uint16, srcAddr netip.Addr
_, err := io.Copy(conn, backConn) _, err := io.Copy(conn, backConn)
errc <- err errc <- err
}() }()
<-errc return <-errc
return }
} }
b.logf("closing TCP conn to port %v (from %v) with actionless TCPPortHandler", dport, srcAddr) b.logf("closing TCP conn to port %v (from %v) with actionless TCPPortHandler", dport, srcAddr)
sendRST() return nil
} }
func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, at string, ok bool) { func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, at string, ok bool) {

@ -537,10 +537,6 @@ func (ns *Impl) isLocalIP(ip netip.Addr) bool {
return ns.atomicIsLocalIPFunc.Load()(ip) return ns.atomicIsLocalIPFunc.Load()(ip)
} }
func (ns *Impl) processSSH() bool {
return ns.lb != nil && ns.lb.ShouldRunSSH()
}
func (ns *Impl) peerAPIPortAtomic(ip netip.Addr) *atomic.Uint32 { func (ns *Impl) peerAPIPortAtomic(ip netip.Addr) *atomic.Uint32 {
if ip.Is4() { if ip.Is4() {
return &ns.peerapiPort4Atomic return &ns.peerapiPort4Atomic
@ -840,7 +836,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
// request until we're sure that the connection can be handled by this // request until we're sure that the connection can be handled by this
// endpoint. This function sets up the TCP connection and should be // endpoint. This function sets up the TCP connection and should be
// called immediately before a connection is handled. // called immediately before a connection is handled.
createConn := func(opts ...tcpip.SettableSocketOption) *gonet.TCPConn { getConnOrReset := func(opts ...tcpip.SettableSocketOption) *gonet.TCPConn {
ep, err := r.CreateEndpoint(&wq) ep, err := r.CreateEndpoint(&wq)
if err != nil { if err != nil {
ns.logf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err) ns.logf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err)
@ -879,7 +875,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
// DNS // DNS
if reqDetails.LocalPort == 53 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) { if reqDetails.LocalPort == 53 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) {
c := createConn() c := getConnOrReset()
if c == nil { if c == nil {
return return
} }
@ -888,53 +884,13 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
} }
if ns.lb != nil { if ns.lb != nil {
if reqDetails.LocalPort == 22 && ns.processSSH() && ns.isLocalIP(dialIP) { handler, opts := ns.lb.TCPHandlerForDst(clientRemoteAddrPort, dstAddrPort)
// Use a higher keepalive idle time for SSH connections, as they are if handler != nil {
// typically long lived and idle connections are more likely to be c := getConnOrReset(opts...) // will send a RST if it fails
// intentional. Ideally we would turn this off entirely, but we can't
// tell the difference between a long lived connection that is idle
// vs a connection that is dead because the peer has gone away.
// We pick 72h as that is typically sufficient for a long weekend.
idle := tcpip.KeepaliveIdleOption(72 * time.Hour)
c := createConn(&idle)
if c == nil { if c == nil {
return return
} }
if err := ns.lb.HandleSSHConn(c); err != nil { handler(c)
ns.logf("ssh error: %v", err)
}
return
}
if port, ok := ns.lb.GetPeerAPIPort(dialIP); ok {
if reqDetails.LocalPort == port && ns.isLocalIP(dialIP) {
c := createConn()
if c == nil {
return
}
src := netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort)
dst := netip.AddrPortFrom(dialIP, port)
ns.lb.ServePeerAPIConnection(src, dst, c)
return
}
}
if reqDetails.LocalPort == 80 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) {
c := createConn()
if c == nil {
return
}
ns.lb.HandleQuad100Port80Conn(c)
return
}
if ns.lb.ShouldInterceptTCPPort(reqDetails.LocalPort) && ns.isLocalIP(dialIP) {
getTCPConn := func() (_ net.Conn, ok bool) {
c := createConn()
return c, c != nil
}
sendRST := func() {
r.Complete(true)
}
ns.lb.HandleInterceptedTCPConn(reqDetails.LocalPort, clientRemoteAddrPort, getTCPConn, sendRST)
return return
} }
} }
@ -946,7 +902,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
r.Complete(true) r.Complete(true)
return return
} }
c := createConn() // will send a RST if it fails c := getConnOrReset() // will send a RST if it fails
if c == nil { if c == nil {
return return
} }
@ -959,7 +915,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
} }
dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort)) dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort))
if !ns.forwardTCP(createConn, clientRemoteIP, &wq, dialAddr) { if !ns.forwardTCP(getConnOrReset, clientRemoteIP, &wq, dialAddr) {
r.Complete(true) // sends a RST r.Complete(true) // sends a RST
} }
} }

Loading…
Cancel
Save