diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index cda882c4c..425cb1641 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -71,7 +71,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/control/controlclient from tailscale.com/ipn/ipnlocal+ tailscale.com/derp from tailscale.com/derp/derphttp+ tailscale.com/derp/derphttp from tailscale.com/net/netcheck+ - tailscale.com/derp/derpmap from tailscale.com/cmd/tailscaled + tailscale.com/derp/derpmap from tailscale.com/cmd/tailscaled+ tailscale.com/disco from tailscale.com/derp+ tailscale.com/health from tailscale.com/control/controlclient+ tailscale.com/internal/deepprint from tailscale.com/ipn/ipnlocal+ @@ -89,6 +89,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/logtail/filch from tailscale.com/logpolicy tailscale.com/metrics from tailscale.com/derp tailscale.com/net/dnscache from tailscale.com/control/controlclient+ + tailscale.com/net/dnsfallback from tailscale.com/control/controlclient tailscale.com/net/flowtrack from tailscale.com/wgengine/filter+ 💣 tailscale.com/net/interfaces from tailscale.com/cmd/tailscaled+ tailscale.com/net/netcheck from tailscale.com/wgengine/magicsock diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index be04e143c..6aab70760 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -36,6 +36,7 @@ import ( "tailscale.com/health" "tailscale.com/log/logheap" "tailscale.com/net/dnscache" + "tailscale.com/net/dnsfallback" "tailscale.com/net/netns" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" @@ -126,16 +127,18 @@ func NewDirect(opts Options) (*Direct, error) { httpc := opts.HTTPTestClient if httpc == nil { dnsCache := &dnscache.Resolver{ - Forward: dnscache.Get().Forward, // use default cache's forwarder - UseLastGood: true, + Forward: dnscache.Get().Forward, // use default cache's forwarder + UseLastGood: true, + LookupIPFallback: dnsfallback.Lookup, } dialer := netns.NewDialer() tr := http.DefaultTransport.(*http.Transport).Clone() tr.Proxy = tshttpproxy.ProxyFromEnvironment tshttpproxy.SetTransportGetProxyConnectHeader(tr) + tr.TLSClientConfig = tlsdial.Config(serverURL.Host, tr.TLSClientConfig) tr.DialContext = dnscache.Dialer(dialer.DialContext, dnsCache) + tr.DialTLSContext = dnscache.TLSDialer(dialer.DialContext, dnsCache, tr.TLSClientConfig) tr.ForceAttemptHTTP2 = true - tr.TLSClientConfig = tlsdial.Config(serverURL.Host, tr.TLSClientConfig) httpc = &http.Client{Transport: tr} } diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index a1a52107d..953a64c51 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -8,6 +8,8 @@ package dnscache import ( "context" + "crypto/tls" + "errors" "fmt" "log" "net" @@ -18,6 +20,7 @@ import ( "time" "golang.org/x/sync/singleflight" + "inet.af/netaddr" ) var single = &Resolver{ @@ -55,6 +58,10 @@ type Resolver struct { // If nil, net.DefaultResolver is used. Forward *net.Resolver + // LookupIPFallback optionally provides a backup DNS mechanism + // to use if Forward returns an error or no results. + LookupIPFallback func(ctx context.Context, host string) ([]netaddr.IP, error) + // TTL is how long to keep entries cached // // If zero, a default (currently 10 minutes) is used. @@ -198,6 +205,18 @@ func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, err error) { ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host)) defer cancel() ips, err := r.fwd().LookupIPAddr(ctx, host) + if (err != nil || len(ips) == 0) && r.LookupIPFallback != nil { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + var fips []netaddr.IP + fips, err = r.LookupIPFallback(ctx, host) + if err == nil { + ips = nil + for _, fip := range fips { + ips = append(ips, *fip.IPAddr()) + } + } + } if err != nil { return nil, nil, err } @@ -269,13 +288,33 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con // Dialer returns a wrapped DialContext func that uses the provided dnsCache. func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { - return func(ctx context.Context, network, address string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (retConn net.Conn, ret error) { host, port, err := net.SplitHostPort(address) if err != nil { // Bogus. But just let the real dialer return an error rather than // inventing a similar one. return fwd(ctx, network, address) } + defer func() { + // On any failure, assume our DNS is wrong and try our fallback, if any. + if ret == nil || dnsCache.LookupIPFallback == nil { + return + } + ips, err := dnsCache.LookupIPFallback(ctx, host) + if err != nil { + // Return with original error + return + } + for _, ip := range ips { + dst := net.JoinHostPort(ip.String(), port) + if c, err := fwd(ctx, network, dst); err == nil { + retConn = c + ret = nil + return + } + } + }() + ip, ip6, err := dnsCache.LookupIP(ctx, host) if err != nil { return nil, fmt.Errorf("failed to resolve %q: %w", host, err) @@ -300,3 +339,62 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { return fwd(ctx, network, dst) } } + +var errTLSHandshakeTimeout = errors.New("timeout doing TLS handshake") + +// TLSDialer is like Dialer but returns a func suitable for using with net/http.Transport.DialTLSContext. +// It returns a *tls.Conn type on success. +// On TLS cert validation failure, it can invoke a backup DNS resolution strategy. +func TLSDialer(fwd DialContextFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) DialContextFunc { + tcpDialer := Dialer(fwd, dnsCache) + return func(ctx context.Context, network, address string) (net.Conn, error) { + host, _, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + tcpConn, err := tcpDialer(ctx, network, address) + if err != nil { + return nil, err + } + + cfg := cloneTLSConfig(tlsConfigBase) + if cfg.ServerName == "" { + cfg.ServerName = host + } + tlsConn := tls.Client(tcpConn, cfg) + + errc := make(chan error, 2) + handshakeCtx, handshakeTimeoutCancel := context.WithTimeout(ctx, 5*time.Second) + defer handshakeTimeoutCancel() + done := make(chan bool) + defer close(done) + go func() { + select { + case <-done: + case <-handshakeCtx.Done(): + errc <- errTLSHandshakeTimeout + } + }() + go func() { + err := tlsConn.Handshake() + handshakeTimeoutCancel() + errc <- err + }() + if err := <-errc; err != nil { + tcpConn.Close() + // TODO: if err != errTLSHandshakeTimeout, + // assume it might be some captive portal or + // otherwise incorrect DNS and try the backup + // DNS mechanism. + return nil, err + } + return tlsConn, nil + } +} + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} diff --git a/net/dnsfallback/dnsfallback.go b/net/dnsfallback/dnsfallback.go new file mode 100644 index 000000000..9039dee1e --- /dev/null +++ b/net/dnsfallback/dnsfallback.go @@ -0,0 +1,103 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package dnsfallback contains a DNS fallback mechanism +// for starting up Tailscale when the system DNS is broken or otherwise unavailable. +package dnsfallback + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "math/rand" + "net" + "net/http" + "net/url" + "time" + + "inet.af/netaddr" + "tailscale.com/derp/derpmap" + "tailscale.com/net/netns" + "tailscale.com/net/tshttpproxy" +) + +func Lookup(ctx context.Context, host string) ([]netaddr.IP, error) { + type nameIP struct { + dnsName string + ip netaddr.IP + } + + var cands []nameIP + dm := derpmap.Prod() + for _, dr := range dm.Regions { + for _, n := range dr.Nodes { + if ip, err := netaddr.ParseIP(n.IPv4); err == nil { + cands = append(cands, nameIP{n.HostName, ip}) + } + if ip, err := netaddr.ParseIP(n.IPv6); err == nil { + cands = append(cands, nameIP{n.HostName, ip}) + } + } + } + rand.Shuffle(len(cands), func(i, j int) { + cands[i], cands[j] = cands[j], cands[i] + }) + if len(cands) == 0 { + return nil, fmt.Errorf("no DNS fallback options for %q", host) + } + for ctx.Err() == nil && len(cands) > 0 { + cand := cands[0] + log.Printf("trying bootstrapDNS(%q, %q) for %q ...", cand.dnsName, cand.ip, host) + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + dm, err := bootstrapDNSMap(ctx, cand.dnsName, cand.ip, host) + if err != nil { + log.Printf("bootstrapDNS(%q, %q) for %q error: %v", cand.dnsName, cand.ip, host, err) + continue + } + if ips := dm[host]; len(ips) > 0 { + log.Printf("bootstrapDNS(%q, %q) for %q = %v", cand.dnsName, cand.ip, host, ips) + return ips, nil + } + } + if err := ctx.Err(); err != nil { + return nil, err + } + return nil, fmt.Errorf("no DNS fallback candidates remain for %q", host) +} + +// serverName and serverIP of are, say, "derpN.tailscale.com". +// queryName is the name being sought (e.g. "login.tailscale.com"), passed as hint. +func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netaddr.IP, queryName string) (dnsMap, error) { + dialer := netns.NewDialer() + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.Proxy = tshttpproxy.ProxyFromEnvironment + tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443")) + } + c := &http.Client{Transport: tr} + req, err := http.NewRequestWithContext(ctx, "GET", "https://"+serverName+"/bootstrap-dns?q="+url.QueryEscape(queryName), nil) + if err != nil { + return nil, err + } + dm := make(dnsMap) + res, err := c.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != 200 { + return nil, errors.New(res.Status) + } + if err := json.NewDecoder(res.Body).Decode(&dm); err != nil { + return nil, err + } + return dm, nil +} + +// dnsMap is the JSON type returned by the DERP /bootstrap-dns handler: +// https://derp10.tailscale.com/bootstrap-dns +type dnsMap map[string][]netaddr.IP