diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index b23c7d516..3d061e207 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -213,8 +213,8 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 netip.Addr return ip, ip6, allIPs, nil } - ch := r.sf.DoChan(host, func() (ret ipRes, _ error) { - ip, ip6, allIPs, err := r.lookupIP(host) + ch := r.sf.DoChanContext(ctx, host, func(ctx context.Context) (ret ipRes, _ error) { + ip, ip6, allIPs, err := r.lookupIP(ctx, host) if err != nil { return ret, err } @@ -275,30 +275,30 @@ func (r *Resolver) lookupTimeoutForHost(host string) time.Duration { return 10 * time.Second } -func (r *Resolver) lookupIP(host string) (ip, ip6 netip.Addr, allIPs []netip.Addr, err error) { +func (r *Resolver) lookupIP(ctx context.Context, host string) (ip, ip6 netip.Addr, allIPs []netip.Addr, err error) { if ip, ip6, allIPs, ok := r.lookupIPCache(host); ok { r.dlogf("%q found in cache as %v", host, ip) return ip, ip6, allIPs, nil } - ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host)) - defer cancel() - ips, err := r.fwd().LookupNetIP(ctx, "ip", host) + lookupCtx, lookupCancel := context.WithTimeout(ctx, r.lookupTimeoutForHost(host)) + defer lookupCancel() + ips, err := r.fwd().LookupNetIP(lookupCtx, "ip", host) if err != nil || len(ips) == 0 { if resolver, ok := r.cloudHostResolver(); ok { r.dlogf("resolving %q via cloud resolver", host) - ips, err = resolver.LookupNetIP(ctx, "ip", host) + ips, err = resolver.LookupNetIP(lookupCtx, "ip", host) } } if (err != nil || len(ips) == 0) && r.LookupIPFallback != nil { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() + lookupCtx, lookupCancel := context.WithTimeout(ctx, 30*time.Second) + defer lookupCancel() if err != nil { r.dlogf("resolving %q using fallback resolver due to error", host) } else { r.dlogf("resolving %q using fallback resolver due to no returned IPs", host) } - ips, err = r.LookupIPFallback(ctx, host) + ips, err = r.LookupIPFallback(lookupCtx, host) } if err != nil { return netip.Addr{}, netip.Addr{}, nil, err