net/dns/resolver: add forwardQuery type as race work prep

Add a place to hang state in a future change for #2436.
For now this just simplifies the send signature without
any functional change.

Updates #2436

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/2513/head
Brad Fitzpatrick 3 years ago
parent 064b916b1a
commit e94ec448a7

@ -316,17 +316,12 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
// //
// send expects the reply to have the same txid as txidOut. // send expects the reply to have the same txid as txidOut.
// //
// The provided closeOnCtxDone lets send register values to Close if func (f *forwarder) send(ctx context.Context, fq *forwardQuery, dst netaddr.IPPort) ([]byte, error) {
// the caller's ctx expires. This avoids send from allocating its own
// waiting goroutine to interrupt the ReadFrom, as memory is tight on
// iOS and we want the number of pending DNS lookups to be bursty
// without too much associated goroutine/memory cost.
func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *closePool, packet []byte, dst netaddr.IPPort) ([]byte, error) {
ip := dst.IP() ip := dst.IP()
// Upgrade known DNS IPs to DoH (DNS-over-HTTPs). // Upgrade known DNS IPs to DoH (DNS-over-HTTPs).
if urlBase, dc, ok := f.getDoHClient(ip); ok { if urlBase, dc, ok := f.getDoHClient(ip); ok {
res, err := f.sendDoH(ctx, urlBase, dc, packet) res, err := f.sendDoH(ctx, urlBase, dc, fq.packet)
if err == nil || ctx.Err() != nil { if err == nil || ctx.Err() != nil {
return res, err return res, err
} }
@ -344,10 +339,10 @@ func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *clos
} }
defer conn.Close() defer conn.Close()
closeOnCtxDone.Add(conn) fq.closeOnCtxDone.Add(conn)
defer closeOnCtxDone.Remove(conn) defer fq.closeOnCtxDone.Remove(conn)
if _, err := conn.WriteTo(packet, dst.UDPAddr()); err != nil { if _, err := conn.WriteTo(fq.packet, dst.UDPAddr()); err != nil {
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return nil, err return nil, err
} }
@ -376,7 +371,7 @@ func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *clos
} }
out = out[:n] out = out[:n]
txid := getTxID(out) txid := getTxID(out)
if txid != txidOut { if txid != fq.txid {
return nil, errors.New("txid doesn't match") return nil, errors.New("txid doesn't match")
} }
@ -411,6 +406,30 @@ func (f *forwarder) resolvers(domain dnsname.FQDN) []netaddr.IPPort {
return nil return nil
} }
// forwardQuery is information and state about a forwarded DNS query that's
// being sent to 1 or more upstreams.
//
// In the case of racing against multiple equivalent upstreams
// (e.g. Google or CloudFlare's 4 DNS IPs: 2 IPv4 + 2 IPv6), this type
// handles racing them more intelligently than just blasting away 4
// queries at once.
type forwardQuery struct {
txid txid
packet []byte
// closeOnCtxDone lets send register values to Close if the
// caller's ctx expires. This avoids send from allocating its
// own waiting goroutine to interrupt the ReadFrom, as memory
// is tight on iOS and we want the number of pending DNS
// lookups to be bursty without too much associated
// goroutine/memory cost.
closeOnCtxDone *closePool
// TODO(bradfitz): add race delay state:
// mu sync.Mutex
// ...
}
// forward forwards the query to all upstream nameservers and returns the first response. // forward forwards the query to all upstream nameservers and returns the first response.
func (f *forwarder) forward(query packet) error { func (f *forwarder) forward(query packet) error {
domain, err := nameFromQuery(query.bs) domain, err := nameFromQuery(query.bs)
@ -418,7 +437,6 @@ func (f *forwarder) forward(query packet) error {
return err return err
} }
txid := getTxID(query.bs)
clampEDNSSize(query.bs, maxResponseBytes) clampEDNSSize(query.bs, maxResponseBytes)
resolvers := f.resolvers(domain) resolvers := f.resolvers(domain)
@ -426,8 +444,12 @@ func (f *forwarder) forward(query packet) error {
return errNoUpstreams return errNoUpstreams
} }
closeOnCtxDone := new(closePool) fq := &forwardQuery{
defer closeOnCtxDone.Close() txid: getTxID(query.bs),
packet: query.bs,
closeOnCtxDone: new(closePool),
}
defer fq.closeOnCtxDone.Close()
ctx, cancel := context.WithTimeout(f.ctx, responseTimeout) ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
defer cancel() defer cancel()
@ -440,7 +462,7 @@ func (f *forwarder) forward(query packet) error {
for _, ipp := range resolvers { for _, ipp := range resolvers {
go func(ipp netaddr.IPPort) { go func(ipp netaddr.IPPort) {
resb, err := f.send(ctx, txid, closeOnCtxDone, query.bs, ipp) resb, err := f.send(ctx, fq, ipp)
if err != nil { if err != nil {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()

Loading…
Cancel
Save