From 903988b39204837066898b5b8d06a428c3c1a9f7 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 14 Feb 2022 09:38:23 -0800 Subject: [PATCH] net/dnscache: refactor from func-y closure-y state to types & methods No behavior changes (intended, at least). This is in prep for future changes to this package, which would get too complicated in the current style. Change-Id: Ic260f8e34ae2f64f34819d4a56e38bee8d8ac5ce Signed-off-by: Brad Fitzpatrick --- net/dnscache/dnscache.go | 114 ++++++++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 42 deletions(-) diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 508769753..835158de7 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -278,53 +278,78 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con // Dialer returns a wrapped DialContext func that uses the provided dnsCache. func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { - return func(ctx context.Context, network, address string) (retConn net.Conn, ret error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - // Bogus. But just let the real dialer return an error rather than - // inventing a similar one. - return fwd(ctx, network, address) - } - defer func() { - // On any failure, assume our DNS is wrong and try our fallback, if any. - if ret == nil || dnsCache.LookupIPFallback == nil { - return - } - ips, err := dnsCache.LookupIPFallback(ctx, host) - if err != nil { - // Return with original error - return - } - if c, err := raceDial(ctx, fwd, network, ips, port); err == nil { - retConn = c - ret = nil - return - } - }() + d := &dialer{ + fwd: fwd, + dnsCache: dnsCache, + } + return d.DialContext +} - ip, ip6, allIPs, err := dnsCache.LookupIP(ctx, host) +// dialer is the config and accumulated state for a dial func returned by Dialer. +type dialer struct { + fwd DialContextFunc + dnsCache *Resolver +} + +func (d *dialer) DialContext(ctx context.Context, network, address string) (retConn net.Conn, ret error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + // Bogus. But just let the real dialer return an error rather than + // inventing a similar one. + return d.fwd(ctx, network, address) + } + dc := &dialCall{ + d: d, + network: network, + address: address, + host: host, + port: port, + } + defer func() { + // On any failure, assume our DNS is wrong and try our fallback, if any. + if ret == nil || d.dnsCache.LookupIPFallback == nil { + return + } + ips, err := d.dnsCache.LookupIPFallback(ctx, host) if err != nil { - return nil, fmt.Errorf("failed to resolve %q: %w", host, err) + // Return with original error + return } - i4s := v4addrs(allIPs) - if len(i4s) < 2 { - dst := net.JoinHostPort(ip.String(), port) - if debug { - log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) - } - c, err := fwd(ctx, network, dst) - if err == nil || ctx.Err() != nil || ip6 == nil { - return c, err - } - // Fall back to trying IPv6. - dst = net.JoinHostPort(ip6.String(), port) - return fwd(ctx, network, dst) + if c, err := dc.raceDial(ctx, ips); err == nil { + retConn = c + ret = nil + return } + }() - // Multiple IPv4 candidates, and 0+ IPv6. - ipsToTry := append(i4s, v6addrs(allIPs)...) - return raceDial(ctx, fwd, network, ipsToTry, port) + ip, ip6, allIPs, err := d.dnsCache.LookupIP(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to resolve %q: %w", host, err) + } + i4s := v4addrs(allIPs) + if len(i4s) < 2 { + dst := net.JoinHostPort(ip.String(), port) + if debug { + log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) + } + c, err := d.fwd(ctx, network, dst) + if err == nil || ctx.Err() != nil || ip6 == nil { + return c, err + } + // Fall back to trying IPv6. + dst = net.JoinHostPort(ip6.String(), port) + return d.fwd(ctx, network, dst) } + + // Multiple IPv4 candidates, and 0+ IPv6. + ipsToTry := append(i4s, v6addrs(allIPs)...) + return dc.raceDial(ctx, ipsToTry) +} + +// dialCall is the state around a single call to dial. +type dialCall struct { + d *dialer + network, address, host, port string } // fallbackDelay is how long to wait between trying subsequent @@ -334,7 +359,12 @@ const fallbackDelay = 300 * time.Millisecond // raceDial tries to dial port on each ip in ips, starting a new race // dial every fallbackDelay apart, returning whichever completes first. -func raceDial(ctx context.Context, fwd DialContextFunc, network string, ips []netaddr.IP, port string) (net.Conn, error) { +func (dc *dialCall) raceDial(ctx context.Context, ips []netaddr.IP) (net.Conn, error) { + var ( + fwd = dc.d.fwd + network = dc.network + port = dc.port + ) ctx, cancel := context.WithCancel(ctx) defer cancel()