diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 2ea05a6b7..14a8c3e3c 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -361,7 +361,7 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC defer func() { // On failure, consider that our DNS might be wrong and ask the DNS fallback mechanism for // some other IPs to try. - if ret == nil || ctx.Err() != nil || d.dnsCache.LookupIPFallback == nil || dc.dnsWasTrustworthy() { + if !d.shouldTryBootstrap(ctx, ret, dc) { return } ips, err := d.dnsCache.LookupIPFallback(ctx, host) @@ -398,6 +398,40 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC return dc.raceDial(ctx, ipsToTry) } +func (d *dialer) shouldTryBootstrap(ctx context.Context, err error, dc *dialCall) bool { + // No need to do anything when we succeeded. + if err == nil { + return false + } + + // Can't try bootstrap DNS if we don't have a fallback function + if d.dnsCache.LookupIPFallback == nil { + if debug { + log.Printf("dnscache: not using bootstrap DNS: no fallback") + } + return false + } + + // We can't retry if the context is canceled, since any further + // operations with this context will fail. + if ctxErr := ctx.Err(); ctxErr != nil { + if debug { + log.Printf("dnscache: not using bootstrap DNS: context error: %v", ctxErr) + } + return false + } + + wasTrustworthy := dc.dnsWasTrustworthy() + if wasTrustworthy { + if debug { + log.Printf("dnscache: not using bootstrap DNS: DNS was trustworthy") + } + return false + } + + return true +} + // dialCall is the state around a single call to dial. type dialCall struct { d *dialer diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index b99992148..2e64a87e2 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -164,3 +164,104 @@ func TestInterleaveSlices(t *testing.T) { }) } } + +func TestShouldTryBootstrap(t *testing.T) { + oldDebug := debug + t.Cleanup(func() { + debug = oldDebug + }) + debug = true + + type step struct { + ip netip.Addr // IP we pretended to dial + err error // the dial error or nil for success + } + + canceled, cancel := context.WithCancel(context.Background()) + cancel() + + deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0) + defer cancel() + + ctx := context.Background() + errFailed := errors.New("some failure") + + cacheWithFallback := &Resolver{ + LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) { + panic("unimplemented") + }, + } + cacheNoFallback := &Resolver{} + + testCases := []struct { + name string + steps []step + ctx context.Context + err error + noFallback bool + want bool + }{ + { + name: "no-error", + ctx: ctx, + err: nil, + want: false, + }, + { + name: "canceled", + ctx: canceled, + err: errFailed, + want: false, + }, + { + name: "deadline-exceeded", + ctx: deadlineExceeded, + err: errFailed, + want: false, + }, + { + name: "no-fallback", + ctx: ctx, + err: errFailed, + noFallback: true, + want: false, + }, + { + name: "dns-was-trustworthy", + ctx: ctx, + err: errFailed, + steps: []step{ + {netip.MustParseAddr("2003::1"), nil}, + {netip.MustParseAddr("2003::1"), errFailed}, + }, + want: false, + }, + { + name: "should-bootstrap", + ctx: ctx, + err: errFailed, + want: true, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + d := &dialer{ + pastConnect: map[netip.Addr]time.Time{}, + } + if tt.noFallback { + d.dnsCache = cacheNoFallback + } else { + d.dnsCache = cacheWithFallback + } + dc := &dialCall{d: d} + for _, st := range tt.steps { + dc.noteDialResult(st.ip, st.err) + } + got := d.shouldTryBootstrap(tt.ctx, tt.err, dc) + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +}