// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package resolver import ( "bytes" "context" "encoding/binary" "errors" "fmt" "io" "net" "net/http" "net/netip" "net/url" "sort" "strings" "sync" "sync/atomic" "time" dns "golang.org/x/net/dns/dnsmessage" "tailscale.com/control/controlknobs" "tailscale.com/envknob" "tailscale.com/net/dns/publicdns" "tailscale.com/net/dnscache" "tailscale.com/net/neterror" "tailscale.com/net/netmon" "tailscale.com/net/sockstats" "tailscale.com/net/tsdial" "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/types/nettype" "tailscale.com/util/cloudenv" "tailscale.com/util/dnsname" "tailscale.com/util/race" "tailscale.com/version" ) // headerBytes is the number of bytes in a DNS message header. const headerBytes = 12 // dnsFlagTruncated is set in the flags word when the packet is truncated. const dnsFlagTruncated = 0x200 // truncatedFlagSet returns true if the DNS packet signals that it has // been truncated. False is also returned if the packet was too small // to be valid. func truncatedFlagSet(pkt []byte) bool { if len(pkt) < headerBytes { return false } return (binary.BigEndian.Uint16(pkt[2:4]) & dnsFlagTruncated) != 0 } const ( // dohTransportTimeout is how long to keep idle HTTP // connections open to DNS-over-HTTPs servers. This is pretty // arbitrary. dohTransportTimeout = 30 * time.Second // dohTransportTimeout is how much of a head start to give a DoH query // that was upgraded from a well-known public DNS provider's IP before // normal UDP mode is attempted as a fallback. dohHeadStart = 500 * time.Millisecond // wellKnownHostBackupDelay is how long to artificially delay upstream // DNS queries to the "fallback" DNS server IP for a known provider // (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8). wellKnownHostBackupDelay = 200 * time.Millisecond // udpRaceTimeout is the timeout after which we will start a DNS query // over TCP while waiting for the UDP query to complete. udpRaceTimeout = 2 * time.Second // tcpQueryTimeout is the timeout for a DNS query performed over TCP. // It matches the default 5sec timeout of the 'dig' utility. tcpQueryTimeout = 5 * time.Second ) // txid identifies a DNS transaction. // // As the standard DNS Request ID is only 16 bits, we extend it: // the lower 32 bits are the zero-extended bits of the DNS Request ID; // the upper 32 bits are the CRC32 checksum of the first question in the request. // This makes probability of txid collision negligible. type txid uint64 // getTxID computes the txid of the given DNS packet. func getTxID(packet []byte) txid { if len(packet) < headerBytes { return 0 } dnsid := binary.BigEndian.Uint16(packet[0:2]) // Previously, we hashed the question and combined it with the original txid // which was useful when concurrent queries were multiplexed on a single // local source port. We encountered some situations where the DNS server // canonicalizes the question in the response (uppercase converted to // lowercase in this case), which resulted in responses that we couldn't // match to the original request due to hash mismatches. return txid(dnsid) } func getRCode(packet []byte) dns.RCode { if len(packet) < headerBytes { // treat invalid packets as a refusal return dns.RCode(5) } // get bottom 4 bits of 3rd byte return dns.RCode(packet[3] & 0x0F) } // clampEDNSSize attempts to limit the maximum EDNS response size. This is not // an exhaustive solution, instead only easy cases are currently handled in the // interest of speed and reduced complexity. Only OPT records at the very end of // the message with no option codes are addressed. // TODO: handle more situations if we discover that they happen often func clampEDNSSize(packet []byte, maxSize uint16) { // optFixedBytes is the size of an OPT record with no option codes. const optFixedBytes = 11 const edns0Version = 0 if len(packet) < headerBytes+optFixedBytes { return } arCount := binary.BigEndian.Uint16(packet[10:12]) if arCount == 0 { // OPT shows up in an AR, so there must be no OPT return } // https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.2 opt := packet[len(packet)-optFixedBytes:] if opt[0] != 0 { // OPT NAME must be 0 (root domain) return } if dns.Type(binary.BigEndian.Uint16(opt[1:3])) != dns.TypeOPT { // Not an OPT record return } requestedSize := binary.BigEndian.Uint16(opt[3:5]) // Ignore extended RCODE in opt[5] if opt[6] != edns0Version { // Be conservative and don't touch unknown versions. return } // Ignore flags in opt[6:9] if binary.BigEndian.Uint16(opt[9:11]) != 0 { // RDLEN must be 0 (no variable length data). We're at the end of the // packet so this should be 0 anyway).. return } if requestedSize <= maxSize { return } // Clamp the maximum size binary.BigEndian.PutUint16(opt[3:5], maxSize) } type route struct { Suffix dnsname.FQDN Resolvers []resolverAndDelay } // resolverAndDelay is an upstream DNS resolver and a delay for how // long to wait before querying it. type resolverAndDelay struct { // name is the upstream resolver. name *dnstype.Resolver // startDelay is an amount to delay this resolver at // start. It's used when, say, there are four Google or // Cloudflare DNS IPs (two IPv4 + two IPv6) and we don't want // to race all four at once. startDelay time.Duration } // forwarder forwards DNS packets to a number of upstream nameservers. type forwarder struct { logf logger.Logf netMon *netmon.Monitor // always non-nil linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it dialer *tsdial.Dialer controlKnobs *controlknobs.Knobs // or nil ctx context.Context // good until Close ctxCancel context.CancelFunc // closes ctx mu sync.Mutex // guards following dohClient map[string]*http.Client // urlBase -> client // routes are per-suffix resolvers to use, with // the most specific routes first. routes []route // cloudHostFallback are last resort resolvers to use if no per-suffix // resolver matches. These are only populated on cloud hosts where the // platform provides a well-known recursive resolver. // // That is, if we're running on GCP or AWS where there's always a well-known // IP of a recursive resolver, return that rather than having callers return // SERVFAIL. This fixes both normal 100.100.100.100 resolution when // /etc/resolv.conf is missing/corrupt, and the peerapi ExitDNS stub // resolver lookup. cloudHostFallback []resolverAndDelay } func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, knobs *controlknobs.Knobs) *forwarder { if netMon == nil { panic("nil netMon") } f := &forwarder{ logf: logger.WithPrefix(logf, "forward: "), netMon: netMon, linkSel: linkSel, dialer: dialer, controlKnobs: knobs, } f.ctx, f.ctxCancel = context.WithCancel(context.Background()) return f } func (f *forwarder) Close() error { f.ctxCancel() return nil } // resolversWithDelays maps from a set of DNS server names to a slice of a type // that included a startDelay, upgrading any well-known DoH (DNS-over-HTTP) // servers in the process, insert a DoH lookup first before UDP fallbacks. func resolversWithDelays(resolvers []*dnstype.Resolver) []resolverAndDelay { rr := make([]resolverAndDelay, 0, len(resolvers)+2) type dohState uint8 const addedDoH = dohState(1) const addedDoHAndDontAddUDP = dohState(2) // Add the known DoH ones first, starting immediately. didDoH := map[string]dohState{} for _, r := range resolvers { ipp, ok := r.IPPort() if !ok { continue } dohBase, dohOnly, ok := publicdns.DoHEndpointFromIP(ipp.Addr()) if !ok || didDoH[dohBase] != 0 { continue } if dohOnly { didDoH[dohBase] = addedDoHAndDontAddUDP } else { didDoH[dohBase] = addedDoH } rr = append(rr, resolverAndDelay{name: &dnstype.Resolver{Addr: dohBase}}) } type hostAndFam struct { host string // some arbitrary string representing DNS host (currently the DoH base) bits uint8 // either 32 or 128 for IPv4 vs IPv6s address family } done := map[hostAndFam]int{} for _, r := range resolvers { ipp, ok := r.IPPort() if !ok { // Pass non-IP ones through unchanged, without delay. // (e.g. DNS-over-ExitDNS when using an exit node) rr = append(rr, resolverAndDelay{name: r}) continue } ip := ipp.Addr() var startDelay time.Duration if host, _, ok := publicdns.DoHEndpointFromIP(ip); ok { if didDoH[host] == addedDoHAndDontAddUDP { continue } // We already did the DoH query early. These // are for normal dns53 UDP queries. startDelay = dohHeadStart key := hostAndFam{host, uint8(ip.BitLen())} if done[key] > 0 { startDelay += wellKnownHostBackupDelay } done[key]++ } rr = append(rr, resolverAndDelay{ name: r, startDelay: startDelay, }) } return rr } var ( cloudResolversOnce sync.Once cloudResolversLazy []resolverAndDelay ) func cloudResolvers() []resolverAndDelay { cloudResolversOnce.Do(func() { if ip := cloudenv.Get().ResolverIP(); ip != "" { cloudResolver := []*dnstype.Resolver{{Addr: ip}} cloudResolversLazy = resolversWithDelays(cloudResolver) } }) return cloudResolversLazy } // setRoutes sets the routes to use for DNS forwarding. It's called by // Resolver.SetConfig on reconfig. // // The memory referenced by routesBySuffix should not be modified. func (f *forwarder) setRoutes(routesBySuffix map[dnsname.FQDN][]*dnstype.Resolver) { routes := make([]route, 0, len(routesBySuffix)) cloudHostFallback := cloudResolvers() for suffix, rs := range routesBySuffix { if suffix == "." && len(rs) == 0 && len(cloudHostFallback) > 0 { routes = append(routes, route{ Suffix: suffix, Resolvers: cloudHostFallback, }) } else { routes = append(routes, route{ Suffix: suffix, Resolvers: resolversWithDelays(rs), }) } } if cloudenv.Get().HasInternalTLD() && len(cloudHostFallback) > 0 { if _, ok := routesBySuffix["internal."]; !ok { routes = append(routes, route{ Suffix: "internal.", Resolvers: cloudHostFallback, }) } } // Sort from longest prefix to shortest. sort.Slice(routes, func(i, j int) bool { return routes[i].Suffix.NumLabels() > routes[j].Suffix.NumLabels() }) f.mu.Lock() defer f.mu.Unlock() f.routes = routes f.cloudHostFallback = cloudHostFallback } var stdNetPacketListener nettype.PacketListenerWithNetIP = nettype.MakePacketListenerWithNetIP(new(net.ListenConfig)) func (f *forwarder) packetListener(ip netip.Addr) (nettype.PacketListenerWithNetIP, error) { if f.linkSel == nil || initListenConfig == nil { return stdNetPacketListener, nil } linkName := f.linkSel.PickLink(ip) if linkName == "" { return stdNetPacketListener, nil } lc := new(net.ListenConfig) if err := initListenConfig(lc, f.netMon, linkName); err != nil { return nil, err } return nettype.MakePacketListenerWithNetIP(lc), nil } // getKnownDoHClientForProvider returns an HTTP client for a specific DoH // provider named by its DoH base URL (like "https://dns.google/dns-query"). // // The returned client race/Happy Eyeballs dials all IPs for urlBase (usually // 4), as statically known by the publicdns package. func (f *forwarder) getKnownDoHClientForProvider(urlBase string) (c *http.Client, ok bool) { f.mu.Lock() defer f.mu.Unlock() if c, ok := f.dohClient[urlBase]; ok { return c, true } allIPs := publicdns.DoHIPsOfBase(urlBase) if len(allIPs) == 0 { return nil, false } dohURL, err := url.Parse(urlBase) if err != nil { return nil, false } // NOTE: use f.dialer.SystemDial so we close connections on a link // change; on mobile devices when switching between WiFi and cellular, // we need to ensure we don't retain a connection on the old interface // or we can block DNS resolution. // // NOTE: if we ever support arbitrary user-defined DoH providers, this // isn't sufficient; we'd need a dialer that dial a DoH server on the // internet, without going through Tailscale (as SystemDial does), but // also can dial a node on the tailnet (e.g. a PiHole). // // As of the time of writing (2024-02-11), this isn't a problem because // we only support a restricted set of public DoH providers that aren't // on a user's tailnet. dialer := dnscache.Dialer(f.dialer.SystemDial, &dnscache.Resolver{ SingleHost: dohURL.Hostname(), SingleHostStaticResult: allIPs, Logf: f.logf, }) c = &http.Client{ Transport: &http.Transport{ ForceAttemptHTTP2: true, IdleConnTimeout: dohTransportTimeout, // On mobile platforms TCP KeepAlive is disabled in the dialer, // ensure that we timeout if the connection appears to be hung. ResponseHeaderTimeout: 10 * time.Second, DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) { if !strings.HasPrefix(netw, "tcp") { return nil, fmt.Errorf("unexpected network %q", netw) } return dialer(ctx, netw, addr) }, }, } if f.dohClient == nil { f.dohClient = map[string]*http.Client{} } f.dohClient[urlBase] = c return c, true } const dohType = "application/dns-message" func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client, packet []byte) ([]byte, error) { ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderDoH, f.logf) metricDNSFwdDoH.Add(1) req, err := http.NewRequestWithContext(ctx, "POST", urlBase, bytes.NewReader(packet)) if err != nil { return nil, err } req.Header.Set("Content-Type", dohType) req.Header.Set("Accept", dohType) req.Header.Set("User-Agent", "tailscaled/"+version.Long()) hres, err := c.Do(req) if err != nil { metricDNSFwdDoHErrorTransport.Add(1) return nil, err } defer hres.Body.Close() if hres.StatusCode != 200 { metricDNSFwdDoHErrorStatus.Add(1) return nil, errors.New(hres.Status) } if ct := hres.Header.Get("Content-Type"); ct != dohType { metricDNSFwdDoHErrorCT.Add(1) return nil, fmt.Errorf("unexpected response Content-Type %q", ct) } res, err := io.ReadAll(hres.Body) if err != nil { metricDNSFwdDoHErrorBody.Add(1) } if truncatedFlagSet(res) { metricDNSFwdTruncated.Add(1) } return res, err } var ( verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND") skipTCPRetry = envknob.RegisterBool("TS_DNS_FORWARD_SKIP_TCP_RETRY") // For correlating log messages in the send() function; only used when // verboseDNSForward() is true. forwarderCount atomic.Uint64 ) // send sends packet to dst. It is best effort. // // send expects the reply to have the same txid as txidOut. func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) { if verboseDNSForward() { id := forwarderCount.Add(1) f.logf("forwarder.send(%q) [%d] ...", rr.name.Addr, id) defer func() { f.logf("forwarder.send(%q) [%d] = %v, %v", rr.name.Addr, id, len(ret), err) }() } if strings.HasPrefix(rr.name.Addr, "http://") { return f.sendDoH(ctx, rr.name.Addr, f.dialer.PeerAPIHTTPClient(), fq.packet) } if strings.HasPrefix(rr.name.Addr, "https://") { // Only known DoH providers are supported currently. Specifically, we // only support DoH providers where we can TCP connect to them on port // 443 at the same IP address they serve normal UDP DNS from (1.1.1.1, // 8.8.8.8, 9.9.9.9, etc.) That's why OpenDNS and custom DoH providers // aren't currently supported. There's no backup DNS resolution path for // them. urlBase := rr.name.Addr if hc, ok := f.getKnownDoHClientForProvider(urlBase); ok { return f.sendDoH(ctx, urlBase, hc, fq.packet) } metricDNSFwdErrorType.Add(1) return nil, fmt.Errorf("arbitrary https:// resolvers not supported yet") } if strings.HasPrefix(rr.name.Addr, "tls://") { metricDNSFwdErrorType.Add(1) return nil, fmt.Errorf("tls:// resolvers not supported yet") } ctx, cancel := context.WithCancel(ctx) defer cancel() isUDPQuery := fq.family == "udp" skipTCP := skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load()) // Print logs about retries if this was because of a truncated response. var explicitRetry atomic.Bool // true if truncated UDP response retried defer func() { if !explicitRetry.Load() { return } if err == nil { f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr) } else { f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err) } }() firstUDP := func(ctx context.Context) ([]byte, error) { resp, err := f.sendUDP(ctx, fq, rr) if err != nil { return nil, err } if !truncatedFlagSet(resp) { // Successful, non-truncated response; no retry. return resp, nil } // If this is a UDP query, return it regardless of whether the // response is truncated or not; the client can retry // communicating with tailscaled over TCP. There's no point // falling back to TCP for a truncated query if we can't return // the results to the client. if isUDPQuery { return resp, nil } if skipTCP { // Envknob or control knob disabled the TCP retry behaviour; // just return what we have. return resp, nil } // This is a TCP query from the client, and the UDP response // from the upstream DNS server is truncated; map this to an // error to cause our retry helper to immediately kick off the // TCP retry. explicitRetry.Store(true) return nil, truncatedResponseError{resp} } thenTCP := func(ctx context.Context) ([]byte, error) { // If we're skipping the TCP fallback, then wait until the // context is canceled and return that error (i.e. not // returning anything). if skipTCP { <-ctx.Done() return nil, ctx.Err() } return f.sendTCP(ctx, fq, rr) } // If the input query is TCP, then don't have a timeout between // starting UDP and TCP. timeout := udpRaceTimeout if !isUDPQuery { timeout = 0 } // Kick off the race between the UDP and TCP queries. rh := race.New(timeout, firstUDP, thenTCP) resp, err := rh.Start(ctx) if err == nil { return resp, nil } // If we got a truncated UDP response, return that instead of an error. var trErr truncatedResponseError if errors.As(err, &trErr) { return trErr.res, nil } return nil, err } type truncatedResponseError struct { res []byte } func (tr truncatedResponseError) Error() string { return "response truncated" } var errServerFailure = errors.New("response code indicates server issue") var errTxIDMismatch = errors.New("txid doesn't match") func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) { ipp, ok := rr.name.IPPort() if !ok { metricDNSFwdErrorType.Add(1) return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr) } metricDNSFwdUDP.Add(1) ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderUDP, f.logf) ln, err := f.packetListener(ipp.Addr()) if err != nil { return nil, err } // Specify the exact UDP family to work around https://github.com/golang/go/issues/52264 udpFam := "udp4" if ipp.Addr().Is6() { udpFam = "udp6" } conn, err := ln.ListenPacket(ctx, udpFam, ":0") if err != nil { f.logf("ListenPacket failed: %v", err) return nil, err } defer conn.Close() fq.closeOnCtxDone.Add(conn) defer fq.closeOnCtxDone.Remove(conn) if _, err := conn.WriteToUDPAddrPort(fq.packet, ipp); err != nil { metricDNSFwdUDPErrorWrite.Add(1) if err := ctx.Err(); err != nil { return nil, err } return nil, err } metricDNSFwdUDPWrote.Add(1) // The 1 extra byte is to detect packet truncation. out := make([]byte, maxResponseBytes+1) n, _, err := conn.ReadFromUDPAddrPort(out) if err != nil { if err := ctx.Err(); err != nil { return nil, err } if neterror.PacketWasTruncated(err) { err = nil } else { metricDNSFwdUDPErrorRead.Add(1) return nil, err } } truncated := n > maxResponseBytes if truncated { n = maxResponseBytes } if n < headerBytes { f.logf("recv: packet too small (%d bytes)", n) } out = out[:n] txid := getTxID(out) if txid != fq.txid { metricDNSFwdUDPErrorTxID.Add(1) return nil, errTxIDMismatch } rcode := getRCode(out) // don't forward transient errors back to the client when the server fails if rcode == dns.RCodeServerFailure { f.logf("recv: response code indicating server failure: %d", rcode) metricDNSFwdUDPErrorServer.Add(1) return nil, errServerFailure } if truncated { // Set the truncated bit if it wasn't already. flags := binary.BigEndian.Uint16(out[2:4]) flags |= dnsFlagTruncated binary.BigEndian.PutUint16(out[2:4], flags) // TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2 // states that truncation should head drop so that the authority // section can be preserved if possible. However, the UDP read with // a too-small buffer has already dropped the end, so that's the // best we can do. } if truncatedFlagSet(out) { metricDNSFwdTruncated.Add(1) } clampEDNSSize(out, maxResponseBytes) metricDNSFwdUDPSuccess.Add(1) return out, nil } func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) { ipp, ok := rr.name.IPPort() if !ok { metricDNSFwdErrorType.Add(1) return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr) } metricDNSFwdTCP.Add(1) ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderTCP, f.logf) // Specify the exact family to work around https://github.com/golang/go/issues/52264 tcpFam := "tcp4" if ipp.Addr().Is6() { tcpFam = "tcp6" } ctx, cancel := context.WithTimeout(ctx, tcpQueryTimeout) defer cancel() conn, err := f.dialer.SystemDial(ctx, tcpFam, ipp.String()) if err != nil { return nil, err } defer conn.Close() fq.closeOnCtxDone.Add(conn) defer fq.closeOnCtxDone.Remove(conn) ctxOrErr := func(err2 error) ([]byte, error) { if err := ctx.Err(); err != nil { return nil, err } return nil, err2 } // Write the query to the server. query := make([]byte, len(fq.packet)+2) binary.BigEndian.PutUint16(query, uint16(len(fq.packet))) copy(query[2:], fq.packet) if _, err := conn.Write(query); err != nil { metricDNSFwdTCPErrorWrite.Add(1) return ctxOrErr(err) } metricDNSFwdTCPWrote.Add(1) // Read the header length back from the server var length uint16 if err := binary.Read(conn, binary.BigEndian, &length); err != nil { metricDNSFwdTCPErrorRead.Add(1) return ctxOrErr(err) } // Now read the response out := make([]byte, length) n, err := io.ReadFull(conn, out) if err != nil { metricDNSFwdTCPErrorRead.Add(1) return ctxOrErr(err) } if n < int(length) { f.logf("sendTCP: packet too small (%d bytes)", n) return nil, io.ErrUnexpectedEOF } out = out[:n] txid := getTxID(out) if txid != fq.txid { metricDNSFwdTCPErrorTxID.Add(1) return nil, errTxIDMismatch } rcode := getRCode(out) // don't forward transient errors back to the client when the server fails if rcode == dns.RCodeServerFailure { f.logf("sendTCP: response code indicating server failure: %d", rcode) metricDNSFwdTCPErrorServer.Add(1) return nil, errServerFailure } // TODO(andrew): do we need to do this? //clampEDNSSize(out, maxResponseBytes) metricDNSFwdTCPSuccess.Add(1) return out, nil } // resolvers returns the resolvers to use for domain. func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay { f.mu.Lock() routes := f.routes cloudHostFallback := f.cloudHostFallback f.mu.Unlock() for _, route := range routes { if route.Suffix == "." || route.Suffix.Contains(domain) { return route.Resolvers } } return cloudHostFallback // or nil if no fallback } // forwardQuery is information and state about a forwarded DNS query that's // being sent to 1 or more upstreams. // // In the case of racing against multiple equivalent upstreams // (e.g. Google or CloudFlare's 4 DNS IPs: 2 IPv4 + 2 IPv6), this type // handles racing them more intelligently than just blasting away 4 // queries at once. type forwardQuery struct { txid txid packet []byte family string // "tcp" or "udp" // closeOnCtxDone lets send register values to Close if the // caller's ctx expires. This avoids send from allocating its // own waiting goroutine to interrupt the ReadFrom, as memory // is tight on iOS and we want the number of pending DNS // lookups to be bursty without too much associated // goroutine/memory cost. closeOnCtxDone *closePool // TODO(bradfitz): add race delay state: // mu sync.Mutex // ... } // forwardWithDestChan forwards the query to all upstream nameservers // and waits for the first response. // // It either sends to responseChan and returns nil, or returns a // non-nil error (without sending to the channel). // // If resolvers is non-empty, it's used explicitly (notably, for exit // node DNS proxy queries), otherwise f.resolvers is used. func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, responseChan chan<- packet, resolvers ...resolverAndDelay) error { metricDNSFwd.Add(1) domain, err := nameFromQuery(query.bs) if err != nil { metricDNSFwdErrorName.Add(1) return err } // Guarantee that the ctx we use below is done when this function returns. ctx, cancel := context.WithCancel(ctx) defer cancel() // Drop DNS service discovery spam, primarily for battery life // on mobile. Things like Spotify on iOS generate this traffic, // when browsing for LAN devices. But even when filtering this // out, playing on Sonos still works. if hasRDNSBonjourPrefix(domain) { metricDNSFwdDropBonjour.Add(1) res, err := nxDomainResponse(query) if err != nil { f.logf("error parsing bonjour query: %v", err) // Returning an error will cause an internal retry, there is // nothing we can do if parsing failed. Just drop the packet. return nil } select { case <-ctx.Done(): return ctx.Err() case responseChan <- res: return nil } } if fl := fwdLogAtomic.Load(); fl != nil { fl.addName(string(domain)) } clampEDNSSize(query.bs, maxResponseBytes) if len(resolvers) == 0 { resolvers = f.resolvers(domain) if len(resolvers) == 0 { metricDNSFwdErrorNoUpstream.Add(1) f.logf("no upstream resolvers set, returning SERVFAIL") res, err := servfailResponse(query) if err != nil { f.logf("building servfail response: %v", err) // Returning an error will cause an internal retry, there is // nothing we can do if parsing failed. Just drop the packet. return nil } select { case <-ctx.Done(): return ctx.Err() case responseChan <- res: return nil } } } fq := &forwardQuery{ txid: getTxID(query.bs), packet: query.bs, family: query.family, closeOnCtxDone: new(closePool), } defer fq.closeOnCtxDone.Close() resc := make(chan []byte, 1) // it's fine buffered or not errc := make(chan error, 1) // it's fine buffered or not too for i := range resolvers { go func(rr *resolverAndDelay) { if rr.startDelay > 0 { timer := time.NewTimer(rr.startDelay) select { case <-timer.C: case <-ctx.Done(): timer.Stop() return } } resb, err := f.send(ctx, fq, *rr) if err != nil { select { case errc <- err: case <-ctx.Done(): } return } select { case resc <- resb: case <-ctx.Done(): } }(&resolvers[i]) } var firstErr error var numErr int for { select { case v := <-resc: select { case <-ctx.Done(): metricDNSFwdErrorContext.Add(1) return ctx.Err() case responseChan <- packet{v, query.family, query.addr}: metricDNSFwdSuccess.Add(1) return nil } case err := <-errc: if firstErr == nil { firstErr = err } numErr++ if numErr == len(resolvers) { if errors.Is(firstErr, errServerFailure) { res, err := servfailResponse(query) if err != nil { f.logf("building servfail response: %v", err) return firstErr } select { case <-ctx.Done(): metricDNSFwdErrorContext.Add(1) metricDNSFwdErrorContextGotError.Add(1) case responseChan <- res: } } return firstErr } case <-ctx.Done(): metricDNSFwdErrorContext.Add(1) if firstErr != nil { metricDNSFwdErrorContextGotError.Add(1) return firstErr } return ctx.Err() } } } var initListenConfig func(_ *net.ListenConfig, _ *netmon.Monitor, tunName string) error // nameFromQuery extracts the normalized query name from bs. func nameFromQuery(bs []byte) (dnsname.FQDN, error) { var parser dns.Parser hdr, err := parser.Start(bs) if err != nil { return "", err } if hdr.Response { return "", errNotQuery } q, err := parser.Question() if err != nil { return "", err } n := q.Name.Data[:q.Name.Length] return dnsname.ToFQDN(rawNameToLower(n)) } // nxDomainResponse returns an NXDomain DNS reply for the provided request. func nxDomainResponse(req packet) (res packet, err error) { p := dnsParserPool.Get().(*dnsParser) defer dnsParserPool.Put(p) if err := p.parseQuery(req.bs); err != nil { return packet{}, err } h := p.Header h.Response = true h.RecursionAvailable = h.RecursionDesired h.RCode = dns.RCodeNameError b := dns.NewBuilder(nil, h) // TODO(bradfitz): should we add an SOA record in the Authority // section too? (for the nxdomain negative caching TTL) // For which zone? Does iOS care? res.bs, err = b.Finish() res.addr = req.addr return res, err } // servfailResponse returns a SERVFAIL error reply for the provided request. func servfailResponse(req packet) (res packet, err error) { p := dnsParserPool.Get().(*dnsParser) defer dnsParserPool.Put(p) if err := p.parseQuery(req.bs); err != nil { return packet{}, err } h := p.Header h.Response = true h.Authoritative = true h.RCode = dns.RCodeServerFailure b := dns.NewBuilder(nil, h) b.StartQuestions() b.Question(p.Question) res.bs, err = b.Finish() res.addr = req.addr return res, err } // closePool is a dynamic set of io.Closers to close as a group. // It's intended to be Closed at most once. // // The zero value is ready for use. type closePool struct { mu sync.Mutex m map[io.Closer]bool closed bool } func (p *closePool) Add(c io.Closer) { p.mu.Lock() defer p.mu.Unlock() if p.closed { c.Close() return } if p.m == nil { p.m = map[io.Closer]bool{} } p.m[c] = true } func (p *closePool) Remove(c io.Closer) { p.mu.Lock() defer p.mu.Unlock() if p.closed { return } delete(p.m, c) } func (p *closePool) Close() error { p.mu.Lock() defer p.mu.Unlock() if p.closed { return nil } p.closed = true for c := range p.m { c.Close() } return nil }