From 281d5036261c77fddf54d89479044493b8075547 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 26 Jul 2021 14:36:21 -0700 Subject: [PATCH] net/dnscache: make Dialer try all resolved IPs Tested manually with: $ go test -v ./net/dnscache/ -dial-test=bogusplane.dev.tailscale.com:80 Where bogusplane has three A records, only one of which works. Signed-off-by: Brad Fitzpatrick --- net/dnscache/dnscache.go | 138 ++++++++++++++++++++++++++++------ net/dnscache/dnscache_test.go | 23 ++++++ 2 files changed, 137 insertions(+), 24 deletions(-) diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index e6a5b60ef..072d6fd7a 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -314,39 +314,129 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { // Return with original error return } - for _, ip := range ips { - dst := net.JoinHostPort(ip.String(), port) - if c, err := fwd(ctx, network, dst); err == nil { - retConn = c - ret = nil - return - } + if c, err := raceDial(ctx, fwd, network, ips, port); err == nil { + retConn = c + ret = nil + return } }() - ip, ip6, _, err := dnsCache.LookupIP(ctx, host) + ip, ip6, allIPs, err := dnsCache.LookupIP(ctx, host) if err != nil { return nil, fmt.Errorf("failed to resolve %q: %w", host, err) } - dst := net.JoinHostPort(ip.String(), port) - if debug { - log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) + i4s := v4addrs(allIPs) + if len(i4s) < 2 { + dst := net.JoinHostPort(ip.String(), port) + if debug { + log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) + } + c, err := fwd(ctx, network, dst) + if err == nil || ctx.Err() != nil || ip6 == nil { + return c, err + } + // Fall back to trying IPv6. + dst = net.JoinHostPort(ip6.String(), port) + return fwd(ctx, network, dst) + } + + // Multiple IPv4 candidates, and 0+ IPv6. + ipsToTry := append(i4s, v6addrs(allIPs)...) + return raceDial(ctx, fwd, network, ipsToTry, port) + } +} + +// fallbackDelay is how long to wait between trying subsequent +// addresses when multiple options are available. +// 300ms is the same as Go's Happy Eyeballs fallbackDelay value. +const fallbackDelay = 300 * time.Millisecond + +// raceDial tries to dial port on each ip in ips, starting a new race +// dial every 300ms apart, returning whichever completes first. +func raceDial(ctx context.Context, fwd DialContextFunc, network string, ips []netaddr.IP, port string) (net.Conn, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + type res struct { + c net.Conn + err error + } + resc := make(chan res) // must be unbuffered + failBoost := make(chan struct{}) // best effort send on dial failure + + go func() { + for i, ip := range ips { + if i != 0 { + timer := time.NewTimer(fallbackDelay) + select { + case <-timer.C: + case <-failBoost: + timer.Stop() + case <-ctx.Done(): + timer.Stop() + return + } + } + go func(ip netaddr.IP) { + c, err := fwd(ctx, network, net.JoinHostPort(ip.String(), port)) + if err != nil { + // Best effort wake-up a pending dial. + // e.g. IPv4 dials failing quickly on an IPv6-only system. + // In that case we don't want to wait 300ms per IPv4 before + // we get to the IPv6 addresses. + select { + case failBoost <- struct{}{}: + default: + } + } + select { + case resc <- res{c, err}: + case <-ctx.Done(): + if c != nil { + c.Close() + } + } + }(ip) } - c, err := fwd(ctx, network, dst) - if err == nil || ctx.Err() != nil || ip6 == nil { - return c, err + }() + + var firstErr error + var fails int + for { + select { + case r := <-resc: + if r.c != nil { + return r.c, nil + } + fails++ + if firstErr == nil { + firstErr = r.err + } + if fails == len(ips) { + return nil, firstErr + } + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func v4addrs(aa []net.IPAddr) (ret []netaddr.IP) { + for _, a := range aa { + if ip, ok := netaddr.FromStdIP(a.IP); ok && ip.Is4() { + ret = append(ret, ip) + } + } + return ret +} + +func v6addrs(aa []net.IPAddr) (ret []netaddr.IP) { + for _, a := range aa { + if ip, ok := netaddr.FromStdIP(a.IP); ok && ip.Is6() { + ret = append(ret, ip) } - // Fall back to trying IPv6. - // TODO(bradfitz): this is a primarily for IPv6-only - // hosts; it's not supposed to be a real Happy - // Eyeballs implementation. We should use the net - // package's implementation of that by plumbing this - // dnscache impl into net.Dialer.Resolver.Dial and - // unmarshal/marshal DNS queries/responses to the net - // package. This works for v6-only hosts for now. - dst = net.JoinHostPort(ip6.String(), port) - return fwd(ctx, network, dst) } + return ret } var errTLSHandshakeTimeout = errors.New("timeout doing TLS handshake") diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index 96e1cc8a5..10d986da7 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -5,10 +5,15 @@ package dnscache import ( + "context" + "flag" "net" "testing" + "time" ) +var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") + func TestIsPrivateIP(t *testing.T) { tests := []struct { ip string @@ -26,3 +31,21 @@ func TestIsPrivateIP(t *testing.T) { } } } + +func TestDialer(t *testing.T) { + if *dialTest == "" { + t.Skip("skipping; --dial-test is blank") + } + r := new(Resolver) + var std net.Dialer + dialer := Dialer(std.DialContext, r) + t0 := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + c, err := dialer(ctx, "tcp", *dialTest) + if err != nil { + t.Fatal(err) + } + t.Logf("dialed in %v", time.Since(t0)) + c.Close() +}