@ -278,53 +278,78 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
func Dialer ( fwd DialContextFunc , dnsCache * Resolver ) DialContextFunc {
func Dialer ( fwd DialContextFunc , dnsCache * Resolver ) DialContextFunc {
return func ( ctx context . Context , network , address string ) ( retConn net . Conn , ret error ) {
d := & dialer {
host , port , err := net . SplitHostPort ( address )
fwd : fwd ,
if err != nil {
dnsCache : dnsCache ,
// Bogus. But just let the real dialer return an error rather than
}
// inventing a similar one.
return d . DialContext
return fwd ( ctx , network , address )
}
}
defer func ( ) {
// On any failure, assume our DNS is wrong and try our fallback, if any.
if ret == nil || dnsCache . LookupIPFallback == nil {
return
}
ips , err := dnsCache . LookupIPFallback ( ctx , host )
if err != nil {
// Return with original error
return
}
if c , err := raceDial ( ctx , fwd , network , ips , port ) ; err == nil {
retConn = c
ret = nil
return
}
} ( )
ip , ip6 , allIPs , err := dnsCache . LookupIP ( ctx , host )
// dialer is the config and accumulated state for a dial func returned by Dialer.
type dialer struct {
fwd DialContextFunc
dnsCache * Resolver
}
func ( d * dialer ) DialContext ( ctx context . Context , network , address string ) ( retConn net . Conn , ret error ) {
host , port , err := net . SplitHostPort ( address )
if err != nil {
// Bogus. But just let the real dialer return an error rather than
// inventing a similar one.
return d . fwd ( ctx , network , address )
}
dc := & dialCall {
d : d ,
network : network ,
address : address ,
host : host ,
port : port ,
}
defer func ( ) {
// On any failure, assume our DNS is wrong and try our fallback, if any.
if ret == nil || d . dnsCache . LookupIPFallback == nil {
return
}
ips , err := d . dnsCache . LookupIPFallback ( ctx , host )
if err != nil {
if err != nil {
return nil , fmt . Errorf ( "failed to resolve %q: %w" , host , err )
// Return with original error
return
}
}
i4s := v4addrs ( allIPs )
if c , err := dc . raceDial ( ctx , ips ) ; err == nil {
if len ( i4s ) < 2 {
retConn = c
dst := net . JoinHostPort ( ip . String ( ) , port )
ret = nil
if debug {
return
log . Printf ( "dnscache: dialing %s, %s for %s" , network , dst , address )
}
c , err := fwd ( ctx , network , dst )
if err == nil || ctx . Err ( ) != nil || ip6 == nil {
return c , err
}
// Fall back to trying IPv6.
dst = net . JoinHostPort ( ip6 . String ( ) , port )
return fwd ( ctx , network , dst )
}
}
} ( )
// Multiple IPv4 candidates, and 0+ IPv6.
ip , ip6 , allIPs , err := d . dnsCache . LookupIP ( ctx , host )
ipsToTry := append ( i4s , v6addrs ( allIPs ) ... )
if err != nil {
return raceDial ( ctx , fwd , network , ipsToTry , port )
return nil , fmt . Errorf ( "failed to resolve %q: %w" , host , err )
}
i4s := v4addrs ( allIPs )
if len ( i4s ) < 2 {
dst := net . JoinHostPort ( ip . String ( ) , port )
if debug {
log . Printf ( "dnscache: dialing %s, %s for %s" , network , dst , address )
}
c , err := d . fwd ( ctx , network , dst )
if err == nil || ctx . Err ( ) != nil || ip6 == nil {
return c , err
}
// Fall back to trying IPv6.
dst = net . JoinHostPort ( ip6 . String ( ) , port )
return d . fwd ( ctx , network , dst )
}
}
// Multiple IPv4 candidates, and 0+ IPv6.
ipsToTry := append ( i4s , v6addrs ( allIPs ) ... )
return dc . raceDial ( ctx , ipsToTry )
}
// dialCall is the state around a single call to dial.
type dialCall struct {
d * dialer
network , address , host , port string
}
}
// fallbackDelay is how long to wait between trying subsequent
// fallbackDelay is how long to wait between trying subsequent
@ -334,7 +359,12 @@ const fallbackDelay = 300 * time.Millisecond
// raceDial tries to dial port on each ip in ips, starting a new race
// raceDial tries to dial port on each ip in ips, starting a new race
// dial every fallbackDelay apart, returning whichever completes first.
// dial every fallbackDelay apart, returning whichever completes first.
func raceDial ( ctx context . Context , fwd DialContextFunc , network string , ips [ ] netaddr . IP , port string ) ( net . Conn , error ) {
func ( dc * dialCall ) raceDial ( ctx context . Context , ips [ ] netaddr . IP ) ( net . Conn , error ) {
var (
fwd = dc . d . fwd
network = dc . network
port = dc . port
)
ctx , cancel := context . WithCancel ( ctx )
ctx , cancel := context . WithCancel ( ctx )
defer cancel ( )
defer cancel ( )