@ -9,8 +9,11 @@ package dnscache
import (
import (
"context"
"context"
"fmt"
"fmt"
"log"
"net"
"net"
"os"
"runtime"
"runtime"
"strconv"
"sync"
"sync"
"time"
"time"
@ -42,8 +45,6 @@ func preferGoResolver() bool {
// Get returns a caching Resolver singleton.
// Get returns a caching Resolver singleton.
func Get ( ) * Resolver { return single }
func Get ( ) * Resolver { return single }
const fixedTTL = 10 * time . Minute
// Resolver is a minimal DNS caching resolver.
// Resolver is a minimal DNS caching resolver.
//
//
// The TTL is always fixed for now. It's not intended for general use.
// The TTL is always fixed for now. It's not intended for general use.
@ -54,6 +55,15 @@ type Resolver struct {
// If nil, net.DefaultResolver is used.
// If nil, net.DefaultResolver is used.
Forward * net . Resolver
Forward * net . Resolver
// TTL is how long to keep entries cached
//
// If zero, a default (currently 10 minutes) is used.
TTL time . Duration
// UseLastGood controls whether a cached entry older than TTL is used
// if a refresh fails.
UseLastGood bool
sf singleflight . Group
sf singleflight . Group
mu sync . Mutex
mu sync . Mutex
@ -72,16 +82,31 @@ func (r *Resolver) fwd() *net.Resolver {
return net . DefaultResolver
return net . DefaultResolver
}
}
func ( r * Resolver ) ttl ( ) time . Duration {
if r . TTL > 0 {
return r . TTL
}
return 10 * time . Minute
}
var debug , _ = strconv . ParseBool ( os . Getenv ( "TS_DEBUG_DNS_CACHE" ) )
// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address.
// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address.
func ( r * Resolver ) LookupIP ( ctx context . Context , host string ) ( net . IP , error ) {
func ( r * Resolver ) LookupIP ( ctx context . Context , host string ) ( net . IP , error ) {
if ip := net . ParseIP ( host ) ; ip != nil {
if ip := net . ParseIP ( host ) ; ip != nil {
if ip4 := ip . To4 ( ) ; ip4 != nil {
if ip4 := ip . To4 ( ) ; ip4 != nil {
return ip4 , nil
return ip4 , nil
}
}
if debug {
log . Printf ( "dnscache: %q is an IP" , host )
}
return ip , nil
return ip , nil
}
}
if ip , ok := r . lookupIPCache ( host ) ; ok {
if ip , ok := r . lookupIPCache ( host ) ; ok {
if debug {
log . Printf ( "dnscache: %q = %v (cached)" , host , ip )
}
return ip , nil
return ip , nil
}
}
@ -95,10 +120,24 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) {
select {
select {
case res := <- ch :
case res := <- ch :
if res . Err != nil {
if res . Err != nil {
if r . UseLastGood {
if ip , ok := r . lookupIPCacheExpired ( host ) ; ok {
if debug {
log . Printf ( "dnscache: %q using %v after error" , host , ip )
}
return ip , nil
}
}
if debug {
log . Printf ( "dnscache: error resolving %q: %v" , host , res . Err )
}
return nil , res . Err
return nil , res . Err
}
}
return res . Val . ( net . IP ) , nil
return res . Val . ( net . IP ) , nil
case <- ctx . Done ( ) :
case <- ctx . Done ( ) :
if debug {
log . Printf ( "dnscache: context done while resolving %q: %v" , host , ctx . Err ( ) )
}
return nil , ctx . Err ( )
return nil , ctx . Err ( )
}
}
}
}
@ -112,12 +151,41 @@ func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) {
return nil , false
return nil , false
}
}
func ( r * Resolver ) lookupIPCacheExpired ( host string ) ( ip net . IP , ok bool ) {
r . mu . Lock ( )
defer r . mu . Unlock ( )
if ent , ok := r . ipCache [ host ] ; ok {
return ent . ip , true
}
return nil , false
}
func ( r * Resolver ) lookupTimeoutForHost ( host string ) time . Duration {
if r . UseLastGood {
if _ , ok := r . lookupIPCacheExpired ( host ) ; ok {
// If we have some previous good value for this host,
// don't give this DNS lookup much time. If we're in a
// situation where the user's DNS server is unreachable
// (e.g. their corp DNS server is behind a subnet router
// that can't come up due to Tailscale needing to
// connect to itself), then we want to fail fast and let
// our caller (who set UseLastGood) fall back to using
// the last-known-good IP address.
return 3 * time . Second
}
}
return 10 * time . Second
}
func ( r * Resolver ) lookupIP ( host string ) ( net . IP , error ) {
func ( r * Resolver ) lookupIP ( host string ) ( net . IP , error ) {
if ip , ok := r . lookupIPCache ( host ) ; ok {
if ip , ok := r . lookupIPCache ( host ) ; ok {
if debug {
log . Printf ( "dnscache: %q found in cache as %v" , host , ip )
}
return ip , nil
return ip , nil
}
}
ctx , cancel := context . WithTimeout ( context . Background ( ) , 10 * time . Second )
ctx , cancel := context . WithTimeout ( context . Background ( ) , r . lookupTimeoutForHost ( host ) )
defer cancel ( )
defer cancel ( )
ips , err := r . fwd ( ) . LookupIPAddr ( ctx , host )
ips , err := r . fwd ( ) . LookupIPAddr ( ctx , host )
if err != nil {
if err != nil {
@ -129,19 +197,26 @@ func (r *Resolver) lookupIP(host string) (net.IP, error) {
for _ , ipa := range ips {
for _ , ipa := range ips {
if ip4 := ipa . IP . To4 ( ) ; ip4 != nil {
if ip4 := ipa . IP . To4 ( ) ; ip4 != nil {
return r . addIPCache ( host , ip4 , fixedTTL ) , nil
return r . addIPCache ( host , ip4 , r. ttl ( ) ) , nil
}
}
}
}
return r . addIPCache ( host , ips [ 0 ] . IP , fixedTTL ) , nil
return r . addIPCache ( host , ips [ 0 ] . IP , r. ttl ( ) ) , nil
}
}
func ( r * Resolver ) addIPCache ( host string , ip net . IP , d time . Duration ) net . IP {
func ( r * Resolver ) addIPCache ( host string , ip net . IP , d time . Duration ) net . IP {
if isPrivateIP ( ip ) {
if isPrivateIP ( ip ) {
// Don't cache obviously wrong entries from captive portals.
// Don't cache obviously wrong entries from captive portals.
// TODO: use DoH or DoT for the forwarding resolver?
// TODO: use DoH or DoT for the forwarding resolver?
if debug {
log . Printf ( "dnscache: %q resolved to private IP %v; using but not caching" , host , ip )
}
return ip
return ip
}
}
if debug {
log . Printf ( "dnscache: %q resolved to IP %v; caching" , host , ip )
}
r . mu . Lock ( )
r . mu . Lock ( )
defer r . mu . Unlock ( )
defer r . mu . Unlock ( )
if r . ipCache == nil {
if r . ipCache == nil {
@ -168,3 +243,26 @@ var (
private2 = mustCIDR ( "172.16.0.0/12" )
private2 = mustCIDR ( "172.16.0.0/12" )
private3 = mustCIDR ( "192.168.0.0/16" )
private3 = mustCIDR ( "192.168.0.0/16" )
)
)
type DialContextFunc func ( ctx context . Context , network , address string ) ( net . Conn , error )
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
func Dialer ( fwd DialContextFunc , dnsCache * Resolver ) DialContextFunc {
return func ( ctx context . Context , network , address string ) ( net . Conn , error ) {
host , port , err := net . SplitHostPort ( address )
if err != nil {
// Bogus. But just let the real dialer return an error rather than
// inventing a similar one.
return fwd ( ctx , network , address )
}
ip , err := dnsCache . LookupIP ( ctx , host )
if err != nil {
return nil , fmt . Errorf ( "failed to resolve %q: %w" , host , err )
}
dst := net . JoinHostPort ( ip . String ( ) , port )
if debug {
log . Printf ( "dnscache: dialing %s, %s for %s" , network , dst , address )
}
return fwd ( ctx , network , dst )
}
}