From e94ec448a719f55a3f881a03049d54900158872e Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 25 Jul 2021 15:43:49 -0700 Subject: [PATCH] net/dns/resolver: add forwardQuery type as race work prep Add a place to hang state in a future change for #2436. For now this just simplifies the send signature without any functional change. Updates #2436 Signed-off-by: Brad Fitzpatrick --- net/dns/resolver/forwarder.go | 52 +++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index b690285ba..531c2d130 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -316,17 +316,12 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client, // // send expects the reply to have the same txid as txidOut. // -// The provided closeOnCtxDone lets send register values to Close if -// the caller's ctx expires. This avoids send from allocating its own -// waiting goroutine to interrupt the ReadFrom, as memory is tight on -// iOS and we want the number of pending DNS lookups to be bursty -// without too much associated goroutine/memory cost. -func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *closePool, packet []byte, dst netaddr.IPPort) ([]byte, error) { +func (f *forwarder) send(ctx context.Context, fq *forwardQuery, dst netaddr.IPPort) ([]byte, error) { ip := dst.IP() // Upgrade known DNS IPs to DoH (DNS-over-HTTPs). if urlBase, dc, ok := f.getDoHClient(ip); ok { - res, err := f.sendDoH(ctx, urlBase, dc, packet) + res, err := f.sendDoH(ctx, urlBase, dc, fq.packet) if err == nil || ctx.Err() != nil { return res, err } @@ -344,10 +339,10 @@ func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *clos } defer conn.Close() - closeOnCtxDone.Add(conn) - defer closeOnCtxDone.Remove(conn) + fq.closeOnCtxDone.Add(conn) + defer fq.closeOnCtxDone.Remove(conn) - if _, err := conn.WriteTo(packet, dst.UDPAddr()); err != nil { + if _, err := conn.WriteTo(fq.packet, dst.UDPAddr()); err != nil { if err := ctx.Err(); err != nil { return nil, err } @@ -376,7 +371,7 @@ func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *clos } out = out[:n] txid := getTxID(out) - if txid != txidOut { + if txid != fq.txid { return nil, errors.New("txid doesn't match") } @@ -411,6 +406,30 @@ func (f *forwarder) resolvers(domain dnsname.FQDN) []netaddr.IPPort { return nil } +// forwardQuery is information and state about a forwarded DNS query that's +// being sent to 1 or more upstreams. +// +// In the case of racing against multiple equivalent upstreams +// (e.g. Google or CloudFlare's 4 DNS IPs: 2 IPv4 + 2 IPv6), this type +// handles racing them more intelligently than just blasting away 4 +// queries at once. +type forwardQuery struct { + txid txid + packet []byte + + // closeOnCtxDone lets send register values to Close if the + // caller's ctx expires. This avoids send from allocating its + // own waiting goroutine to interrupt the ReadFrom, as memory + // is tight on iOS and we want the number of pending DNS + // lookups to be bursty without too much associated + // goroutine/memory cost. + closeOnCtxDone *closePool + + // TODO(bradfitz): add race delay state: + // mu sync.Mutex + // ... +} + // forward forwards the query to all upstream nameservers and returns the first response. func (f *forwarder) forward(query packet) error { domain, err := nameFromQuery(query.bs) @@ -418,7 +437,6 @@ func (f *forwarder) forward(query packet) error { return err } - txid := getTxID(query.bs) clampEDNSSize(query.bs, maxResponseBytes) resolvers := f.resolvers(domain) @@ -426,8 +444,12 @@ func (f *forwarder) forward(query packet) error { return errNoUpstreams } - closeOnCtxDone := new(closePool) - defer closeOnCtxDone.Close() + fq := &forwardQuery{ + txid: getTxID(query.bs), + packet: query.bs, + closeOnCtxDone: new(closePool), + } + defer fq.closeOnCtxDone.Close() ctx, cancel := context.WithTimeout(f.ctx, responseTimeout) defer cancel() @@ -440,7 +462,7 @@ func (f *forwarder) forward(query packet) error { for _, ipp := range resolvers { go func(ipp netaddr.IPPort) { - resb, err := f.send(ctx, txid, closeOnCtxDone, query.bs, ipp) + resb, err := f.send(ctx, fq, ipp) if err != nil { mu.Lock() defer mu.Unlock()