diff --git a/net/dns/manager.go b/net/dns/manager.go index 59bcdcf17..82dd5d47b 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -331,6 +331,9 @@ func (m *Manager) NextPacket() ([]byte, error) { return buf, nil } +// Query executes a DNS query recieved from the given address. The query is +// provided in bs as a wire-encoded DNS query without any transport header. +// This method is called for requests arriving over UDP and TCP. func (m *Manager) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]byte, error) { select { case <-m.ctx.Done(): @@ -460,7 +463,7 @@ func (m *Manager) HandleTCPConn(conn net.Conn, srcAddr netaddr.IPPort) { responses: make(chan []byte), readClosing: make(chan struct{}), } - s.ctx, s.closeCtx = context.WithCancel(context.Background()) + s.ctx, s.closeCtx = context.WithCancel(m.ctx) go s.handleReads() s.handleWrites() } diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 27aa87e04..b1ee65142 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -256,6 +256,12 @@ func (r *Resolver) Close() { r.forwarder.Close() } +// dnsQueryTimeout is not intended to be user-visible (the users +// DNS resolver will retry well before that), just put an upper +// bound on per-query resource usage. +const dnsQueryTimeout = 10 * time.Second + + func (r *Resolver) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]byte, error) { metricDNSQueryLocal.Add(1) select { @@ -268,7 +274,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([ out, err := r.respond(bs) if err == errNotOurName { responses := make(chan packet, 1) - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout) defer close(responses) defer cancel() err = r.forwarder.forwardWithDestChan(ctx, packet{bs, from}, responses)