diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 895278662..be219a5b0 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -53,7 +53,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/log/logheap from tailscale.com/control/controlclient tailscale.com/logtail/backoff from tailscale.com/control/controlclient+ tailscale.com/metrics from tailscale.com/derp - tailscale.com/net/dnscache from tailscale.com/cmd/tailscale/cli+ + tailscale.com/net/dnscache from tailscale.com/control/controlclient+ 💣 tailscale.com/net/interfaces from tailscale.com/cmd/tailscale/cli+ tailscale.com/net/netcheck from tailscale.com/cmd/tailscale/cli+ 💣 tailscale.com/net/netns from tailscale.com/control/controlclient+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 43258fd04..ebe927e35 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -58,7 +58,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/logtail/backoff from tailscale.com/control/controlclient+ tailscale.com/logtail/filch from tailscale.com/logpolicy tailscale.com/metrics from tailscale.com/derp - tailscale.com/net/dnscache from tailscale.com/derp/derphttp+ + tailscale.com/net/dnscache from tailscale.com/control/controlclient+ 💣 tailscale.com/net/interfaces from tailscale.com/ipn+ tailscale.com/net/netcheck from tailscale.com/wgengine/magicsock 💣 tailscale.com/net/netns from tailscale.com/control/controlclient+ diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 3bd7208b8..7dce00901 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -36,6 +36,7 @@ import ( "golang.org/x/oauth2" "inet.af/netaddr" "tailscale.com/log/logheap" + "tailscale.com/net/dnscache" "tailscale.com/net/netns" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" @@ -172,11 +173,15 @@ func NewDirect(opts Options) (*Direct, error) { httpc := opts.HTTPTestClient if httpc == nil { + dnsCache := &dnscache.Resolver{ + Forward: dnscache.Get().Forward, // use default cache's forwarder + UseLastGood: true, + } dialer := netns.NewDialer() tr := http.DefaultTransport.(*http.Transport).Clone() tr.Proxy = tshttpproxy.ProxyFromEnvironment tshttpproxy.SetTransportGetProxyConnectHeader(tr) - tr.DialContext = dialer.DialContext + tr.DialContext = dnscache.Dialer(dialer.DialContext, dnsCache) tr.ForceAttemptHTTP2 = true tr.TLSClientConfig = tlsdial.Config(serverURL.Host, tr.TLSClientConfig) httpc = &http.Client{Transport: tr} diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index fe9fc99ea..15497418c 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -9,8 +9,11 @@ package dnscache import ( "context" "fmt" + "log" "net" + "os" "runtime" + "strconv" "sync" "time" @@ -42,8 +45,6 @@ func preferGoResolver() bool { // Get returns a caching Resolver singleton. func Get() *Resolver { return single } -const fixedTTL = 10 * time.Minute - // Resolver is a minimal DNS caching resolver. // // The TTL is always fixed for now. It's not intended for general use. @@ -54,6 +55,15 @@ type Resolver struct { // If nil, net.DefaultResolver is used. Forward *net.Resolver + // TTL is how long to keep entries cached + // + // If zero, a default (currently 10 minutes) is used. + TTL time.Duration + + // UseLastGood controls whether a cached entry older than TTL is used + // if a refresh fails. + UseLastGood bool + sf singleflight.Group mu sync.Mutex @@ -72,16 +82,31 @@ func (r *Resolver) fwd() *net.Resolver { return net.DefaultResolver } +func (r *Resolver) ttl() time.Duration { + if r.TTL > 0 { + return r.TTL + } + return 10 * time.Minute +} + +var debug, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_DNS_CACHE")) + // LookupIP returns the first IPv4 address found, otherwise the first IPv6 address. func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) { if ip := net.ParseIP(host); ip != nil { if ip4 := ip.To4(); ip4 != nil { return ip4, nil } + if debug { + log.Printf("dnscache: %q is an IP", host) + } return ip, nil } if ip, ok := r.lookupIPCache(host); ok { + if debug { + log.Printf("dnscache: %q = %v (cached)", host, ip) + } return ip, nil } @@ -95,10 +120,24 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) { select { case res := <-ch: if res.Err != nil { + if r.UseLastGood { + if ip, ok := r.lookupIPCacheExpired(host); ok { + if debug { + log.Printf("dnscache: %q using %v after error", host, ip) + } + return ip, nil + } + } + if debug { + log.Printf("dnscache: error resolving %q: %v", host, res.Err) + } return nil, res.Err } return res.Val.(net.IP), nil case <-ctx.Done(): + if debug { + log.Printf("dnscache: context done while resolving %q: %v", host, ctx.Err()) + } return nil, ctx.Err() } } @@ -112,12 +151,41 @@ func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) { return nil, false } +func (r *Resolver) lookupIPCacheExpired(host string) (ip net.IP, ok bool) { + r.mu.Lock() + defer r.mu.Unlock() + if ent, ok := r.ipCache[host]; ok { + return ent.ip, true + } + return nil, false +} + +func (r *Resolver) lookupTimeoutForHost(host string) time.Duration { + if r.UseLastGood { + if _, ok := r.lookupIPCacheExpired(host); ok { + // If we have some previous good value for this host, + // don't give this DNS lookup much time. If we're in a + // situation where the user's DNS server is unreachable + // (e.g. their corp DNS server is behind a subnet router + // that can't come up due to Tailscale needing to + // connect to itself), then we want to fail fast and let + // our caller (who set UseLastGood) fall back to using + // the last-known-good IP address. + return 3 * time.Second + } + } + return 10 * time.Second +} + func (r *Resolver) lookupIP(host string) (net.IP, error) { if ip, ok := r.lookupIPCache(host); ok { + if debug { + log.Printf("dnscache: %q found in cache as %v", host, ip) + } return ip, nil } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host)) defer cancel() ips, err := r.fwd().LookupIPAddr(ctx, host) if err != nil { @@ -129,19 +197,26 @@ func (r *Resolver) lookupIP(host string) (net.IP, error) { for _, ipa := range ips { if ip4 := ipa.IP.To4(); ip4 != nil { - return r.addIPCache(host, ip4, fixedTTL), nil + return r.addIPCache(host, ip4, r.ttl()), nil } } - return r.addIPCache(host, ips[0].IP, fixedTTL), nil + return r.addIPCache(host, ips[0].IP, r.ttl()), nil } func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP { if isPrivateIP(ip) { // Don't cache obviously wrong entries from captive portals. // TODO: use DoH or DoT for the forwarding resolver? + if debug { + log.Printf("dnscache: %q resolved to private IP %v; using but not caching", host, ip) + } return ip } + if debug { + log.Printf("dnscache: %q resolved to IP %v; caching", host, ip) + } + r.mu.Lock() defer r.mu.Unlock() if r.ipCache == nil { @@ -168,3 +243,26 @@ var ( private2 = mustCIDR("172.16.0.0/12") private3 = mustCIDR("192.168.0.0/16") ) + +type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// Dialer returns a wrapped DialContext func that uses the provided dnsCache. +func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { + return func(ctx context.Context, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + // Bogus. But just let the real dialer return an error rather than + // inventing a similar one. + return fwd(ctx, network, address) + } + ip, err := dnsCache.LookupIP(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to resolve %q: %w", host, err) + } + dst := net.JoinHostPort(ip.String(), port) + if debug { + log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) + } + return fwd(ctx, network, dst) + } +}