@ -278,31 +278,51 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
func Dialer ( fwd DialContextFunc , dnsCache * Resolver ) DialContextFunc {
func Dialer ( fwd DialContextFunc , dnsCache * Resolver ) DialContextFunc {
return func ( ctx context . Context , network , address string ) ( retConn net . Conn , ret error ) {
d := & dialer {
fwd : fwd ,
dnsCache : dnsCache ,
}
return d . DialContext
}
// dialer is the config and accumulated state for a dial func returned by Dialer.
type dialer struct {
fwd DialContextFunc
dnsCache * Resolver
}
func ( d * dialer ) DialContext ( ctx context . Context , network , address string ) ( retConn net . Conn , ret error ) {
host , port , err := net . SplitHostPort ( address )
host , port , err := net . SplitHostPort ( address )
if err != nil {
if err != nil {
// Bogus. But just let the real dialer return an error rather than
// Bogus. But just let the real dialer return an error rather than
// inventing a similar one.
// inventing a similar one.
return fwd ( ctx , network , address )
return d . fwd ( ctx , network , address )
}
dc := & dialCall {
d : d ,
network : network ,
address : address ,
host : host ,
port : port ,
}
}
defer func ( ) {
defer func ( ) {
// On any failure, assume our DNS is wrong and try our fallback, if any.
// On any failure, assume our DNS is wrong and try our fallback, if any.
if ret == nil || dnsCache . LookupIPFallback == nil {
if ret == nil || d . dnsCache . LookupIPFallback == nil {
return
return
}
}
ips , err := dnsCache . LookupIPFallback ( ctx , host )
ips , err := d . dnsCache . LookupIPFallback ( ctx , host )
if err != nil {
if err != nil {
// Return with original error
// Return with original error
return
return
}
}
if c , err := raceDial ( ctx , fwd, network , ips, port ) ; err == nil {
if c , err := dc . raceDial ( ctx , ips) ; err == nil {
retConn = c
retConn = c
ret = nil
ret = nil
return
return
}
}
} ( )
} ( )
ip , ip6 , allIPs , err := dnsCache . LookupIP ( ctx , host )
ip , ip6 , allIPs , err := d . dnsCache . LookupIP ( ctx , host )
if err != nil {
if err != nil {
return nil , fmt . Errorf ( "failed to resolve %q: %w" , host , err )
return nil , fmt . Errorf ( "failed to resolve %q: %w" , host , err )
}
}
@ -312,19 +332,24 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
if debug {
if debug {
log . Printf ( "dnscache: dialing %s, %s for %s" , network , dst , address )
log . Printf ( "dnscache: dialing %s, %s for %s" , network , dst , address )
}
}
c , err := fwd ( ctx , network , dst )
c , err := d . fwd ( ctx , network , dst )
if err == nil || ctx . Err ( ) != nil || ip6 == nil {
if err == nil || ctx . Err ( ) != nil || ip6 == nil {
return c , err
return c , err
}
}
// Fall back to trying IPv6.
// Fall back to trying IPv6.
dst = net . JoinHostPort ( ip6 . String ( ) , port )
dst = net . JoinHostPort ( ip6 . String ( ) , port )
return fwd ( ctx , network , dst )
return d . fwd ( ctx , network , dst )
}
}
// Multiple IPv4 candidates, and 0+ IPv6.
// Multiple IPv4 candidates, and 0+ IPv6.
ipsToTry := append ( i4s , v6addrs ( allIPs ) ... )
ipsToTry := append ( i4s , v6addrs ( allIPs ) ... )
return raceDial ( ctx , fwd, network , ipsToTry, port )
return dc . raceDial ( ctx , ipsToTry)
}
}
// dialCall is the state around a single call to dial.
type dialCall struct {
d * dialer
network , address , host , port string
}
}
// fallbackDelay is how long to wait between trying subsequent
// fallbackDelay is how long to wait between trying subsequent
@ -334,7 +359,12 @@ const fallbackDelay = 300 * time.Millisecond
// raceDial tries to dial port on each ip in ips, starting a new race
// raceDial tries to dial port on each ip in ips, starting a new race
// dial every fallbackDelay apart, returning whichever completes first.
// dial every fallbackDelay apart, returning whichever completes first.
func raceDial ( ctx context . Context , fwd DialContextFunc , network string , ips [ ] netaddr . IP , port string ) ( net . Conn , error ) {
func ( dc * dialCall ) raceDial ( ctx context . Context , ips [ ] netaddr . IP ) ( net . Conn , error ) {
var (
fwd = dc . d . fwd
network = dc . network
port = dc . port
)
ctx , cancel := context . WithCancel ( ctx )
ctx , cancel := context . WithCancel ( ctx )
defer cancel ( )
defer cancel ( )