diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 94ee56085..0570f07db 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -332,6 +332,7 @@ func run() error { socksListener, httpProxyListener := mustStartProxyListeners(args.socksAddr, args.httpProxyAddr) dialer := new(tsdial.Dialer) // mutated below (before used) + dialer.Logf = logf e, useNetstack, err := createEngine(logf, linkMon, dialer) if err != nil { return fmt.Errorf("createEngine: %w", err) @@ -394,6 +395,7 @@ func run() error { // want to keep running. signal.Ignore(syscall.SIGPIPE) go func() { + defer dialer.Close() select { case s := <-interrupt: logf("tailscaled got signal %v; shutting down", s) diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index c4535a7f7..3a377f7f9 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -38,9 +38,9 @@ import ( "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" "tailscale.com/net/interfaces" - "tailscale.com/net/netns" "tailscale.com/net/netutil" "tailscale.com/net/tlsdial" + "tailscale.com/net/tsdial" "tailscale.com/net/tshttpproxy" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -57,7 +57,8 @@ import ( // Direct is the client that connects to a tailcontrol server for a node. type Direct struct { httpc *http.Client // HTTP client used to talk to tailcontrol - serverURL string // URL of the tailcontrol server + dialer *tsdial.Dialer + serverURL string // URL of the tailcontrol server timeNow func() time.Time lastPrintMap time.Time newDecompressor func() (Decompressor, error) @@ -106,6 +107,7 @@ type Options struct { DebugFlags []string // debug settings to send to control LinkMonitor *monitor.Mon // optional link monitor PopBrowserURL func(url string) // optional func to open browser + Dialer *tsdial.Dialer // non-nil // KeepSharerAndUserSplit controls whether the client // understands Node.Sharer. If false, the Sharer is mapped to the User. @@ -170,13 +172,12 @@ func NewDirect(opts Options) (*Direct, error) { UseLastGood: true, LookupIPFallback: dnsfallback.Lookup, } - dialer := netns.NewDialer(opts.Logf) tr := http.DefaultTransport.(*http.Transport).Clone() tr.Proxy = tshttpproxy.ProxyFromEnvironment tshttpproxy.SetTransportGetProxyConnectHeader(tr) tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), tr.TLSClientConfig) - tr.DialContext = dnscache.Dialer(dialer.DialContext, dnsCache) - tr.DialTLSContext = dnscache.TLSDialer(dialer.DialContext, dnsCache, tr.TLSClientConfig) + tr.DialContext = dnscache.Dialer(opts.Dialer.SystemDial, dnsCache) + tr.DialTLSContext = dnscache.TLSDialer(opts.Dialer.SystemDial, dnsCache, tr.TLSClientConfig) tr.ForceAttemptHTTP2 = true // Disable implicit gzip compression; the various // handlers (register, map, set-dns, etc) do their own @@ -202,6 +203,7 @@ func NewDirect(opts Options) (*Direct, error) { skipIPForwardingCheck: opts.SkipIPForwardingCheck, pinger: opts.Pinger, popBrowser: opts.PopBrowserURL, + dialer: opts.Dialer, } if opts.Hostinfo == nil { c.SetHostinfo(hostinfo.New()) @@ -1278,7 +1280,7 @@ func (c *Direct) getNoiseClient() (*noiseClient, error) { return nil, err } - nc, err = newNoiseClient(k, serverNoiseKey, c.serverURL) + nc, err = newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer) if err != nil { return nil, err } diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index d1e678ed1..0063eec17 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -14,6 +14,7 @@ import ( "inet.af/netaddr" "tailscale.com/hostinfo" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -30,6 +31,7 @@ func TestNewDirect(t *testing.T) { GetMachinePrivateKey: func() (key.MachinePrivate, error) { return k, nil }, + Dialer: new(tsdial.Dialer), } c, err := NewDirect(opts) if err != nil { @@ -106,6 +108,7 @@ func TestTsmpPing(t *testing.T) { GetMachinePrivateKey: func() (key.MachinePrivate, error) { return k, nil }, + Dialer: new(tsdial.Dialer), } c, err := NewDirect(opts) diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index 54c1e0ce6..8fc0d714e 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -18,6 +18,7 @@ import ( "golang.org/x/net/http2" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp" + "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/util/mak" @@ -46,6 +47,7 @@ func (c *noiseConn) Close() error { // the ts2021 protocol. type noiseClient struct { *http.Client // HTTP client used to talk to tailcontrol + dialer *tsdial.Dialer privKey key.MachinePrivate serverPubKey key.MachinePublic serverHost string // the host:port part of serverURL @@ -58,7 +60,7 @@ type noiseClient struct { // newNoiseClient returns a new noiseClient for the provided server and machine key. // serverURL is of the form https://: (no trailing slash). -func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string) (*noiseClient, error) { +func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer) (*noiseClient, error) { u, err := url.Parse(serverURL) if err != nil { return nil, err @@ -75,6 +77,7 @@ func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, s serverPubKey: serverPubKey, privKey: priKey, serverHost: host, + dialer: dialer, } // Create the HTTP/2 Transport using a net/http.Transport @@ -151,7 +154,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) { // thousand version numbers before getting to this point. panic("capability version is too high to fit in the wire protocol") } - conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion)) + conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion), nc.dialer.SystemDial) if err != nil { return nil, err } diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 4273eeb0f..ca407c101 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -25,7 +25,6 @@ import ( "errors" "fmt" "io" - "log" "net" "net/http" "net/http/httptrace" @@ -35,7 +34,6 @@ import ( "tailscale.com/control/controlbase" "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" - "tailscale.com/net/netns" "tailscale.com/net/netutil" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" @@ -66,7 +64,7 @@ const ( // // The provided ctx is only used for the initial connection, until // Dial returns. It does not affect the connection once established. -func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*controlbase.Conn, error) { +func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -80,6 +78,7 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr controlKey: controlKey, version: protocolVersion, proxyFunc: tshttpproxy.ProxyFromEnvironment, + dialer: dialer, } return a.dial() } @@ -93,6 +92,7 @@ type dialParams struct { controlKey key.MachinePublic version uint16 proxyFunc func(*http.Request) (*url.URL, error) // or nil + dialer dnscache.DialContextFunc // For tests only insecureTLS bool @@ -196,12 +196,11 @@ func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.C LookupIPFallback: dnsfallback.Lookup, UseLastGood: true, } - dialer := netns.NewDialer(log.Printf) tr := http.DefaultTransport.(*http.Transport).Clone() defer tr.CloseIdleConnections() tr.Proxy = a.proxyFunc tshttpproxy.SetTransportGetProxyConnectHeader(tr) - tr.DialContext = dnscache.Dialer(dialer.DialContext, dns) + tr.DialContext = dnscache.Dialer(a.dialer, dns) // Disable HTTP2, since h2 can't do protocol switching. tr.TLSClientConfig.NextProtos = []string{} tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} @@ -210,7 +209,7 @@ func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.C tr.TLSClientConfig.InsecureSkipVerify = true tr.TLSClientConfig.VerifyConnection = nil } - tr.DialTLSContext = dnscache.TLSDialer(dialer.DialContext, dns, tr.TLSClientConfig) + tr.DialTLSContext = dnscache.TLSDialer(a.dialer, dns, tr.TLSClientConfig) tr.DisableCompression = true // (mis)use httptrace to extract the underlying net.Conn from the diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 1d2adf124..5f942c895 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -20,6 +20,7 @@ import ( "tailscale.com/control/controlbase" "tailscale.com/net/socks5" + "tailscale.com/net/tsdial" "tailscale.com/types/key" ) @@ -155,6 +156,7 @@ func testControlHTTP(t *testing.T, proxy proxy) { controlKey: server.Public(), version: testProtocolVersion, insecureTLS: true, + dialer: new(tsdial.Dialer).SystemDial, } if proxy != nil { diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 000f0d91c..60e2ecbfa 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -1034,6 +1034,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { LinkMonitor: b.e.GetLinkMonitor(), Pinger: b.e, PopBrowserURL: b.tellClientToBrowseToURL, + Dialer: b.Dialer(), // Don't warn about broken Linux IP forwarding when // netstack is being used. diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index 961567232..0b2468c1c 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -20,8 +20,12 @@ import ( "inet.af/netaddr" "tailscale.com/net/dnscache" + "tailscale.com/net/interfaces" "tailscale.com/net/netknob" + "tailscale.com/net/netns" + "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/util/mak" "tailscale.com/wgengine/monitor" ) @@ -30,6 +34,7 @@ import ( // (TUN, netstack), the OS network sandboxing style (macOS/iOS // Extension, none), user-selected route acceptance prefs, etc. type Dialer struct { + Logf logger.Logf // UseNetstackForIP if non-nil is whether NetstackDialTCP (if // it's non-nil) should be used to dial the provided IP. UseNetstackForIP func(netaddr.IP) bool @@ -46,12 +51,33 @@ type Dialer struct { peerDialerOnce sync.Once peerDialer *net.Dialer - mu sync.Mutex - dns dnsMap - tunName string // tun device name - linkMon *monitor.Mon - exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?') - dnsCache *dnscache.MessageCache // nil until first first non-empty SetExitDNSDoH + netnsDialerOnce sync.Once + netnsDialer netns.Dialer + + mu sync.Mutex + closed bool + dns dnsMap + tunName string // tun device name + linkMon *monitor.Mon + linkMonUnregister func() + exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?') + dnsCache *dnscache.MessageCache // nil until first first non-empty SetExitDNSDoH + nextSysConnID int + activeSysConns map[int]net.Conn // active connections not yet closed +} + +// sysConn wraps a net.Conn that was created using d.SystemDial. +// It exists to track which connections are still open, and should be +// closed on major link changes. +type sysConn struct { + net.Conn + id int + d *Dialer +} + +func (c sysConn) Close() error { + c.d.closeSysConn(c.id) + return nil } // SetTUNName sets the name of the tun device in use ("tailscale0", "utun6", @@ -91,10 +117,53 @@ func (d *Dialer) SetExitDNSDoH(doh string) { } } +func (d *Dialer) Close() error { + d.mu.Lock() + defer d.mu.Unlock() + d.closed = true + if d.linkMonUnregister != nil { + d.linkMonUnregister() + d.linkMonUnregister = nil + } + for _, c := range d.activeSysConns { + c.Close() + } + d.activeSysConns = nil + return nil +} + func (d *Dialer) SetLinkMonitor(mon *monitor.Mon) { d.mu.Lock() defer d.mu.Unlock() + if d.linkMonUnregister != nil { + go d.linkMonUnregister() + d.linkMonUnregister = nil + } d.linkMon = mon + d.linkMonUnregister = d.linkMon.RegisterChangeCallback(d.linkChanged) +} + +func (d *Dialer) linkChanged(major bool, state *interfaces.State) { + if !major { + return + } + d.mu.Lock() + defer d.mu.Unlock() + for id, c := range d.activeSysConns { + go c.Close() + delete(d.activeSysConns, id) + } +} + +func (d *Dialer) closeSysConn(id int) { + d.mu.Lock() + defer d.mu.Unlock() + c, ok := d.activeSysConns[id] + if !ok { + return + } + delete(d.activeSysConns, id) + go c.Close() // ignore the error } func (d *Dialer) interfaceIndexLocked(ifName string) (index int, ok bool) { @@ -197,6 +266,42 @@ func ipNetOfNetwork(n string) string { return "ip" } +// SystemDial connects to the provided network address without going over +// Tailscale. It prefers going over the default interface and closes existing +// connections if the default interface changes. It is used to connect to +// Control and (in the future, as of 2022-04-27) DERPs.. +func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn, error) { + d.mu.Lock() + closed := d.closed + d.mu.Unlock() + if closed { + return nil, net.ErrClosed + } + + d.netnsDialerOnce.Do(func() { + logf := d.Logf + if logf == nil { + logf = logger.Discard + } + d.netnsDialer = netns.NewDialer(logf) + }) + c, err := d.netnsDialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + d.mu.Lock() + defer d.mu.Unlock() + id := d.nextSysConnID + d.nextSysConnID++ + mak.Set(&d.activeSysConns, id, c) + + return sysConn{ + id: id, + d: d, + Conn: c, + }, nil +} + // UserDial connects to the provided network address as if a user were initiating the dial. // (e.g. from a SOCKS or HTTP outbound proxy) func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, error) { diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index ae6b5c4b0..efbbb91ac 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -105,6 +105,7 @@ func (s *Server) Close() error { s.shutdownCancel() s.lb.Shutdown() s.linkMon.Close() + s.dialer.Close() s.localAPIListener.Close() s.mu.Lock()