net/dnscache: refactor from func-y closure-y state to types & methods

No behavior changes (intended, at least).

This is in prep for future changes to this package, which would get
too complicated in the current style.

Change-Id: Ic260f8e34ae2f64f34819d4a56e38bee8d8ac5ce
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/3938/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent 8267ea0f80
commit 903988b392

@ -278,53 +278,78 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
// Dialer returns a wrapped DialContext func that uses the provided dnsCache. // Dialer returns a wrapped DialContext func that uses the provided dnsCache.
func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
return func(ctx context.Context, network, address string) (retConn net.Conn, ret error) { d := &dialer{
host, port, err := net.SplitHostPort(address) fwd: fwd,
if err != nil { dnsCache: dnsCache,
// Bogus. But just let the real dialer return an error rather than }
// inventing a similar one. return d.DialContext
return fwd(ctx, network, address) }
}
defer func() {
// On any failure, assume our DNS is wrong and try our fallback, if any.
if ret == nil || dnsCache.LookupIPFallback == nil {
return
}
ips, err := dnsCache.LookupIPFallback(ctx, host)
if err != nil {
// Return with original error
return
}
if c, err := raceDial(ctx, fwd, network, ips, port); err == nil {
retConn = c
ret = nil
return
}
}()
ip, ip6, allIPs, err := dnsCache.LookupIP(ctx, host) // dialer is the config and accumulated state for a dial func returned by Dialer.
type dialer struct {
fwd DialContextFunc
dnsCache *Resolver
}
func (d *dialer) DialContext(ctx context.Context, network, address string) (retConn net.Conn, ret error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
// Bogus. But just let the real dialer return an error rather than
// inventing a similar one.
return d.fwd(ctx, network, address)
}
dc := &dialCall{
d: d,
network: network,
address: address,
host: host,
port: port,
}
defer func() {
// On any failure, assume our DNS is wrong and try our fallback, if any.
if ret == nil || d.dnsCache.LookupIPFallback == nil {
return
}
ips, err := d.dnsCache.LookupIPFallback(ctx, host)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to resolve %q: %w", host, err) // Return with original error
return
} }
i4s := v4addrs(allIPs) if c, err := dc.raceDial(ctx, ips); err == nil {
if len(i4s) < 2 { retConn = c
dst := net.JoinHostPort(ip.String(), port) ret = nil
if debug { return
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
}
c, err := fwd(ctx, network, dst)
if err == nil || ctx.Err() != nil || ip6 == nil {
return c, err
}
// Fall back to trying IPv6.
dst = net.JoinHostPort(ip6.String(), port)
return fwd(ctx, network, dst)
} }
}()
// Multiple IPv4 candidates, and 0+ IPv6. ip, ip6, allIPs, err := d.dnsCache.LookupIP(ctx, host)
ipsToTry := append(i4s, v6addrs(allIPs)...) if err != nil {
return raceDial(ctx, fwd, network, ipsToTry, port) return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
}
i4s := v4addrs(allIPs)
if len(i4s) < 2 {
dst := net.JoinHostPort(ip.String(), port)
if debug {
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
}
c, err := d.fwd(ctx, network, dst)
if err == nil || ctx.Err() != nil || ip6 == nil {
return c, err
}
// Fall back to trying IPv6.
dst = net.JoinHostPort(ip6.String(), port)
return d.fwd(ctx, network, dst)
} }
// Multiple IPv4 candidates, and 0+ IPv6.
ipsToTry := append(i4s, v6addrs(allIPs)...)
return dc.raceDial(ctx, ipsToTry)
}
// dialCall is the state around a single call to dial.
type dialCall struct {
d *dialer
network, address, host, port string
} }
// fallbackDelay is how long to wait between trying subsequent // fallbackDelay is how long to wait between trying subsequent
@ -334,7 +359,12 @@ const fallbackDelay = 300 * time.Millisecond
// raceDial tries to dial port on each ip in ips, starting a new race // raceDial tries to dial port on each ip in ips, starting a new race
// dial every fallbackDelay apart, returning whichever completes first. // dial every fallbackDelay apart, returning whichever completes first.
func raceDial(ctx context.Context, fwd DialContextFunc, network string, ips []netaddr.IP, port string) (net.Conn, error) { func (dc *dialCall) raceDial(ctx context.Context, ips []netaddr.IP) (net.Conn, error) {
var (
fwd = dc.d.fwd
network = dc.network
port = dc.port
)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()

Loading…
Cancel
Save