net/tsdial: add SystemDial as a wrapper on netns.Dial

The connections returned from SystemDial are automatically closed when
there is a major link change.

Also plumb through the dialer to the noise client so that connections
are auto-reset when moving from cellular to WiFi etc.

Updates #3363

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/4555/head
Maisem Ali 3 years ago committed by Maisem Ali
parent e38d3dfc76
commit 5a1ef1bbb9

@ -332,6 +332,7 @@ func run() error {
socksListener, httpProxyListener := mustStartProxyListeners(args.socksAddr, args.httpProxyAddr) socksListener, httpProxyListener := mustStartProxyListeners(args.socksAddr, args.httpProxyAddr)
dialer := new(tsdial.Dialer) // mutated below (before used) dialer := new(tsdial.Dialer) // mutated below (before used)
dialer.Logf = logf
e, useNetstack, err := createEngine(logf, linkMon, dialer) e, useNetstack, err := createEngine(logf, linkMon, dialer)
if err != nil { if err != nil {
return fmt.Errorf("createEngine: %w", err) return fmt.Errorf("createEngine: %w", err)
@ -394,6 +395,7 @@ func run() error {
// want to keep running. // want to keep running.
signal.Ignore(syscall.SIGPIPE) signal.Ignore(syscall.SIGPIPE)
go func() { go func() {
defer dialer.Close()
select { select {
case s := <-interrupt: case s := <-interrupt:
logf("tailscaled got signal %v; shutting down", s) logf("tailscaled got signal %v; shutting down", s)

@ -38,9 +38,9 @@ import (
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/dnsfallback" "tailscale.com/net/dnsfallback"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
"tailscale.com/net/netns"
"tailscale.com/net/netutil" "tailscale.com/net/netutil"
"tailscale.com/net/tlsdial" "tailscale.com/net/tlsdial"
"tailscale.com/net/tsdial"
"tailscale.com/net/tshttpproxy" "tailscale.com/net/tshttpproxy"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -57,7 +57,8 @@ import (
// Direct is the client that connects to a tailcontrol server for a node. // Direct is the client that connects to a tailcontrol server for a node.
type Direct struct { type Direct struct {
httpc *http.Client // HTTP client used to talk to tailcontrol httpc *http.Client // HTTP client used to talk to tailcontrol
serverURL string // URL of the tailcontrol server dialer *tsdial.Dialer
serverURL string // URL of the tailcontrol server
timeNow func() time.Time timeNow func() time.Time
lastPrintMap time.Time lastPrintMap time.Time
newDecompressor func() (Decompressor, error) newDecompressor func() (Decompressor, error)
@ -106,6 +107,7 @@ type Options struct {
DebugFlags []string // debug settings to send to control DebugFlags []string // debug settings to send to control
LinkMonitor *monitor.Mon // optional link monitor LinkMonitor *monitor.Mon // optional link monitor
PopBrowserURL func(url string) // optional func to open browser PopBrowserURL func(url string) // optional func to open browser
Dialer *tsdial.Dialer // non-nil
// KeepSharerAndUserSplit controls whether the client // KeepSharerAndUserSplit controls whether the client
// understands Node.Sharer. If false, the Sharer is mapped to the User. // understands Node.Sharer. If false, the Sharer is mapped to the User.
@ -170,13 +172,12 @@ func NewDirect(opts Options) (*Direct, error) {
UseLastGood: true, UseLastGood: true,
LookupIPFallback: dnsfallback.Lookup, LookupIPFallback: dnsfallback.Lookup,
} }
dialer := netns.NewDialer(opts.Logf)
tr := http.DefaultTransport.(*http.Transport).Clone() tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment tr.Proxy = tshttpproxy.ProxyFromEnvironment
tshttpproxy.SetTransportGetProxyConnectHeader(tr) tshttpproxy.SetTransportGetProxyConnectHeader(tr)
tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), tr.TLSClientConfig) tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), tr.TLSClientConfig)
tr.DialContext = dnscache.Dialer(dialer.DialContext, dnsCache) tr.DialContext = dnscache.Dialer(opts.Dialer.SystemDial, dnsCache)
tr.DialTLSContext = dnscache.TLSDialer(dialer.DialContext, dnsCache, tr.TLSClientConfig) tr.DialTLSContext = dnscache.TLSDialer(opts.Dialer.SystemDial, dnsCache, tr.TLSClientConfig)
tr.ForceAttemptHTTP2 = true tr.ForceAttemptHTTP2 = true
// Disable implicit gzip compression; the various // Disable implicit gzip compression; the various
// handlers (register, map, set-dns, etc) do their own // handlers (register, map, set-dns, etc) do their own
@ -202,6 +203,7 @@ func NewDirect(opts Options) (*Direct, error) {
skipIPForwardingCheck: opts.SkipIPForwardingCheck, skipIPForwardingCheck: opts.SkipIPForwardingCheck,
pinger: opts.Pinger, pinger: opts.Pinger,
popBrowser: opts.PopBrowserURL, popBrowser: opts.PopBrowserURL,
dialer: opts.Dialer,
} }
if opts.Hostinfo == nil { if opts.Hostinfo == nil {
c.SetHostinfo(hostinfo.New()) c.SetHostinfo(hostinfo.New())
@ -1278,7 +1280,7 @@ func (c *Direct) getNoiseClient() (*noiseClient, error) {
return nil, err return nil, err
} }
nc, err = newNoiseClient(k, serverNoiseKey, c.serverURL) nc, err = newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -14,6 +14,7 @@ import (
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/hostinfo" "tailscale.com/hostinfo"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -30,6 +31,7 @@ func TestNewDirect(t *testing.T) {
GetMachinePrivateKey: func() (key.MachinePrivate, error) { GetMachinePrivateKey: func() (key.MachinePrivate, error) {
return k, nil return k, nil
}, },
Dialer: new(tsdial.Dialer),
} }
c, err := NewDirect(opts) c, err := NewDirect(opts)
if err != nil { if err != nil {
@ -106,6 +108,7 @@ func TestTsmpPing(t *testing.T) {
GetMachinePrivateKey: func() (key.MachinePrivate, error) { GetMachinePrivateKey: func() (key.MachinePrivate, error) {
return k, nil return k, nil
}, },
Dialer: new(tsdial.Dialer),
} }
c, err := NewDirect(opts) c, err := NewDirect(opts)

@ -18,6 +18,7 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp" "tailscale.com/control/controlhttp"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/util/mak" "tailscale.com/util/mak"
@ -46,6 +47,7 @@ func (c *noiseConn) Close() error {
// the ts2021 protocol. // the ts2021 protocol.
type noiseClient struct { type noiseClient struct {
*http.Client // HTTP client used to talk to tailcontrol *http.Client // HTTP client used to talk to tailcontrol
dialer *tsdial.Dialer
privKey key.MachinePrivate privKey key.MachinePrivate
serverPubKey key.MachinePublic serverPubKey key.MachinePublic
serverHost string // the host:port part of serverURL serverHost string // the host:port part of serverURL
@ -58,7 +60,7 @@ type noiseClient struct {
// newNoiseClient returns a new noiseClient for the provided server and machine key. // newNoiseClient returns a new noiseClient for the provided server and machine key.
// serverURL is of the form https://<host>:<port> (no trailing slash). // serverURL is of the form https://<host>:<port> (no trailing slash).
func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string) (*noiseClient, error) { func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer) (*noiseClient, error) {
u, err := url.Parse(serverURL) u, err := url.Parse(serverURL)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,6 +77,7 @@ func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, s
serverPubKey: serverPubKey, serverPubKey: serverPubKey,
privKey: priKey, privKey: priKey,
serverHost: host, serverHost: host,
dialer: dialer,
} }
// Create the HTTP/2 Transport using a net/http.Transport // Create the HTTP/2 Transport using a net/http.Transport
@ -151,7 +154,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
// thousand version numbers before getting to this point. // thousand version numbers before getting to this point.
panic("capability version is too high to fit in the wire protocol") panic("capability version is too high to fit in the wire protocol")
} }
conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion)) conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion), nc.dialer.SystemDial)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -25,7 +25,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
@ -35,7 +34,6 @@ import (
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/dnsfallback" "tailscale.com/net/dnsfallback"
"tailscale.com/net/netns"
"tailscale.com/net/netutil" "tailscale.com/net/netutil"
"tailscale.com/net/tlsdial" "tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy" "tailscale.com/net/tshttpproxy"
@ -66,7 +64,7 @@ const (
// //
// The provided ctx is only used for the initial connection, until // The provided ctx is only used for the initial connection, until
// Dial returns. It does not affect the connection once established. // Dial returns. It does not affect the connection once established.
func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*controlbase.Conn, error) { func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -80,6 +78,7 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr
controlKey: controlKey, controlKey: controlKey,
version: protocolVersion, version: protocolVersion,
proxyFunc: tshttpproxy.ProxyFromEnvironment, proxyFunc: tshttpproxy.ProxyFromEnvironment,
dialer: dialer,
} }
return a.dial() return a.dial()
} }
@ -93,6 +92,7 @@ type dialParams struct {
controlKey key.MachinePublic controlKey key.MachinePublic
version uint16 version uint16
proxyFunc func(*http.Request) (*url.URL, error) // or nil proxyFunc func(*http.Request) (*url.URL, error) // or nil
dialer dnscache.DialContextFunc
// For tests only // For tests only
insecureTLS bool insecureTLS bool
@ -196,12 +196,11 @@ func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.C
LookupIPFallback: dnsfallback.Lookup, LookupIPFallback: dnsfallback.Lookup,
UseLastGood: true, UseLastGood: true,
} }
dialer := netns.NewDialer(log.Printf)
tr := http.DefaultTransport.(*http.Transport).Clone() tr := http.DefaultTransport.(*http.Transport).Clone()
defer tr.CloseIdleConnections() defer tr.CloseIdleConnections()
tr.Proxy = a.proxyFunc tr.Proxy = a.proxyFunc
tshttpproxy.SetTransportGetProxyConnectHeader(tr) tshttpproxy.SetTransportGetProxyConnectHeader(tr)
tr.DialContext = dnscache.Dialer(dialer.DialContext, dns) tr.DialContext = dnscache.Dialer(a.dialer, dns)
// Disable HTTP2, since h2 can't do protocol switching. // Disable HTTP2, since h2 can't do protocol switching.
tr.TLSClientConfig.NextProtos = []string{} tr.TLSClientConfig.NextProtos = []string{}
tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
@ -210,7 +209,7 @@ func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.C
tr.TLSClientConfig.InsecureSkipVerify = true tr.TLSClientConfig.InsecureSkipVerify = true
tr.TLSClientConfig.VerifyConnection = nil tr.TLSClientConfig.VerifyConnection = nil
} }
tr.DialTLSContext = dnscache.TLSDialer(dialer.DialContext, dns, tr.TLSClientConfig) tr.DialTLSContext = dnscache.TLSDialer(a.dialer, dns, tr.TLSClientConfig)
tr.DisableCompression = true tr.DisableCompression = true
// (mis)use httptrace to extract the underlying net.Conn from the // (mis)use httptrace to extract the underlying net.Conn from the

@ -20,6 +20,7 @@ import (
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/net/socks5" "tailscale.com/net/socks5"
"tailscale.com/net/tsdial"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@ -155,6 +156,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
controlKey: server.Public(), controlKey: server.Public(),
version: testProtocolVersion, version: testProtocolVersion,
insecureTLS: true, insecureTLS: true,
dialer: new(tsdial.Dialer).SystemDial,
} }
if proxy != nil { if proxy != nil {

@ -1036,6 +1036,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
LinkMonitor: b.e.GetLinkMonitor(), LinkMonitor: b.e.GetLinkMonitor(),
Pinger: b.e, Pinger: b.e,
PopBrowserURL: b.tellClientToBrowseToURL, PopBrowserURL: b.tellClientToBrowseToURL,
Dialer: b.Dialer(),
// Don't warn about broken Linux IP forwarding when // Don't warn about broken Linux IP forwarding when
// netstack is being used. // netstack is being used.

@ -20,8 +20,12 @@ import (
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/interfaces"
"tailscale.com/net/netknob" "tailscale.com/net/netknob"
"tailscale.com/net/netns"
"tailscale.com/types/logger"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/util/mak"
"tailscale.com/wgengine/monitor" "tailscale.com/wgengine/monitor"
) )
@ -30,6 +34,7 @@ import (
// (TUN, netstack), the OS network sandboxing style (macOS/iOS // (TUN, netstack), the OS network sandboxing style (macOS/iOS
// Extension, none), user-selected route acceptance prefs, etc. // Extension, none), user-selected route acceptance prefs, etc.
type Dialer struct { type Dialer struct {
Logf logger.Logf
// UseNetstackForIP if non-nil is whether NetstackDialTCP (if // UseNetstackForIP if non-nil is whether NetstackDialTCP (if
// it's non-nil) should be used to dial the provided IP. // it's non-nil) should be used to dial the provided IP.
UseNetstackForIP func(netaddr.IP) bool UseNetstackForIP func(netaddr.IP) bool
@ -46,12 +51,33 @@ type Dialer struct {
peerDialerOnce sync.Once peerDialerOnce sync.Once
peerDialer *net.Dialer peerDialer *net.Dialer
mu sync.Mutex netnsDialerOnce sync.Once
dns dnsMap netnsDialer netns.Dialer
tunName string // tun device name
linkMon *monitor.Mon mu sync.Mutex
exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?') closed bool
dnsCache *dnscache.MessageCache // nil until first first non-empty SetExitDNSDoH dns dnsMap
tunName string // tun device name
linkMon *monitor.Mon
linkMonUnregister func()
exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?')
dnsCache *dnscache.MessageCache // nil until first first non-empty SetExitDNSDoH
nextSysConnID int
activeSysConns map[int]net.Conn // active connections not yet closed
}
// sysConn wraps a net.Conn that was created using d.SystemDial.
// It exists to track which connections are still open, and should be
// closed on major link changes.
type sysConn struct {
net.Conn
id int
d *Dialer
}
func (c sysConn) Close() error {
c.d.closeSysConn(c.id)
return nil
} }
// SetTUNName sets the name of the tun device in use ("tailscale0", "utun6", // SetTUNName sets the name of the tun device in use ("tailscale0", "utun6",
@ -91,10 +117,53 @@ func (d *Dialer) SetExitDNSDoH(doh string) {
} }
} }
func (d *Dialer) Close() error {
d.mu.Lock()
defer d.mu.Unlock()
d.closed = true
if d.linkMonUnregister != nil {
d.linkMonUnregister()
d.linkMonUnregister = nil
}
for _, c := range d.activeSysConns {
c.Close()
}
d.activeSysConns = nil
return nil
}
func (d *Dialer) SetLinkMonitor(mon *monitor.Mon) { func (d *Dialer) SetLinkMonitor(mon *monitor.Mon) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if d.linkMonUnregister != nil {
go d.linkMonUnregister()
d.linkMonUnregister = nil
}
d.linkMon = mon d.linkMon = mon
d.linkMonUnregister = d.linkMon.RegisterChangeCallback(d.linkChanged)
}
func (d *Dialer) linkChanged(major bool, state *interfaces.State) {
if !major {
return
}
d.mu.Lock()
defer d.mu.Unlock()
for id, c := range d.activeSysConns {
go c.Close()
delete(d.activeSysConns, id)
}
}
func (d *Dialer) closeSysConn(id int) {
d.mu.Lock()
defer d.mu.Unlock()
c, ok := d.activeSysConns[id]
if !ok {
return
}
delete(d.activeSysConns, id)
go c.Close() // ignore the error
} }
func (d *Dialer) interfaceIndexLocked(ifName string) (index int, ok bool) { func (d *Dialer) interfaceIndexLocked(ifName string) (index int, ok bool) {
@ -197,6 +266,42 @@ func ipNetOfNetwork(n string) string {
return "ip" return "ip"
} }
// SystemDial connects to the provided network address without going over
// Tailscale. It prefers going over the default interface and closes existing
// connections if the default interface changes. It is used to connect to
// Control and (in the future, as of 2022-04-27) DERPs..
func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn, error) {
d.mu.Lock()
closed := d.closed
d.mu.Unlock()
if closed {
return nil, net.ErrClosed
}
d.netnsDialerOnce.Do(func() {
logf := d.Logf
if logf == nil {
logf = logger.Discard
}
d.netnsDialer = netns.NewDialer(logf)
})
c, err := d.netnsDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
d.mu.Lock()
defer d.mu.Unlock()
id := d.nextSysConnID
d.nextSysConnID++
mak.Set(&d.activeSysConns, id, c)
return sysConn{
id: id,
d: d,
Conn: c,
}, nil
}
// UserDial connects to the provided network address as if a user were initiating the dial. // UserDial connects to the provided network address as if a user were initiating the dial.
// (e.g. from a SOCKS or HTTP outbound proxy) // (e.g. from a SOCKS or HTTP outbound proxy)
func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, error) { func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, error) {

@ -105,6 +105,7 @@ func (s *Server) Close() error {
s.shutdownCancel() s.shutdownCancel()
s.lb.Shutdown() s.lb.Shutdown()
s.linkMon.Close() s.linkMon.Close()
s.dialer.Close()
s.localAPIListener.Close() s.localAPIListener.Close()
s.mu.Lock() s.mu.Lock()

Loading…
Cancel
Save