net/dnscache: make Dialer try all resolved IPs

Tested manually with:

$ go test -v ./net/dnscache/ -dial-test=bogusplane.dev.tailscale.com:80

Where bogusplane has three A records, only one of which works.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/2523/head
Brad Fitzpatrick 3 years ago committed by Denton Gentry
parent dfa5e38fad
commit 281d503626

@ -314,20 +314,19 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
// Return with original error // Return with original error
return return
} }
for _, ip := range ips { if c, err := raceDial(ctx, fwd, network, ips, port); err == nil {
dst := net.JoinHostPort(ip.String(), port)
if c, err := fwd(ctx, network, dst); err == nil {
retConn = c retConn = c
ret = nil ret = nil
return return
} }
}
}() }()
ip, ip6, _, err := dnsCache.LookupIP(ctx, host) ip, ip6, allIPs, err := dnsCache.LookupIP(ctx, host)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to resolve %q: %w", host, err) return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
} }
i4s := v4addrs(allIPs)
if len(i4s) < 2 {
dst := net.JoinHostPort(ip.String(), port) dst := net.JoinHostPort(ip.String(), port)
if debug { if debug {
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
@ -337,16 +336,107 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
return c, err return c, err
} }
// Fall back to trying IPv6. // Fall back to trying IPv6.
// TODO(bradfitz): this is a primarily for IPv6-only
// hosts; it's not supposed to be a real Happy
// Eyeballs implementation. We should use the net
// package's implementation of that by plumbing this
// dnscache impl into net.Dialer.Resolver.Dial and
// unmarshal/marshal DNS queries/responses to the net
// package. This works for v6-only hosts for now.
dst = net.JoinHostPort(ip6.String(), port) dst = net.JoinHostPort(ip6.String(), port)
return fwd(ctx, network, dst) return fwd(ctx, network, dst)
} }
// Multiple IPv4 candidates, and 0+ IPv6.
ipsToTry := append(i4s, v6addrs(allIPs)...)
return raceDial(ctx, fwd, network, ipsToTry, port)
}
}
// fallbackDelay is how long to wait between trying subsequent
// addresses when multiple options are available.
// 300ms is the same as Go's Happy Eyeballs fallbackDelay value.
const fallbackDelay = 300 * time.Millisecond
// raceDial tries to dial port on each ip in ips, starting a new race
// dial every 300ms apart, returning whichever completes first.
func raceDial(ctx context.Context, fwd DialContextFunc, network string, ips []netaddr.IP, port string) (net.Conn, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
type res struct {
c net.Conn
err error
}
resc := make(chan res) // must be unbuffered
failBoost := make(chan struct{}) // best effort send on dial failure
go func() {
for i, ip := range ips {
if i != 0 {
timer := time.NewTimer(fallbackDelay)
select {
case <-timer.C:
case <-failBoost:
timer.Stop()
case <-ctx.Done():
timer.Stop()
return
}
}
go func(ip netaddr.IP) {
c, err := fwd(ctx, network, net.JoinHostPort(ip.String(), port))
if err != nil {
// Best effort wake-up a pending dial.
// e.g. IPv4 dials failing quickly on an IPv6-only system.
// In that case we don't want to wait 300ms per IPv4 before
// we get to the IPv6 addresses.
select {
case failBoost <- struct{}{}:
default:
}
}
select {
case resc <- res{c, err}:
case <-ctx.Done():
if c != nil {
c.Close()
}
}
}(ip)
}
}()
var firstErr error
var fails int
for {
select {
case r := <-resc:
if r.c != nil {
return r.c, nil
}
fails++
if firstErr == nil {
firstErr = r.err
}
if fails == len(ips) {
return nil, firstErr
}
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
func v4addrs(aa []net.IPAddr) (ret []netaddr.IP) {
for _, a := range aa {
if ip, ok := netaddr.FromStdIP(a.IP); ok && ip.Is4() {
ret = append(ret, ip)
}
}
return ret
}
func v6addrs(aa []net.IPAddr) (ret []netaddr.IP) {
for _, a := range aa {
if ip, ok := netaddr.FromStdIP(a.IP); ok && ip.Is6() {
ret = append(ret, ip)
}
}
return ret
} }
var errTLSHandshakeTimeout = errors.New("timeout doing TLS handshake") var errTLSHandshakeTimeout = errors.New("timeout doing TLS handshake")

@ -5,10 +5,15 @@
package dnscache package dnscache
import ( import (
"context"
"flag"
"net" "net"
"testing" "testing"
"time"
) )
var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
func TestIsPrivateIP(t *testing.T) { func TestIsPrivateIP(t *testing.T) {
tests := []struct { tests := []struct {
ip string ip string
@ -26,3 +31,21 @@ func TestIsPrivateIP(t *testing.T) {
} }
} }
} }
func TestDialer(t *testing.T) {
if *dialTest == "" {
t.Skip("skipping; --dial-test is blank")
}
r := new(Resolver)
var std net.Dialer
dialer := Dialer(std.DialContext, r)
t0 := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
c, err := dialer(ctx, "tcp", *dialTest)
if err != nil {
t.Fatal(err)
}
t.Logf("dialed in %v", time.Since(t0))
c.Close()
}

Loading…
Cancel
Save