net/dnscache: work on IPv6-only hosts (again)

This fixes the regression where we had stopped working on IPv6-only
hosts.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/957/head
Brad Fitzpatrick 3 years ago
parent 560da4884f
commit 66be052a70

@ -358,7 +358,7 @@ func (c *Client) dialURL(ctx context.Context) (net.Conn, error) {
dialer := netns.NewDialer() dialer := netns.NewDialer()
if c.DNSCache != nil { if c.DNSCache != nil {
ip, err := c.DNSCache.LookupIP(ctx, host) ip, _, err := c.DNSCache.LookupIP(ctx, host)
if err == nil { if err == nil {
hostOrIP = ip.String() hostOrIP = ip.String()
} }

@ -71,7 +71,8 @@ type Resolver struct {
} }
type ipCacheEntry struct { type ipCacheEntry struct {
ip net.IP ip net.IP // either v4 or v6
ip6 net.IP // nil if no v4 or no v6
expires time.Time expires time.Time
} }
@ -91,78 +92,87 @@ func (r *Resolver) ttl() time.Duration {
var debug, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_DNS_CACHE")) var debug, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_DNS_CACHE"))
// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address. // LookupIP returns the host's primary IP address (either IPv4 or
func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) { // IPv6, but preferring IPv4) and optionally its IPv6 address, if
// there is both IPv4 and IPv6.
//
// If err is nil, ip will be non-nil. The v6 address may be nil even
// with a nil error.
func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, err error) {
if ip := net.ParseIP(host); ip != nil { if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil { if ip4 := ip.To4(); ip4 != nil {
return ip4, nil return ip4, nil, nil
} }
if debug { if debug {
log.Printf("dnscache: %q is an IP", host) log.Printf("dnscache: %q is an IP", host)
} }
return ip, nil return ip, nil, nil
} }
if ip, ok := r.lookupIPCache(host); ok { if ip, ip6, ok := r.lookupIPCache(host); ok {
if debug { if debug {
log.Printf("dnscache: %q = %v (cached)", host, ip) log.Printf("dnscache: %q = %v (cached)", host, ip)
} }
return ip, nil return ip, ip6, nil
} }
type ipPair struct {
ip, ip6 net.IP
}
ch := r.sf.DoChan(host, func() (interface{}, error) { ch := r.sf.DoChan(host, func() (interface{}, error) {
ip, err := r.lookupIP(host) ip, ip6, err := r.lookupIP(host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ip, nil return ipPair{ip, ip6}, nil
}) })
select { select {
case res := <-ch: case res := <-ch:
if res.Err != nil { if res.Err != nil {
if r.UseLastGood { if r.UseLastGood {
if ip, ok := r.lookupIPCacheExpired(host); ok { if ip, ip6, ok := r.lookupIPCacheExpired(host); ok {
if debug { if debug {
log.Printf("dnscache: %q using %v after error", host, ip) log.Printf("dnscache: %q using %v after error", host, ip)
} }
return ip, nil return ip, ip6, nil
} }
} }
if debug { if debug {
log.Printf("dnscache: error resolving %q: %v", host, res.Err) log.Printf("dnscache: error resolving %q: %v", host, res.Err)
} }
return nil, res.Err return nil, nil, res.Err
} }
return res.Val.(net.IP), nil pair := res.Val.(ipPair)
return pair.ip, pair.ip6, nil
case <-ctx.Done(): case <-ctx.Done():
if debug { if debug {
log.Printf("dnscache: context done while resolving %q: %v", host, ctx.Err()) log.Printf("dnscache: context done while resolving %q: %v", host, ctx.Err())
} }
return nil, ctx.Err() return nil, nil, ctx.Err()
} }
} }
func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) { func (r *Resolver) lookupIPCache(host string) (ip, ip6 net.IP, ok bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) { if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) {
return ent.ip, true return ent.ip, ent.ip6, true
} }
return nil, false return nil, nil, false
} }
func (r *Resolver) lookupIPCacheExpired(host string) (ip net.IP, ok bool) { func (r *Resolver) lookupIPCacheExpired(host string) (ip, ip6 net.IP, ok bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if ent, ok := r.ipCache[host]; ok { if ent, ok := r.ipCache[host]; ok {
return ent.ip, true return ent.ip, ent.ip6, true
} }
return nil, false return nil, nil, false
} }
func (r *Resolver) lookupTimeoutForHost(host string) time.Duration { func (r *Resolver) lookupTimeoutForHost(host string) time.Duration {
if r.UseLastGood { if r.UseLastGood {
if _, ok := r.lookupIPCacheExpired(host); ok { if _, _, ok := r.lookupIPCacheExpired(host); ok {
// If we have some previous good value for this host, // If we have some previous good value for this host,
// don't give this DNS lookup much time. If we're in a // don't give this DNS lookup much time. If we're in a
// situation where the user's DNS server is unreachable // situation where the user's DNS server is unreachable
@ -177,40 +187,52 @@ func (r *Resolver) lookupTimeoutForHost(host string) time.Duration {
return 10 * time.Second return 10 * time.Second
} }
func (r *Resolver) lookupIP(host string) (net.IP, error) { func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, err error) {
if ip, ok := r.lookupIPCache(host); ok { if ip, ip6, ok := r.lookupIPCache(host); ok {
if debug { if debug {
log.Printf("dnscache: %q found in cache as %v", host, ip) log.Printf("dnscache: %q found in cache as %v", host, ip)
} }
return ip, nil return ip, ip6, nil
} }
ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host)) ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host))
defer cancel() defer cancel()
ips, err := r.fwd().LookupIPAddr(ctx, host) ips, err := r.fwd().LookupIPAddr(ctx, host)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if len(ips) == 0 { if len(ips) == 0 {
return nil, fmt.Errorf("no IPs for %q found", host) return nil, nil, fmt.Errorf("no IPs for %q found", host)
} }
have4 := false
for _, ipa := range ips { for _, ipa := range ips {
if ip4 := ipa.IP.To4(); ip4 != nil { if ip4 := ipa.IP.To4(); ip4 != nil {
return r.addIPCache(host, ip4, r.ttl()), nil if !have4 {
ip6 = ip
ip = ip4
have4 = true
}
} else {
if have4 {
ip6 = ipa.IP
} else {
ip = ipa.IP
}
} }
} }
return r.addIPCache(host, ips[0].IP, r.ttl()), nil r.addIPCache(host, ip, ip6, r.ttl())
return ip, ip6, nil
} }
func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP { func (r *Resolver) addIPCache(host string, ip, ip6 net.IP, d time.Duration) {
if isPrivateIP(ip) { if isPrivateIP(ip) {
// Don't cache obviously wrong entries from captive portals. // Don't cache obviously wrong entries from captive portals.
// TODO: use DoH or DoT for the forwarding resolver? // TODO: use DoH or DoT for the forwarding resolver?
if debug { if debug {
log.Printf("dnscache: %q resolved to private IP %v; using but not caching", host, ip) log.Printf("dnscache: %q resolved to private IP %v; using but not caching", host, ip)
} }
return ip return
} }
if debug { if debug {
@ -222,8 +244,7 @@ func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP {
if r.ipCache == nil { if r.ipCache == nil {
r.ipCache = make(map[string]ipCacheEntry) r.ipCache = make(map[string]ipCacheEntry)
} }
r.ipCache[host] = ipCacheEntry{ip: ip, expires: time.Now().Add(d)} r.ipCache[host] = ipCacheEntry{ip: ip, ip6: ip6, expires: time.Now().Add(d)}
return ip
} }
func mustCIDR(s string) *net.IPNet { func mustCIDR(s string) *net.IPNet {
@ -255,7 +276,7 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
// inventing a similar one. // inventing a similar one.
return fwd(ctx, network, address) return fwd(ctx, network, address)
} }
ip, err := dnsCache.LookupIP(ctx, host) ip, ip6, err := dnsCache.LookupIP(ctx, host)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to resolve %q: %w", host, err) return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
} }
@ -263,6 +284,19 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
if debug { if debug {
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
} }
c, err := fwd(ctx, network, dst)
if err == nil || ctx.Err() != nil || ip6 == nil {
return c, err
}
// Fall back to trying IPv6.
// TODO(bradfitz): this is a primarily for IPv6-only
// hosts; it's not supposed to be a real Happy
// Eyeballs implementation. We should use the net
// package's implementation of that by plumbing this
// dnscache impl into net.Dialer.Resolver.Dial and
// unmarshal/marshal DNS queries/responses to the net
// package. This works for v6-only hosts for now.
dst = net.JoinHostPort(ip6.String(), port)
return fwd(ctx, network, dst) return fwd(ctx, network, dst)
} }
} }

Loading…
Cancel
Save