// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause // Package recursive implements a simple recursive DNS resolver. package recursive import ( "context" "errors" "fmt" "net" "net/netip" "slices" "strings" "time" "github.com/miekg/dns" "tailscale.com/envknob" "tailscale.com/net/netns" "tailscale.com/types/logger" "tailscale.com/util/dnsname" "tailscale.com/util/mak" "tailscale.com/util/multierr" "tailscale.com/util/slicesx" ) const ( // maxDepth is how deep from the root nameservers we'll recurse when // resolving; passing this limit will instead return an error. // // maxDepth must be at least 20 to resolve "console.aws.amazon.com", // which is a domain with a moderately complicated DNS setup. The // current value of 30 was chosen semi-arbitrarily to ensure that we // have about 50% headroom. maxDepth = 30 // numStartingServers is the number of root nameservers that we use as // initial candidates for our recursion. numStartingServers = 3 // udpQueryTimeout is the amount of time we wait for a UDP response // from a nameserver before falling back to a TCP connection. udpQueryTimeout = 5 * time.Second // These constants aren't typed in the DNS package, so we create typed // versions here to avoid having to do repeated type casts. qtypeA dns.Type = dns.Type(dns.TypeA) qtypeAAAA dns.Type = dns.Type(dns.TypeAAAA) ) var ( // ErrMaxDepth is returned when recursive resolving exceeds the maximum // depth limit for this package. ErrMaxDepth = fmt.Errorf("exceeded max depth %d when resolving", maxDepth) // ErrAuthoritativeNoResponses is the error returned when an // authoritative nameserver indicates that there are no responses to // the given query. ErrAuthoritativeNoResponses = errors.New("authoritative server returned no responses") // ErrNoResponses is returned when our resolution process completes // with no valid responses from any nameserver, but no authoritative // server explicitly returned NXDOMAIN. ErrNoResponses = errors.New("no responses to query") ) var rootServersV4 = []netip.Addr{ netip.MustParseAddr("198.41.0.4"), // a.root-servers.net netip.MustParseAddr("170.247.170.2"), // b.root-servers.net netip.MustParseAddr("192.33.4.12"), // c.root-servers.net netip.MustParseAddr("199.7.91.13"), // d.root-servers.net netip.MustParseAddr("192.203.230.10"), // e.root-servers.net netip.MustParseAddr("192.5.5.241"), // f.root-servers.net netip.MustParseAddr("192.112.36.4"), // g.root-servers.net netip.MustParseAddr("198.97.190.53"), // h.root-servers.net netip.MustParseAddr("192.36.148.17"), // i.root-servers.net netip.MustParseAddr("192.58.128.30"), // j.root-servers.net netip.MustParseAddr("193.0.14.129"), // k.root-servers.net netip.MustParseAddr("199.7.83.42"), // l.root-servers.net netip.MustParseAddr("202.12.27.33"), // m.root-servers.net } var rootServersV6 = []netip.Addr{ netip.MustParseAddr("2001:503:ba3e::2:30"), // a.root-servers.net netip.MustParseAddr("2801:1b8:10::b"), // b.root-servers.net netip.MustParseAddr("2001:500:2::c"), // c.root-servers.net netip.MustParseAddr("2001:500:2d::d"), // d.root-servers.net netip.MustParseAddr("2001:500:a8::e"), // e.root-servers.net netip.MustParseAddr("2001:500:2f::f"), // f.root-servers.net netip.MustParseAddr("2001:500:12::d0d"), // g.root-servers.net netip.MustParseAddr("2001:500:1::53"), // h.root-servers.net netip.MustParseAddr("2001:7fe::53"), // i.root-servers.net netip.MustParseAddr("2001:503:c27::2:30"), // j.root-servers.net netip.MustParseAddr("2001:7fd::1"), // k.root-servers.net netip.MustParseAddr("2001:500:9f::42"), // l.root-servers.net netip.MustParseAddr("2001:dc3::35"), // m.root-servers.net } var debug = envknob.RegisterBool("TS_DEBUG_RECURSIVE_DNS") // Resolver is a recursive DNS resolver that is designed for looking up A and AAAA records. type Resolver struct { // Dialer is used to create outbound connections. If nil, a zero // net.Dialer will be used instead. Dialer netns.Dialer // Logf is the logging function to use; if none is specified, then logs // will be dropped. Logf logger.Logf // NoIPv6, if set, will prevent this package from querying for AAAA // records and will avoid contacting nameservers over IPv6. NoIPv6 bool // Test mocks testQueryHook func(name dnsname.FQDN, nameserver netip.Addr, protocol string, qtype dns.Type) (*dns.Msg, error) testExchangeHook func(nameserver netip.Addr, network string, msg *dns.Msg) (*dns.Msg, error) rootServers []netip.Addr timeNow func() time.Time // Caching // NOTE(andrew): if we make resolution parallel, this needs a mutex queryCache map[dnsQuery]dnsMsgWithExpiry // Possible future additions: // - Additional nameservers? From the system maybe? // - NoIPv4 for IPv4 // - DNS-over-HTTPS or DNS-over-TLS support } // queryState stores all state during the course of a single query type queryState struct { // rootServers are the root nameservers to start from rootServers []netip.Addr // TODO: metrics? } type dnsQuery struct { nameserver netip.Addr name dnsname.FQDN qtype dns.Type } func (q dnsQuery) String() string { return fmt.Sprintf("dnsQuery{nameserver:%q,name:%q,qtype:%v}", q.nameserver.String(), q.name, q.qtype) } type dnsMsgWithExpiry struct { *dns.Msg expiresAt time.Time } func (r *Resolver) now() time.Time { if r.timeNow != nil { return r.timeNow() } return time.Now() } func (r *Resolver) logf(format string, args ...any) { if r.Logf == nil { return } r.Logf(format, args...) } func (r *Resolver) depthlogf(depth int, format string, args ...any) { if r.Logf == nil || !debug() { return } prefix := fmt.Sprintf("[%d] %s", depth, strings.Repeat(" ", depth)) r.Logf(prefix+format, args...) } var defaultDialer net.Dialer func (r *Resolver) dialer() netns.Dialer { if r.Dialer != nil { return r.Dialer } return &defaultDialer } func (r *Resolver) newState() *queryState { var rootServers []netip.Addr if len(r.rootServers) > 0 { rootServers = r.rootServers } else { // Select a random subset of root nameservers to start from, since if // we don't get responses from those, something else has probably gone // horribly wrong. roots4 := slices.Clone(rootServersV4) slicesx.Shuffle(roots4) roots4 = roots4[:numStartingServers] var roots6 []netip.Addr if !r.NoIPv6 { roots6 = slices.Clone(rootServersV6) slicesx.Shuffle(roots6) roots6 = roots6[:numStartingServers] } // Interleave the root servers so that we try to contact them over // IPv4, then IPv6, IPv4, IPv6, etc. rootServers = slicesx.Interleave(roots4, roots6) } return &queryState{ rootServers: rootServers, } } // Resolve will perform a recursive DNS resolution for the provided name, // starting at a randomly-chosen root DNS server, and return the A and AAAA // responses as a slice of netip.Addrs along with the minimum TTL for the // returned records. func (r *Resolver) Resolve(ctx context.Context, name string) (addrs []netip.Addr, minTTL time.Duration, err error) { dnsName, err := dnsname.ToFQDN(name) if err != nil { return nil, 0, err } qstate := r.newState() r.logf("querying IPv4 addresses for: %q", name) addrs4, minTTL4, err4 := r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeA) var ( addrs6 []netip.Addr minTTL6 time.Duration err6 error ) if !r.NoIPv6 { r.logf("querying IPv6 addresses for: %q", name) addrs6, minTTL6, err6 = r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeAAAA) } if err4 != nil && err6 != nil { if err4 == err6 { return nil, 0, err4 } return nil, 0, multierr.New(err4, err6) } if err4 != nil { return addrs6, minTTL6, nil } else if err6 != nil { return addrs4, minTTL4, nil } minTTL = minTTL4 if minTTL6 < minTTL { minTTL = minTTL6 } addrs = append(addrs4, addrs6...) if len(addrs) == 0 { return nil, 0, ErrNoResponses } slicesx.Shuffle(addrs) return addrs, minTTL, nil } func (r *Resolver) resolveRecursiveFromRoot( ctx context.Context, qstate *queryState, depth int, name dnsname.FQDN, // what we're querying qtype dns.Type, ) ([]netip.Addr, time.Duration, error) { r.depthlogf(depth, "resolving %q from root (type: %v)", name, qtype) var depthError bool for _, server := range qstate.rootServers { addrs, minTTL, err := r.resolveRecursive(ctx, qstate, depth, name, server, qtype) if err == nil { return addrs, minTTL, err } else if errors.Is(err, ErrAuthoritativeNoResponses) { return nil, 0, ErrAuthoritativeNoResponses } else if errors.Is(err, ErrMaxDepth) { depthError = true } } if depthError { return nil, 0, ErrMaxDepth } return nil, 0, ErrNoResponses } func (r *Resolver) resolveRecursive( ctx context.Context, qstate *queryState, depth int, name dnsname.FQDN, // what we're querying nameserver netip.Addr, qtype dns.Type, ) ([]netip.Addr, time.Duration, error) { if depth == maxDepth { r.depthlogf(depth, "not recursing past maximum depth") return nil, 0, ErrMaxDepth } // Ask this nameserver for an answer. resp, err := r.queryNameserver(ctx, depth, name, nameserver, qtype) if err != nil { return nil, 0, err } // If we get an actual answer from the nameserver, then return it. var ( answers []netip.Addr cnames []dnsname.FQDN minTTL = 24 * 60 * 60 // 24 hours in seconds ) for _, answer := range resp.Answer { if crec, ok := answer.(*dns.CNAME); ok { cnameFQDN, err := dnsname.ToFQDN(crec.Target) if err != nil { r.logf("bad CNAME %q returned: %v", crec.Target, err) continue } cnames = append(cnames, cnameFQDN) continue } addr := addrFromRecord(answer) if !addr.IsValid() { r.logf("[unexpected] invalid record in %T answer", answer) } else if addr.Is4() && qtype != qtypeA { r.logf("[unexpected] got IPv4 answer but qtype=%v", qtype) } else if addr.Is6() && qtype != qtypeAAAA { r.logf("[unexpected] got IPv6 answer but qtype=%v", qtype) } else { answers = append(answers, addr) minTTL = min(minTTL, int(answer.Header().Ttl)) } } if len(answers) > 0 { r.depthlogf(depth, "got answers for %q: %v", name, answers) return answers, time.Duration(minTTL) * time.Second, nil } r.depthlogf(depth, "no answers for %q", name) // If we have a non-zero number of CNAMEs, then try resolving those // (from the root again) and return the first one that succeeds. // // TODO: return the union of all responses? // TODO: parallelism? if len(cnames) > 0 { r.depthlogf(depth, "got CNAME responses for %q: %v", name, cnames) } var cnameDepthError bool for _, cname := range cnames { answers, minTTL, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, cname, qtype) if err == nil { return answers, minTTL, nil } else if errors.Is(err, ErrAuthoritativeNoResponses) { return nil, 0, ErrAuthoritativeNoResponses } else if errors.Is(err, ErrMaxDepth) { cnameDepthError = true } } // If this is an authoritative response, then we know that continuing // to look further is not going to result in any answers and we should // bail out. if resp.MsgHdr.Authoritative { // If we failed to recurse into a CNAME due to a depth limit, // propagate that here. if cnameDepthError { return nil, 0, ErrMaxDepth } r.depthlogf(depth, "got authoritative response with no answers; stopping") return nil, 0, ErrAuthoritativeNoResponses } r.depthlogf(depth, "got %d NS responses and %d ADDITIONAL responses for %q", len(resp.Ns), len(resp.Extra), name) // No CNAMEs and no answers; see if we got any AUTHORITY responses, // which indicate which nameservers to query next. var authorities []dnsname.FQDN for _, rr := range resp.Ns { ns, ok := rr.(*dns.NS) if !ok { continue } nsName, err := dnsname.ToFQDN(ns.Ns) if err != nil { r.logf("unexpected bad NS name %q: %v", ns.Ns, err) continue } authorities = append(authorities, nsName) } // Also check for "glue" records, which are IP addresses provided by // the DNS server for authority responses; these are required when the // authority server is a subdomain of what's being resolved. glueRecords := make(map[dnsname.FQDN][]netip.Addr) for _, rr := range resp.Extra { name, err := dnsname.ToFQDN(rr.Header().Name) if err != nil { r.logf("unexpected bad Name %q in Extra addr: %v", rr.Header().Name, err) continue } if addr := addrFromRecord(rr); addr.IsValid() { glueRecords[name] = append(glueRecords[name], addr) } else { r.logf("unexpected bad Extra %T addr", rr) } } // Try authorities with glue records first, to minimize the number of // additional DNS queries that we need to make. authoritiesGlue, authoritiesNoGlue := slicesx.Partition(authorities, func(aa dnsname.FQDN) bool { return len(glueRecords[aa]) > 0 }) authorityDepthError := false r.depthlogf(depth, "authorities with glue records for recursion: %v", authoritiesGlue) for _, authority := range authoritiesGlue { for _, nameserver := range glueRecords[authority] { answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype) if err == nil { return answers, minTTL, nil } else if errors.Is(err, ErrAuthoritativeNoResponses) { return nil, 0, ErrAuthoritativeNoResponses } else if errors.Is(err, ErrMaxDepth) { authorityDepthError = true } } } r.depthlogf(depth, "authorities with no glue records for recursion: %v", authoritiesNoGlue) for _, authority := range authoritiesNoGlue { // First, resolve the IP for the authority server from the // root, querying for both IPv4 and IPv6 addresses regardless // of what the current question type is. // // TODO: check for infinite recursion; it'll get caught by our // recursion depth, but we want to bail early. for _, authorityQtype := range []dns.Type{qtypeAAAA, qtypeA} { answers, _, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, authority, authorityQtype) if err != nil { r.depthlogf(depth, "error querying authority %q: %v", authority, err) continue } r.depthlogf(depth, "resolved authority %q (type %v) to: %v", authority, authorityQtype, answers) // Now, query this authority for the final address. for _, nameserver := range answers { answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype) if err == nil { return answers, minTTL, nil } else if errors.Is(err, ErrAuthoritativeNoResponses) { return nil, 0, ErrAuthoritativeNoResponses } else if errors.Is(err, ErrMaxDepth) { authorityDepthError = true } } } } if authorityDepthError { return nil, 0, ErrMaxDepth } return nil, 0, ErrNoResponses } // queryNameserver sends a query for "name" to the nameserver "nameserver" for // records of type "qtype", trying both UDP and TCP connections as // appropriate. func (r *Resolver) queryNameserver( ctx context.Context, depth int, name dnsname.FQDN, // what we're querying nameserver netip.Addr, // destination of query qtype dns.Type, ) (*dns.Msg, error) { // TODO(andrew): we should QNAME minimisation here to avoid sending the // full name to intermediate/root nameservers. See: // https://www.rfc-editor.org/rfc/rfc7816 // Handle the case where UDP is blocked by adding an explicit timeout // for the UDP portion of this query. udpCtx, udpCtxCancel := context.WithTimeout(ctx, udpQueryTimeout) defer udpCtxCancel() msg, err := r.queryNameserverProto(udpCtx, depth, name, nameserver, "udp", qtype) if err == nil { return msg, nil } msg, err2 := r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype) if err2 == nil { return msg, nil } return nil, multierr.New(err, err2) } // queryNameserverProto sends a query for "name" to the nameserver "nameserver" // for records of type "qtype" over the provided protocol (either "udp" // or "tcp"), and returns the DNS response or an error. func (r *Resolver) queryNameserverProto( ctx context.Context, depth int, name dnsname.FQDN, // what we're querying nameserver netip.Addr, // destination of query protocol string, qtype dns.Type, ) (resp *dns.Msg, err error) { if r.testQueryHook != nil { return r.testQueryHook(name, nameserver, protocol, qtype) } now := r.now() nameserverStr := nameserver.String() cacheKey := dnsQuery{ nameserver: nameserver, name: name, qtype: qtype, } cacheEntry, ok := r.queryCache[cacheKey] if ok && cacheEntry.expiresAt.Before(now) { r.depthlogf(depth, "using cached response from %s about %q (type: %v)", nameserverStr, name, qtype) return cacheEntry.Msg, nil } var network string if nameserver.Is4() { network = protocol + "4" } else { network = protocol + "6" } // Prepare a message asking for an appropriately-typed record // for the name we're querying. m := new(dns.Msg) m.SetQuestion(name.WithTrailingDot(), uint16(qtype)) // Allow mocking out the network components with our exchange hook. if r.testExchangeHook != nil { resp, err = r.testExchangeHook(nameserver, network, m) } else { // Dial the current nameserver using our dialer. var nconn net.Conn nconn, err = r.dialer().DialContext(ctx, network, net.JoinHostPort(nameserverStr, "53")) if err != nil { return nil, err } var c dns.Client // TODO: share? conn := &dns.Conn{ Conn: nconn, UDPSize: c.UDPSize, } // Send the DNS request to the current nameserver. r.depthlogf(depth, "asking %s over %s about %q (type: %v)", nameserverStr, protocol, name, qtype) resp, _, err = c.ExchangeWithConnContext(ctx, m, conn) } if err != nil { return nil, err } // If the message was truncated and we're using UDP, re-run with TCP. if resp.MsgHdr.Truncated && protocol == "udp" { r.depthlogf(depth, "response message truncated; re-running query with TCP") resp, err = r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype) if err != nil { return nil, err } } // Find minimum expiry for all records in this message. var minTTL int for _, rr := range resp.Answer { minTTL = min(minTTL, int(rr.Header().Ttl)) } for _, rr := range resp.Ns { minTTL = min(minTTL, int(rr.Header().Ttl)) } for _, rr := range resp.Extra { minTTL = min(minTTL, int(rr.Header().Ttl)) } mak.Set(&r.queryCache, cacheKey, dnsMsgWithExpiry{ Msg: resp, expiresAt: now.Add(time.Duration(minTTL) * time.Second), }) return resp, nil } func addrFromRecord(rr dns.RR) netip.Addr { switch v := rr.(type) { case *dns.A: ip, ok := netip.AddrFromSlice(v.A) if !ok || !ip.Is4() { return netip.Addr{} } return ip case *dns.AAAA: ip, ok := netip.AddrFromSlice(v.AAAA) if !ok || !ip.Is6() { return netip.Addr{} } return ip } return netip.Addr{} }