diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 95f005653..87f11eca6 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -358,7 +358,7 @@ func (c *Client) dialURL(ctx context.Context) (net.Conn, error) { dialer := netns.NewDialer() if c.DNSCache != nil { - ip, err := c.DNSCache.LookupIP(ctx, host) + ip, _, err := c.DNSCache.LookupIP(ctx, host) if err == nil { hostOrIP = ip.String() } diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 15497418c..a1a52107d 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -71,7 +71,8 @@ type Resolver struct { } type ipCacheEntry struct { - ip net.IP + ip net.IP // either v4 or v6 + ip6 net.IP // nil if no v4 or no v6 expires time.Time } @@ -91,78 +92,87 @@ func (r *Resolver) ttl() time.Duration { var debug, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_DNS_CACHE")) -// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address. -func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) { +// LookupIP returns the host's primary IP address (either IPv4 or +// IPv6, but preferring IPv4) and optionally its IPv6 address, if +// there is both IPv4 and IPv6. +// +// If err is nil, ip will be non-nil. The v6 address may be nil even +// with a nil error. +func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, err error) { if ip := net.ParseIP(host); ip != nil { if ip4 := ip.To4(); ip4 != nil { - return ip4, nil + return ip4, nil, nil } if debug { log.Printf("dnscache: %q is an IP", host) } - return ip, nil + return ip, nil, nil } - if ip, ok := r.lookupIPCache(host); ok { + if ip, ip6, ok := r.lookupIPCache(host); ok { if debug { log.Printf("dnscache: %q = %v (cached)", host, ip) } - return ip, nil + return ip, ip6, nil } + type ipPair struct { + ip, ip6 net.IP + } ch := r.sf.DoChan(host, func() (interface{}, error) { - ip, err := r.lookupIP(host) + ip, ip6, err := r.lookupIP(host) if err != nil { return nil, err } - return ip, nil + return ipPair{ip, ip6}, nil }) select { case res := <-ch: if res.Err != nil { if r.UseLastGood { - if ip, ok := r.lookupIPCacheExpired(host); ok { + if ip, ip6, ok := r.lookupIPCacheExpired(host); ok { if debug { log.Printf("dnscache: %q using %v after error", host, ip) } - return ip, nil + return ip, ip6, nil } } if debug { log.Printf("dnscache: error resolving %q: %v", host, res.Err) } - return nil, res.Err + return nil, nil, res.Err } - return res.Val.(net.IP), nil + pair := res.Val.(ipPair) + return pair.ip, pair.ip6, nil case <-ctx.Done(): if debug { log.Printf("dnscache: context done while resolving %q: %v", host, ctx.Err()) } - return nil, ctx.Err() + return nil, nil, ctx.Err() } } -func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) { +func (r *Resolver) lookupIPCache(host string) (ip, ip6 net.IP, ok bool) { r.mu.Lock() defer r.mu.Unlock() if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) { - return ent.ip, true + return ent.ip, ent.ip6, true } - return nil, false + return nil, nil, false } -func (r *Resolver) lookupIPCacheExpired(host string) (ip net.IP, ok bool) { +func (r *Resolver) lookupIPCacheExpired(host string) (ip, ip6 net.IP, ok bool) { r.mu.Lock() defer r.mu.Unlock() if ent, ok := r.ipCache[host]; ok { - return ent.ip, true + return ent.ip, ent.ip6, true } - return nil, false + return nil, nil, false } func (r *Resolver) lookupTimeoutForHost(host string) time.Duration { if r.UseLastGood { - if _, ok := r.lookupIPCacheExpired(host); ok { + if _, _, ok := r.lookupIPCacheExpired(host); ok { // If we have some previous good value for this host, // don't give this DNS lookup much time. If we're in a // situation where the user's DNS server is unreachable @@ -177,40 +187,52 @@ func (r *Resolver) lookupTimeoutForHost(host string) time.Duration { return 10 * time.Second } -func (r *Resolver) lookupIP(host string) (net.IP, error) { - if ip, ok := r.lookupIPCache(host); ok { +func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, err error) { + if ip, ip6, ok := r.lookupIPCache(host); ok { if debug { log.Printf("dnscache: %q found in cache as %v", host, ip) } - return ip, nil + return ip, ip6, nil } ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host)) defer cancel() ips, err := r.fwd().LookupIPAddr(ctx, host) if err != nil { - return nil, err + return nil, nil, err } if len(ips) == 0 { - return nil, fmt.Errorf("no IPs for %q found", host) + return nil, nil, fmt.Errorf("no IPs for %q found", host) } + have4 := false for _, ipa := range ips { if ip4 := ipa.IP.To4(); ip4 != nil { - return r.addIPCache(host, ip4, r.ttl()), nil + if !have4 { + ip6 = ip + ip = ip4 + have4 = true + } + } else { + if have4 { + ip6 = ipa.IP + } else { + ip = ipa.IP + } } } - return r.addIPCache(host, ips[0].IP, r.ttl()), nil + r.addIPCache(host, ip, ip6, r.ttl()) + return ip, ip6, nil } -func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP { +func (r *Resolver) addIPCache(host string, ip, ip6 net.IP, d time.Duration) { if isPrivateIP(ip) { // Don't cache obviously wrong entries from captive portals. // TODO: use DoH or DoT for the forwarding resolver? if debug { log.Printf("dnscache: %q resolved to private IP %v; using but not caching", host, ip) } - return ip + return } if debug { @@ -222,8 +244,7 @@ func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP { if r.ipCache == nil { r.ipCache = make(map[string]ipCacheEntry) } - r.ipCache[host] = ipCacheEntry{ip: ip, expires: time.Now().Add(d)} - return ip + r.ipCache[host] = ipCacheEntry{ip: ip, ip6: ip6, expires: time.Now().Add(d)} } func mustCIDR(s string) *net.IPNet { @@ -255,7 +276,7 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { // inventing a similar one. return fwd(ctx, network, address) } - ip, err := dnsCache.LookupIP(ctx, host) + ip, ip6, err := dnsCache.LookupIP(ctx, host) if err != nil { return nil, fmt.Errorf("failed to resolve %q: %w", host, err) } @@ -263,6 +284,19 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { if debug { log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) } + c, err := fwd(ctx, network, dst) + if err == nil || ctx.Err() != nil || ip6 == nil { + return c, err + } + // Fall back to trying IPv6. + // TODO(bradfitz): this is a primarily for IPv6-only + // hosts; it's not supposed to be a real Happy + // Eyeballs implementation. We should use the net + // package's implementation of that by plumbing this + // dnscache impl into net.Dialer.Resolver.Dial and + // unmarshal/marshal DNS queries/responses to the net + // package. This works for v6-only hosts for now. + dst = net.JoinHostPort(ip6.String(), port) return fwd(ctx, network, dst) } }