ipn/ipnlocal,wgengine/netstack: move LocalBackend specifc serving logic to LocalBackend

The netstack code had a bunch of logic to figure out if the LocalBackend should handle an
incoming connection and then would call the function directly on LocalBackend. Move that
logic to LocalBackend and refactor the methods to return conn handlers.

Updates #cleanup

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/8332/head
Maisem Ali 11 months ago committed by Maisem Ali
parent 5b110685fb
commit fe95d81b43

@ -32,6 +32,7 @@ import (
"go4.org/mem"
"go4.org/netipx"
"golang.org/x/exp/slices"
"gvisor.dev/gvisor/pkg/tcpip"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/control/controlclient"
"tailscale.com/doctor"
@ -2828,14 +2829,14 @@ func (b *LocalBackend) GetPeerAPIPort(ip netip.Addr) (port uint16, ok bool) {
return 0, false
}
// ServePeerAPIConnection serves an already-accepted connection c.
// handlePeerAPIConn serves an already-accepted connection c.
//
// The remote parameter is the remote address.
// The local parameter is the local address (either a Tailscale IPv4
// or IPv6 IP and the peerapi port for that address).
//
// The connection will be closed by ServePeerAPIConnection.
func (b *LocalBackend) ServePeerAPIConnection(remote, local netip.AddrPort, c net.Conn) {
// The connection will be closed by handlePeerAPIConn.
func (b *LocalBackend) handlePeerAPIConn(remote, local netip.AddrPort, c net.Conn) {
b.mu.Lock()
defer b.mu.Unlock()
for _, pln := range b.peerAPIListeners {
@ -2849,6 +2850,48 @@ func (b *LocalBackend) ServePeerAPIConnection(remote, local netip.AddrPort, c ne
return
}
func (b *LocalBackend) isLocalIP(ip netip.Addr) bool {
nm := b.NetMap()
return nm != nil && slices.Contains(nm.Addresses, netip.PrefixFrom(ip, ip.BitLen()))
}
var (
magicDNSIP = tsaddr.TailscaleServiceIP()
magicDNSIPv6 = tsaddr.TailscaleServiceIPv6()
)
// TCPHandlerForDst returns a TCP handler for connections to dst, or nil if
// no handler is needed. It also returns a list of TCP socket options to
// apply to the socket before calling the handler.
func (b *LocalBackend) TCPHandlerForDst(src, dst netip.AddrPort) (handler func(c net.Conn) error, opts []tcpip.SettableSocketOption) {
if dst.Port() == 80 && (dst.Addr() == magicDNSIP || dst.Addr() == magicDNSIPv6) {
return b.HandleQuad100Port80Conn, opts
}
if !b.isLocalIP(dst.Addr()) {
return nil, nil
}
if dst.Port() == 22 && b.ShouldRunSSH() {
// Use a higher keepalive idle time for SSH connections, as they are
// typically long lived and idle connections are more likely to be
// intentional. Ideally we would turn this off entirely, but we can't
// tell the difference between a long lived connection that is idle
// vs a connection that is dead because the peer has gone away.
// We pick 72h as that is typically sufficient for a long weekend.
opts = append(opts, ptr.To(tcpip.KeepaliveIdleOption(72*time.Hour)))
return b.handleSSHConn, opts
}
if port, ok := b.GetPeerAPIPort(dst.Addr()); ok && dst.Port() == port {
return func(c net.Conn) error {
b.handlePeerAPIConn(src, dst, c)
return nil
}, opts
}
if handler := b.tcpHandlerForServe(dst.Port(), src); handler != nil {
return handler, opts
}
return nil, nil
}
func (b *LocalBackend) peerAPIServicesLocked() (ret []tailcfg.Service) {
for _, pln := range b.peerAPIListeners {
proto := tailcfg.PeerAPI4
@ -4674,7 +4717,7 @@ func checkSELinux() {
}
}
func (b *LocalBackend) HandleSSHConn(c net.Conn) (err error) {
func (b *LocalBackend) handleSSHConn(c net.Conn) (err error) {
s, err := b.sshServerOrInit()
if err != nil {
return err
@ -4685,10 +4728,10 @@ func (b *LocalBackend) HandleSSHConn(c net.Conn) (err error) {
// HandleQuad100Port80Conn serves http://100.100.100.100/ on port 80 (and
// the equivalent tsaddr.TailscaleServiceIPv6 address).
func (b *LocalBackend) HandleQuad100Port80Conn(c net.Conn) {
func (b *LocalBackend) HandleQuad100Port80Conn(c net.Conn) error {
var s http.Server
s.Handler = http.HandlerFunc(b.handleQuad100Port80Conn)
s.Serve(netutil.NewOneConnListener(c, nil))
return s.Serve(netutil.NewOneConnListener(c, nil))
}
func validQuad100Host(h string) bool {

@ -780,7 +780,7 @@ func (h *peerAPIHandler) handleServeIngress(w http.ResponseWriter, r *http.Reque
return
}
getConn := func() (net.Conn, bool) {
getConnOrReset := func() (net.Conn, bool) {
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
h.logf("ingress: failed hijacking conn")
@ -798,7 +798,7 @@ func (h *peerAPIHandler) handleServeIngress(w http.ResponseWriter, r *http.Reque
http.Error(w, "denied", http.StatusForbidden)
}
h.ps.b.HandleIngressTCPConn(h.peerNode, target, srcAddr, getConn, sendRST)
h.ps.b.HandleIngressTCPConn(h.peerNode, target, srcAddr, getConnOrReset, sendRST)
}
func (h *peerAPIHandler) handleServeInterfaces(w http.ResponseWriter, r *http.Request) {

@ -162,12 +162,13 @@ func (s *serveListener) handleServeListenersAccept(ln net.Listener) error {
return err
}
srcAddr := conn.RemoteAddr().(*net.TCPAddr).AddrPort()
getConn := func() (net.Conn, bool) { return conn, true }
sendRST := func() {
handler := s.b.tcpHandlerForServe(s.ap.Port(), srcAddr)
if handler == nil {
s.b.logf("serve RST for %v", srcAddr)
conn.Close()
continue
}
go s.b.HandleInterceptedTCPConn(s.ap.Port(), srcAddr, getConn, sendRST)
go handler(conn)
}
}
@ -256,7 +257,7 @@ func (b *LocalBackend) ServeConfig() ipn.ServeConfigView {
return b.serveConfig
}
func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ipn.HostPort, srcAddr netip.AddrPort, getConn func() (net.Conn, bool), sendRST func()) {
func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ipn.HostPort, srcAddr netip.AddrPort, getConnOrReset func() (net.Conn, bool), sendRST func()) {
b.mu.Lock()
sc := b.serveConfig
b.mu.Unlock()
@ -289,7 +290,7 @@ func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ip
if b.getTCPHandlerForFunnelFlow != nil {
handler := b.getTCPHandlerForFunnelFlow(srcAddr, dport)
if handler != nil {
c, ok := getConn()
c, ok := getConnOrReset()
if !ok {
b.logf("localbackend: getConn didn't complete from %v to port %v", srcAddr, dport)
return
@ -298,35 +299,40 @@ func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ip
return
}
}
// TODO(bradfitz): pass ingressPeer etc in context to HandleInterceptedTCPConn,
// TODO(bradfitz): pass ingressPeer etc in context to tcpHandlerForServe,
// extend serveHTTPContext or similar.
b.HandleInterceptedTCPConn(dport, srcAddr, getConn, sendRST)
handler := b.tcpHandlerForServe(dport, srcAddr)
if handler == nil {
sendRST()
return
}
c, ok := getConnOrReset()
if !ok {
b.logf("localbackend: getConn didn't complete from %v to port %v", srcAddr, dport)
return
}
handler(c)
}
func (b *LocalBackend) HandleInterceptedTCPConn(dport uint16, srcAddr netip.AddrPort, getConn func() (net.Conn, bool), sendRST func()) {
// tcpHandlerForServe returns a handler for a TCP connection to be served via
// the ipn.ServeConfig.
func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort) (handler func(net.Conn) error) {
b.mu.Lock()
sc := b.serveConfig
b.mu.Unlock()
if !sc.Valid() {
b.logf("[unexpected] localbackend: got TCP conn w/o serveConfig; from %v to port %v", srcAddr, dport)
sendRST()
return
return nil
}
tcph, ok := sc.TCP().GetOk(dport)
if !ok {
b.logf("[unexpected] localbackend: got TCP conn without TCP config for port %v; from %v", dport, srcAddr)
sendRST()
return
return nil
}
if tcph.HTTPS() {
conn, ok := getConn()
if !ok {
b.logf("localbackend: getConn didn't complete from %v to port %v", srcAddr, dport)
return
}
hs := &http.Server{
TLSConfig: &tls.Config{
GetCertificate: b.getTLSServeCertForPort(dport),
@ -339,64 +345,57 @@ func (b *LocalBackend) HandleInterceptedTCPConn(dport uint16, srcAddr netip.Addr
})
},
}
hs.ServeTLS(netutil.NewOneConnListener(conn, nil), "", "")
return
return func(c net.Conn) error {
return hs.ServeTLS(netutil.NewOneConnListener(c, nil), "", "")
}
}
if backDst := tcph.TCPForward(); backDst != "" {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
backConn, err := b.dialer.SystemDial(ctx, "tcp", backDst)
cancel()
if err != nil {
b.logf("localbackend: failed to TCP proxy port %v (from %v) to %s: %v", dport, srcAddr, backDst, err)
sendRST()
return
}
conn, ok := getConn()
if !ok {
b.logf("localbackend: getConn didn't complete from %v to port %v", srcAddr, dport)
backConn.Close()
return
}
defer conn.Close()
defer backConn.Close()
if sni := tcph.TerminateTLS(); sni != "" {
conn = tls.Server(conn, &tls.Config{
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
pair, err := b.GetCertPEM(ctx, sni)
if err != nil {
return nil, err
}
cert, err := tls.X509KeyPair(pair.CertPEM, pair.KeyPEM)
if err != nil {
return nil, err
}
return &cert, nil
},
})
}
return func(conn net.Conn) error {
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
backConn, err := b.dialer.SystemDial(ctx, "tcp", backDst)
cancel()
if err != nil {
b.logf("localbackend: failed to TCP proxy port %v (from %v) to %s: %v", dport, srcAddr, backDst, err)
return nil
}
defer backConn.Close()
if sni := tcph.TerminateTLS(); sni != "" {
conn = tls.Server(conn, &tls.Config{
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
pair, err := b.GetCertPEM(ctx, sni)
if err != nil {
return nil, err
}
cert, err := tls.X509KeyPair(pair.CertPEM, pair.KeyPEM)
if err != nil {
return nil, err
}
return &cert, nil
},
})
}
// TODO(bradfitz): do the RegisterIPPortIdentity and
// UnregisterIPPortIdentity stuff that netstack does
errc := make(chan error, 1)
go func() {
_, err := io.Copy(backConn, conn)
errc <- err
}()
go func() {
_, err := io.Copy(conn, backConn)
errc <- err
}()
<-errc
return
// TODO(bradfitz): do the RegisterIPPortIdentity and
// UnregisterIPPortIdentity stuff that netstack does
errc := make(chan error, 1)
go func() {
_, err := io.Copy(backConn, conn)
errc <- err
}()
go func() {
_, err := io.Copy(conn, backConn)
errc <- err
}()
return <-errc
}
}
b.logf("closing TCP conn to port %v (from %v) with actionless TCPPortHandler", dport, srcAddr)
sendRST()
return nil
}
func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, at string, ok bool) {

@ -537,10 +537,6 @@ func (ns *Impl) isLocalIP(ip netip.Addr) bool {
return ns.atomicIsLocalIPFunc.Load()(ip)
}
func (ns *Impl) processSSH() bool {
return ns.lb != nil && ns.lb.ShouldRunSSH()
}
func (ns *Impl) peerAPIPortAtomic(ip netip.Addr) *atomic.Uint32 {
if ip.Is4() {
return &ns.peerapiPort4Atomic
@ -840,7 +836,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
// request until we're sure that the connection can be handled by this
// endpoint. This function sets up the TCP connection and should be
// called immediately before a connection is handled.
createConn := func(opts ...tcpip.SettableSocketOption) *gonet.TCPConn {
getConnOrReset := func(opts ...tcpip.SettableSocketOption) *gonet.TCPConn {
ep, err := r.CreateEndpoint(&wq)
if err != nil {
ns.logf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err)
@ -879,7 +875,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
// DNS
if reqDetails.LocalPort == 53 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) {
c := createConn()
c := getConnOrReset()
if c == nil {
return
}
@ -888,53 +884,13 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
}
if ns.lb != nil {
if reqDetails.LocalPort == 22 && ns.processSSH() && ns.isLocalIP(dialIP) {
// Use a higher keepalive idle time for SSH connections, as they are
// typically long lived and idle connections are more likely to be
// intentional. Ideally we would turn this off entirely, but we can't
// tell the difference between a long lived connection that is idle
// vs a connection that is dead because the peer has gone away.
// We pick 72h as that is typically sufficient for a long weekend.
idle := tcpip.KeepaliveIdleOption(72 * time.Hour)
c := createConn(&idle)
if c == nil {
return
}
if err := ns.lb.HandleSSHConn(c); err != nil {
ns.logf("ssh error: %v", err)
}
return
}
if port, ok := ns.lb.GetPeerAPIPort(dialIP); ok {
if reqDetails.LocalPort == port && ns.isLocalIP(dialIP) {
c := createConn()
if c == nil {
return
}
src := netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort)
dst := netip.AddrPortFrom(dialIP, port)
ns.lb.ServePeerAPIConnection(src, dst, c)
return
}
}
if reqDetails.LocalPort == 80 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) {
c := createConn()
handler, opts := ns.lb.TCPHandlerForDst(clientRemoteAddrPort, dstAddrPort)
if handler != nil {
c := getConnOrReset(opts...) // will send a RST if it fails
if c == nil {
return
}
ns.lb.HandleQuad100Port80Conn(c)
return
}
if ns.lb.ShouldInterceptTCPPort(reqDetails.LocalPort) && ns.isLocalIP(dialIP) {
getTCPConn := func() (_ net.Conn, ok bool) {
c := createConn()
return c, c != nil
}
sendRST := func() {
r.Complete(true)
}
ns.lb.HandleInterceptedTCPConn(reqDetails.LocalPort, clientRemoteAddrPort, getTCPConn, sendRST)
handler(c)
return
}
}
@ -946,7 +902,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
r.Complete(true)
return
}
c := createConn() // will send a RST if it fails
c := getConnOrReset() // will send a RST if it fails
if c == nil {
return
}
@ -959,7 +915,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
}
dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort))
if !ns.forwardTCP(createConn, clientRemoteIP, &wq, dialAddr) {
if !ns.forwardTCP(getConnOrReset, clientRemoteIP, &wq, dialAddr) {
r.Complete(true) // sends a RST
}
}

Loading…
Cancel
Save