diff --git a/net/dns/recursive/recursive.go b/net/dns/recursive/recursive.go new file mode 100644 index 000000000..8bba66944 --- /dev/null +++ b/net/dns/recursive/recursive.go @@ -0,0 +1,640 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package recursive implements a simple recursive DNS resolver. +package recursive + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strings" + "time" + + "github.com/miekg/dns" + "golang.org/x/exp/constraints" + "golang.org/x/exp/slices" + "tailscale.com/envknob" + "tailscale.com/net/netns" + "tailscale.com/types/logger" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" + "tailscale.com/util/multierr" + "tailscale.com/util/slicesx" +) + +const ( + // maxDepth is how deep from the root nameservers we'll recurse when + // resolving; passing this limit will instead return an error. + // + // maxDepth must be at least 20 to resolve "console.aws.amazon.com", + // which is a domain with a moderately complicated DNS setup. The + // current value of 30 was chosen semi-arbitrarily to ensure that we + // have about 50% headroom. + maxDepth = 30 + // numStartingServers is the number of root nameservers that we use as + // initial candidates for our recursion. + numStartingServers = 3 + // udpQueryTimeout is the amount of time we wait for a UDP response + // from a nameserver before falling back to a TCP connection. + udpQueryTimeout = 5 * time.Second + + // These constants aren't typed in the DNS package, so we create typed + // versions here to avoid having to do repeated type casts. + qtypeA dns.Type = dns.Type(dns.TypeA) + qtypeAAAA dns.Type = dns.Type(dns.TypeAAAA) +) + +var ( + // ErrMaxDepth is returned when recursive resolving exceeds the maximum + // depth limit for this package. + ErrMaxDepth = fmt.Errorf("exceeded max depth %d when resolving", maxDepth) + + // ErrAuthoritativeNoResponses is the error returned when an + // authoritative nameserver indicates that there are no responses to + // the given query. + ErrAuthoritativeNoResponses = errors.New("authoritative server returned no responses") + + // ErrNoResponses is returned when our resolution process completes + // with no valid responses from any nameserver, but no authoritative + // server explicitly returned NXDOMAIN. + ErrNoResponses = errors.New("no responses to query") +) + +var rootServersV4 = []netip.Addr{ + netip.MustParseAddr("198.41.0.4"), // a.root-servers.net + netip.MustParseAddr("199.9.14.201"), // b.root-servers.net + netip.MustParseAddr("192.33.4.12"), // c.root-servers.net + netip.MustParseAddr("199.7.91.13"), // d.root-servers.net + netip.MustParseAddr("192.203.230.10"), // e.root-servers.net + netip.MustParseAddr("192.5.5.241"), // f.root-servers.net + netip.MustParseAddr("192.112.36.4"), // g.root-servers.net + netip.MustParseAddr("198.97.190.53"), // h.root-servers.net + netip.MustParseAddr("192.36.148.17"), // i.root-servers.net + netip.MustParseAddr("192.58.128.30"), // j.root-servers.net + netip.MustParseAddr("193.0.14.129"), // k.root-servers.net + netip.MustParseAddr("199.7.83.42"), // l.root-servers.net + netip.MustParseAddr("202.12.27.33"), // m.root-servers.net +} + +var rootServersV6 = []netip.Addr{ + netip.MustParseAddr("2001:503:ba3e::2:30"), // a.root-servers.net + netip.MustParseAddr("2001:500:200::b"), // b.root-servers.net + netip.MustParseAddr("2001:500:2::c"), // c.root-servers.net + netip.MustParseAddr("2001:500:2d::d"), // d.root-servers.net + netip.MustParseAddr("2001:500:a8::e"), // e.root-servers.net + netip.MustParseAddr("2001:500:2f::f"), // f.root-servers.net + netip.MustParseAddr("2001:500:12::d0d"), // g.root-servers.net + netip.MustParseAddr("2001:500:1::53"), // h.root-servers.net + netip.MustParseAddr("2001:7fe::53"), // i.root-servers.net + netip.MustParseAddr("2001:503:c27::2:30"), // j.root-servers.net + netip.MustParseAddr("2001:7fd::1"), // k.root-servers.net + netip.MustParseAddr("2001:500:9f::42"), // l.root-servers.net + netip.MustParseAddr("2001:dc3::35"), // m.root-servers.net +} + +var debug = envknob.RegisterBool("TS_DEBUG_RECURSIVE_DNS") + +// Resolver is a recursive DNS resolver that is designed for looking up A and AAAA records. +type Resolver struct { + // Dialer is used to create outbound connections. If nil, a zero + // net.Dialer will be used instead. + Dialer netns.Dialer + + // Logf is the logging function to use; if none is specified, then logs + // will be dropped. + Logf logger.Logf + + // NoIPv6, if set, will prevent this package from querying for AAAA + // records and will avoid contacting nameservers over IPv6. + NoIPv6 bool + + // Test mocks + testQueryHook func(name dnsname.FQDN, nameserver netip.Addr, protocol string, qtype dns.Type) (*dns.Msg, error) + testExchangeHook func(nameserver netip.Addr, network string, msg *dns.Msg) (*dns.Msg, error) + rootServers []netip.Addr + timeNow func() time.Time + + // Caching + // NOTE(andrew): if we make resolution parallel, this needs a mutex + queryCache map[dnsQuery]dnsMsgWithExpiry + + // Possible future additions: + // - Additional nameservers? From the system maybe? + // - NoIPv4 for IPv4 + // - DNS-over-HTTPS or DNS-over-TLS support +} + +// queryState stores all state during the course of a single query +type queryState struct { + // rootServers are the root nameservers to start from + rootServers []netip.Addr + + // TODO: metrics? +} + +type dnsQuery struct { + nameserver netip.Addr + name dnsname.FQDN + qtype dns.Type +} + +func (q dnsQuery) String() string { + return fmt.Sprintf("dnsQuery{nameserver:%q,name:%q,qtype:%v}", q.nameserver.String(), q.name, q.qtype) +} + +type dnsMsgWithExpiry struct { + *dns.Msg + expiresAt time.Time +} + +func (r *Resolver) now() time.Time { + if r.timeNow != nil { + return r.timeNow() + } + return time.Now() +} + +func (r *Resolver) logf(format string, args ...any) { + if r.Logf == nil { + return + } + r.Logf(format, args...) +} + +func (r *Resolver) dlogf(format string, args ...any) { + if r.Logf == nil || !debug() { + return + } + r.Logf(format, args...) +} + +func (r *Resolver) depthlogf(depth int, format string, args ...any) { + if r.Logf == nil || !debug() { + return + } + prefix := fmt.Sprintf("[%d] %s", depth, strings.Repeat(" ", depth)) + r.Logf(prefix+format, args...) +} + +var defaultDialer net.Dialer + +func (r *Resolver) dialer() netns.Dialer { + if r.Dialer != nil { + return r.Dialer + } + + return &defaultDialer +} + +func (r *Resolver) newState() *queryState { + var rootServers []netip.Addr + if len(r.rootServers) > 0 { + rootServers = r.rootServers + } else { + // Select a random subset of root nameservers to start from, since if + // we don't get responses from those, something else has probably gone + // horribly wrong. + roots4 := slices.Clone(rootServersV4) + slicesx.Shuffle(roots4) + roots4 = roots4[:numStartingServers] + + var roots6 []netip.Addr + if !r.NoIPv6 { + roots6 = slices.Clone(rootServersV6) + slicesx.Shuffle(roots6) + roots6 = roots6[:numStartingServers] + } + + // Interleave the root servers so that we try to contact them over + // IPv4, then IPv6, IPv4, IPv6, etc. + rootServers = slicesx.Interleave(roots4, roots6) + } + + return &queryState{ + rootServers: rootServers, + } +} + +// Resolve will perform a recursive DNS resolution for the provided name, +// starting at a randomly-chosen root DNS server, and return the A and AAAA +// responses as a slice of netip.Addrs along with the minimum TTL for the +// returned records. +func (r *Resolver) Resolve(ctx context.Context, name string) (addrs []netip.Addr, minTTL time.Duration, err error) { + dnsName, err := dnsname.ToFQDN(name) + if err != nil { + return nil, 0, err + } + + qstate := r.newState() + + r.logf("querying IPv4 addresses for: %q", name) + addrs4, minTTL4, err4 := r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeA) + + var ( + addrs6 []netip.Addr + minTTL6 time.Duration + err6 error + ) + if !r.NoIPv6 { + r.logf("querying IPv6 addresses for: %q", name) + addrs6, minTTL6, err6 = r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeAAAA) + } + + if err4 != nil && err6 != nil { + if err4 == err6 { + return nil, 0, err4 + } + + return nil, 0, multierr.New(err4, err6) + } + if err4 != nil { + return addrs6, minTTL6, nil + } else if err6 != nil { + return addrs4, minTTL4, nil + } + + minTTL = minTTL4 + if minTTL6 < minTTL { + minTTL = minTTL6 + } + + addrs = append(addrs4, addrs6...) + if len(addrs) == 0 { + return nil, 0, ErrNoResponses + } + + slicesx.Shuffle(addrs) + return addrs, minTTL, nil +} + +func (r *Resolver) resolveRecursiveFromRoot( + ctx context.Context, + qstate *queryState, + depth int, + name dnsname.FQDN, // what we're querying + qtype dns.Type, +) ([]netip.Addr, time.Duration, error) { + r.depthlogf(depth, "resolving %q from root (type: %v)", name, qtype) + + var depthError bool + for _, server := range qstate.rootServers { + addrs, minTTL, err := r.resolveRecursive(ctx, qstate, depth, name, server, qtype) + if err == nil { + return addrs, minTTL, err + } else if errors.Is(err, ErrAuthoritativeNoResponses) { + return nil, 0, ErrAuthoritativeNoResponses + } else if errors.Is(err, ErrMaxDepth) { + depthError = true + } + } + + if depthError { + return nil, 0, ErrMaxDepth + } + return nil, 0, ErrNoResponses +} + +func (r *Resolver) resolveRecursive( + ctx context.Context, + qstate *queryState, + depth int, + name dnsname.FQDN, // what we're querying + nameserver netip.Addr, + qtype dns.Type, +) ([]netip.Addr, time.Duration, error) { + if depth == maxDepth { + r.depthlogf(depth, "not recursing past maximum depth") + return nil, 0, ErrMaxDepth + } + + // Ask this nameserver for an answer. + resp, err := r.queryNameserver(ctx, depth, name, nameserver, qtype) + if err != nil { + return nil, 0, err + } + + // If we get an actual answer from the nameserver, then return it. + var ( + answers []netip.Addr + cnames []dnsname.FQDN + minTTL = 24 * 60 * 60 // 24 hours in seconds + ) + for _, answer := range resp.Answer { + if crec, ok := answer.(*dns.CNAME); ok { + cnameFQDN, err := dnsname.ToFQDN(crec.Target) + if err != nil { + r.logf("bad CNAME %q returned: %v", crec.Target, err) + continue + } + + cnames = append(cnames, cnameFQDN) + continue + } + + addr := addrFromRecord(answer) + if !addr.IsValid() { + r.logf("[unexpected] invalid record in %T answer", answer) + } else if addr.Is4() && qtype != qtypeA { + r.logf("[unexpected] got IPv4 answer but qtype=%v", qtype) + } else if addr.Is6() && qtype != qtypeAAAA { + r.logf("[unexpected] got IPv6 answer but qtype=%v", qtype) + } else { + answers = append(answers, addr) + minTTL = min(minTTL, int(answer.Header().Ttl)) + } + } + + if len(answers) > 0 { + r.depthlogf(depth, "got answers for %q: %v", name, answers) + return answers, time.Duration(minTTL) * time.Second, nil + } + + r.depthlogf(depth, "no answers for %q", name) + + // If we have a non-zero number of CNAMEs, then try resolving those + // (from the root again) and return the first one that succeeds. + // + // TODO: return the union of all responses? + // TODO: parallelism? + if len(cnames) > 0 { + r.depthlogf(depth, "got CNAME responses for %q: %v", name, cnames) + } + var cnameDepthError bool + for _, cname := range cnames { + answers, minTTL, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, cname, qtype) + if err == nil { + return answers, minTTL, nil + } else if errors.Is(err, ErrAuthoritativeNoResponses) { + return nil, 0, ErrAuthoritativeNoResponses + } else if errors.Is(err, ErrMaxDepth) { + cnameDepthError = true + } + } + + // If this is an authoritative response, then we know that continuing + // to look further is not going to result in any answers and we should + // bail out. + if resp.MsgHdr.Authoritative { + // If we failed to recurse into a CNAME due to a depth limit, + // propagate that here. + if cnameDepthError { + return nil, 0, ErrMaxDepth + } + + r.depthlogf(depth, "got authoritative response with no answers; stopping") + return nil, 0, ErrAuthoritativeNoResponses + } + + r.depthlogf(depth, "got %d NS responses and %d ADDITIONAL responses for %q", len(resp.Ns), len(resp.Extra), name) + + // No CNAMEs and no answers; see if we got any AUTHORITY responses, + // which indicate which nameservers to query next. + var authorities []dnsname.FQDN + for _, rr := range resp.Ns { + ns, ok := rr.(*dns.NS) + if !ok { + continue + } + + nsName, err := dnsname.ToFQDN(ns.Ns) + if err != nil { + r.logf("unexpected bad NS name %q: %v", ns.Ns, err) + continue + } + + authorities = append(authorities, nsName) + } + + // Also check for "glue" records, which are IP addresses provided by + // the DNS server for authority responses; these are required when the + // authority server is a subdomain of what's being resolved. + glueRecords := make(map[dnsname.FQDN][]netip.Addr) + for _, rr := range resp.Extra { + name, err := dnsname.ToFQDN(rr.Header().Name) + if err != nil { + r.logf("unexpected bad Name %q in Extra addr: %v", rr.Header().Name, err) + continue + } + + if addr := addrFromRecord(rr); addr.IsValid() { + glueRecords[name] = append(glueRecords[name], addr) + } else { + r.logf("unexpected bad Extra %T addr", rr) + } + } + + // Try authorities with glue records first, to minimize the number of + // additional DNS queries that we need to make. + authoritiesGlue, authoritiesNoGlue := slicesx.Partition(authorities, func(aa dnsname.FQDN) bool { + return len(glueRecords[aa]) > 0 + }) + + authorityDepthError := false + + r.depthlogf(depth, "authorities with glue records for recursion: %v", authoritiesGlue) + for _, authority := range authoritiesGlue { + for _, nameserver := range glueRecords[authority] { + answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype) + if err == nil { + return answers, minTTL, nil + } else if errors.Is(err, ErrAuthoritativeNoResponses) { + return nil, 0, ErrAuthoritativeNoResponses + } else if errors.Is(err, ErrMaxDepth) { + authorityDepthError = true + } + } + } + + r.depthlogf(depth, "authorities with no glue records for recursion: %v", authoritiesNoGlue) + for _, authority := range authoritiesNoGlue { + // First, resolve the IP for the authority server from the + // root, querying for both IPv4 and IPv6 addresses regardless + // of what the current question type is. + // + // TODO: check for infinite recursion; it'll get caught by our + // recursion depth, but we want to bail early. + for _, authorityQtype := range []dns.Type{qtypeAAAA, qtypeA} { + answers, _, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, authority, authorityQtype) + if err != nil { + r.depthlogf(depth, "error querying authority %q: %v", authority, err) + continue + } + r.depthlogf(depth, "resolved authority %q (type %v) to: %v", authority, authorityQtype, answers) + + // Now, query this authority for the final address. + for _, nameserver := range answers { + answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype) + if err == nil { + return answers, minTTL, nil + } else if errors.Is(err, ErrAuthoritativeNoResponses) { + return nil, 0, ErrAuthoritativeNoResponses + } else if errors.Is(err, ErrMaxDepth) { + authorityDepthError = true + } + } + } + } + + if authorityDepthError { + return nil, 0, ErrMaxDepth + } + return nil, 0, ErrNoResponses +} + +func min[T constraints.Ordered](a, b T) T { + if a < b { + return a + } + return b +} + +// queryNameserver sends a query for "name" to the nameserver "nameserver" for +// records of type "qtype", trying both UDP and TCP connections as +// appropriate. +func (r *Resolver) queryNameserver( + ctx context.Context, + depth int, + name dnsname.FQDN, // what we're querying + nameserver netip.Addr, // destination of query + qtype dns.Type, +) (*dns.Msg, error) { + // TODO(andrew): we should QNAME minimisation here to avoid sending the + // full name to intermediate/root nameservers. See: + // https://www.rfc-editor.org/rfc/rfc7816 + + // Handle the case where UDP is blocked by adding an explicit timeout + // for the UDP portion of this query. + udpCtx, udpCtxCancel := context.WithTimeout(ctx, udpQueryTimeout) + defer udpCtxCancel() + + msg, err := r.queryNameserverProto(udpCtx, depth, name, nameserver, "udp", qtype) + if err == nil { + return msg, nil + } + + msg, err2 := r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype) + if err2 == nil { + return msg, nil + } + + return nil, multierr.New(err, err2) +} + +// queryNameserverProto sends a query for "name" to the nameserver "nameserver" +// for records of type "qtype" over the provided protocol (either "udp" +// or "tcp"), and returns the DNS response or an error. +func (r *Resolver) queryNameserverProto( + ctx context.Context, + depth int, + name dnsname.FQDN, // what we're querying + nameserver netip.Addr, // destination of query + protocol string, + qtype dns.Type, +) (resp *dns.Msg, err error) { + if r.testQueryHook != nil { + return r.testQueryHook(name, nameserver, protocol, qtype) + } + + now := r.now() + nameserverStr := nameserver.String() + + cacheKey := dnsQuery{ + nameserver: nameserver, + name: name, + qtype: qtype, + } + cacheEntry, ok := r.queryCache[cacheKey] + if ok && cacheEntry.expiresAt.Before(now) { + r.depthlogf(depth, "using cached response from %s about %q (type: %v)", nameserverStr, name, qtype) + return cacheEntry.Msg, nil + } + + var network string + if nameserver.Is4() { + network = protocol + "4" + } else { + network = protocol + "6" + } + + // Prepare a message asking for an appropriately-typed record + // for the name we're querying. + m := new(dns.Msg) + m.SetQuestion(name.WithTrailingDot(), uint16(qtype)) + + // Allow mocking out the network components with our exchange hook. + if r.testExchangeHook != nil { + resp, err = r.testExchangeHook(nameserver, network, m) + } else { + // Dial the current nameserver using our dialer. + var nconn net.Conn + nconn, err = r.dialer().DialContext(ctx, network, net.JoinHostPort(nameserverStr, "53")) + if err != nil { + return nil, err + } + + var c dns.Client // TODO: share? + conn := &dns.Conn{ + Conn: nconn, + UDPSize: c.UDPSize, + } + + // Send the DNS request to the current nameserver. + // + // TODO(andrew): use ExchangeWithConnContext after this upstream PR is + // merged: + // https://github.com/miekg/dns/pull/1459 + r.depthlogf(depth, "asking %s over %s about %q (type: %v)", nameserverStr, protocol, name, qtype) + resp, _, err = c.ExchangeWithConn(m, conn) + } + if err != nil { + return nil, err + } + + // If the message was truncated and we're using UDP, re-run with TCP. + if resp.MsgHdr.Truncated && protocol == "udp" { + r.depthlogf(depth, "response message truncated; re-running query with TCP") + resp, err = r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype) + if err != nil { + return nil, err + } + } + + // Find minimum expiry for all records in this message. + var minTTL int + for _, rr := range resp.Answer { + minTTL = min(minTTL, int(rr.Header().Ttl)) + } + for _, rr := range resp.Ns { + minTTL = min(minTTL, int(rr.Header().Ttl)) + } + for _, rr := range resp.Extra { + minTTL = min(minTTL, int(rr.Header().Ttl)) + } + + mak.Set(&r.queryCache, cacheKey, dnsMsgWithExpiry{ + Msg: resp, + expiresAt: now.Add(time.Duration(minTTL) * time.Second), + }) + return resp, nil +} + +func addrFromRecord(rr dns.RR) netip.Addr { + switch v := rr.(type) { + case *dns.A: + ip, ok := netip.AddrFromSlice(v.A) + if !ok || !ip.Is4() { + return netip.Addr{} + } + return ip + case *dns.AAAA: + ip, ok := netip.AddrFromSlice(v.AAAA) + if !ok || !ip.Is6() { + return netip.Addr{} + } + return ip + } + return netip.Addr{} +} diff --git a/net/dns/recursive/recursive_test.go b/net/dns/recursive/recursive_test.go new file mode 100644 index 000000000..40d5aa15b --- /dev/null +++ b/net/dns/recursive/recursive_test.go @@ -0,0 +1,741 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package recursive + +import ( + "context" + "errors" + "flag" + "fmt" + "net" + "net/netip" + "reflect" + "strings" + "testing" + "time" + + "github.com/miekg/dns" + "golang.org/x/exp/slices" + "tailscale.com/envknob" + "tailscale.com/tstest" +) + +const testDomain = "tailscale.com" + +// Recursively resolving the AWS console requires being able to handle CNAMEs, +// glue records, falling back from UDP to TCP for oversize queries, and more; +// it's a great integration test for DNS resolution and they can handle the +// traffic :) +const complicatedTestDomain = "console.aws.amazon.com" + +var flagNetworkAccess = flag.Bool("enable-network-access", false, "run tests that need external network access") + +func init() { + envknob.Setenv("TS_DEBUG_RECURSIVE_DNS", "true") +} + +func newResolver(tb testing.TB) *Resolver { + clock := &tstest.Clock{ + Step: 50 * time.Millisecond, + } + return &Resolver{ + Logf: tb.Logf, + timeNow: clock.Now, + } +} + +func TestResolve(t *testing.T) { + if !*flagNetworkAccess { + t.SkipNow() + } + + ctx := context.Background() + r := newResolver(t) + addrs, minTTL, err := r.Resolve(ctx, testDomain) + if err != nil { + t.Fatal(err) + } + + t.Logf("addrs: %+v", addrs) + t.Logf("minTTL: %v", minTTL) + if len(addrs) < 1 { + t.Fatalf("expected at least one address") + } + + if minTTL <= 10*time.Second || minTTL >= 24*time.Hour { + t.Errorf("invalid minimum TTL: %v", minTTL) + } + + var has4, has6 bool + for _, addr := range addrs { + has4 = has4 || addr.Is4() + has6 = has6 || addr.Is6() + } + + if !has4 { + t.Errorf("expected at least one IPv4 address") + } + if !has6 { + t.Errorf("expected at least one IPv6 address") + } +} + +func TestResolveComplicated(t *testing.T) { + if !*flagNetworkAccess { + t.SkipNow() + } + + ctx := context.Background() + r := newResolver(t) + addrs, minTTL, err := r.Resolve(ctx, complicatedTestDomain) + if err != nil { + t.Fatal(err) + } + + t.Logf("addrs: %+v", addrs) + t.Logf("minTTL: %v", minTTL) + if len(addrs) < 1 { + t.Fatalf("expected at least one address") + } + + if minTTL <= 10*time.Second || minTTL >= 24*time.Hour { + t.Errorf("invalid minimum TTL: %v", minTTL) + } +} + +func TestResolveNoIPv6(t *testing.T) { + if !*flagNetworkAccess { + t.SkipNow() + } + + r := newResolver(t) + r.NoIPv6 = true + + addrs, _, err := r.Resolve(context.Background(), testDomain) + if err != nil { + t.Fatal(err) + } + + t.Logf("addrs: %+v", addrs) + if len(addrs) < 1 { + t.Fatalf("expected at least one address") + } + + for _, addr := range addrs { + if addr.Is6() { + t.Errorf("got unexpected IPv6 address: %v", addr) + } + } +} + +func TestResolveFallbackToTCP(t *testing.T) { + var udpCalls, tcpCalls int + hook := func(nameserver netip.Addr, network string, req *dns.Msg) (*dns.Msg, error) { + if strings.HasPrefix(network, "udp") { + t.Logf("got %q query; returning truncated result", network) + udpCalls++ + resp := &dns.Msg{} + resp.SetReply(req) + resp.Truncated = true + return resp, nil + } + + t.Logf("got %q query; returning real result", network) + tcpCalls++ + resp := &dns.Msg{} + resp.SetReply(req) + resp.Answer = append(resp.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: req.Question[0].Qtype, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.IPv4(1, 2, 3, 4), + }) + return resp, nil + } + + r := newResolver(t) + r.testExchangeHook = hook + + ctx := context.Background() + resp, err := r.queryNameserverProto(ctx, 0, "tailscale.com", netip.MustParseAddr("9.9.9.9"), "udp", dns.Type(dns.TypeA)) + if err != nil { + t.Fatal(err) + } + + if len(resp.Answer) < 1 { + t.Fatalf("no answers in response: %v", resp) + } + rrA, ok := resp.Answer[0].(*dns.A) + if !ok { + t.Fatalf("invalid RR type: %T", resp.Answer[0]) + } + if !rrA.A.Equal(net.IPv4(1, 2, 3, 4)) { + t.Errorf("wanted A response 1.2.3.4, got: %v", rrA.A) + } + if tcpCalls != 1 { + t.Errorf("got %d, want 1 TCP calls", tcpCalls) + } + if udpCalls != 1 { + t.Errorf("got %d, want 1 UDP calls", udpCalls) + } + + // Verify that we're cached and re-run to fetch from the cache. + if len(r.queryCache) < 1 { + t.Errorf("wanted entries in the query cache") + } + + resp2, err := r.queryNameserverProto(ctx, 0, "tailscale.com", netip.MustParseAddr("9.9.9.9"), "udp", dns.Type(dns.TypeA)) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(resp, resp2) { + t.Errorf("expected equal responses; old=%+v new=%+v", resp, resp2) + } + + // We didn't make any more network requests since we loaded from the cache. + if tcpCalls != 1 { + t.Errorf("got %d, want 1 TCP calls", tcpCalls) + } + if udpCalls != 1 { + t.Errorf("got %d, want 1 UDP calls", udpCalls) + } +} + +func dnsIPRR(name string, addr netip.Addr) dns.RR { + if addr.Is4() { + return &dns.A{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.IP(addr.AsSlice()), + } + } + + return &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: net.IP(addr.AsSlice()), + } +} + +func cnameRR(name, target string) dns.RR { + return &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 300, + }, + Target: target, + } +} + +func nsRR(name, target string) dns.RR { + return &dns.NS{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: 300, + }, + Ns: target, + } +} + +type mockReply struct { + name string + qtype dns.Type + resp *dns.Msg +} + +type replyMock struct { + tb testing.TB + replies map[netip.Addr][]mockReply +} + +func (r *replyMock) exchangeHook(nameserver netip.Addr, network string, req *dns.Msg) (*dns.Msg, error) { + if len(req.Question) != 1 { + r.tb.Fatalf("unsupported multiple or empty question: %v", req.Question) + } + question := req.Question[0] + + replies := r.replies[nameserver] + if len(replies) == 0 { + r.tb.Fatalf("no configured replies for nameserver: %v", nameserver) + } + + for _, reply := range replies { + if reply.name == question.Name && reply.qtype == dns.Type(question.Qtype) { + return reply.resp.Copy(), nil + } + } + + r.tb.Fatalf("no replies found for query %q of type %v to %v", question.Name, question.Qtype, nameserver) + panic("unreachable") +} + +// responses for mocking, shared between the following tests +var ( + rootServerAddr = netip.MustParseAddr("198.41.0.4") // a.root-servers.net. + comNSAddr = netip.MustParseAddr("192.5.6.30") // a.gtld-servers.net. + + // DNS response from the root nameservers for a .com nameserver + comRecord = &dns.Msg{ + Ns: []dns.RR{nsRR("com.", "a.gtld-servers.net.")}, + Extra: []dns.RR{dnsIPRR("a.gtld-servers.net.", comNSAddr)}, + } + + // Random Amazon nameservers that we use in glue records + amazonNS = netip.MustParseAddr("205.251.192.197") + amazonNSv6 = netip.MustParseAddr("2600:9000:5306:1600::1") + + // Nameservers for the tailscale.com domain + tailscaleNameservers = &dns.Msg{ + Ns: []dns.RR{ + nsRR("tailscale.com.", "ns-197.awsdns-24.com."), + nsRR("tailscale.com.", "ns-557.awsdns-05.net."), + nsRR("tailscale.com.", "ns-1558.awsdns-02.co.uk."), + nsRR("tailscale.com.", "ns-1359.awsdns-41.org."), + }, + Extra: []dns.RR{ + dnsIPRR("ns-197.awsdns-24.com.", amazonNS), + }, + } +) + +func TestBasicRecursion(t *testing.T) { + mock := &replyMock{ + tb: t, + replies: map[netip.Addr][]mockReply{ + // Query to the root server returns the .com server + a glue record + rootServerAddr: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, + }, + + // Query to the ".com" server return the nameservers for tailscale.com + comNSAddr: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, + }, + + // Query to the actual nameserver works. + amazonNS: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{ + dnsIPRR("tailscale.com.", netip.MustParseAddr("13.248.141.131")), + dnsIPRR("tailscale.com.", netip.MustParseAddr("76.223.15.28")), + }, + }}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{ + dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b")), + dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a51d:27c1:1530:b9ef:2a6:b9e5")), + }, + }}, + }, + }, + } + + r := newResolver(t) + r.testExchangeHook = mock.exchangeHook + r.rootServers = []netip.Addr{rootServerAddr} + + // Query for tailscale.com, verify we get the right responses + ctx := context.Background() + addrs, minTTL, err := r.Resolve(ctx, "tailscale.com") + if err != nil { + t.Fatal(err) + } + wantAddrs := []netip.Addr{ + netip.MustParseAddr("13.248.141.131"), + netip.MustParseAddr("76.223.15.28"), + netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"), + netip.MustParseAddr("2600:9000:a51d:27c1:1530:b9ef:2a6:b9e5"), + } + slices.SortFunc(addrs, func(x, y netip.Addr) bool { return x.String() < y.String() }) + slices.SortFunc(wantAddrs, func(x, y netip.Addr) bool { return x.String() < y.String() }) + + if !reflect.DeepEqual(addrs, wantAddrs) { + t.Errorf("got addrs=%+v; want %+v", addrs, wantAddrs) + } + + const wantMinTTL = 5 * time.Minute + if minTTL != wantMinTTL { + t.Errorf("got minTTL=%+v; want %+v", minTTL, wantMinTTL) + } +} + +func TestNoAnswers(t *testing.T) { + mock := &replyMock{ + tb: t, + replies: map[netip.Addr][]mockReply{ + // Query to the root server returns the .com server + a glue record + rootServerAddr: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, + }, + + // Query to the ".com" server return the nameservers for tailscale.com + comNSAddr: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, + }, + + // Query to the actual nameserver returns no responses, authoritatively. + amazonNS: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{}, + }}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{}, + }}, + }, + }, + } + + r := &Resolver{ + Logf: t.Logf, + testExchangeHook: mock.exchangeHook, + rootServers: []netip.Addr{rootServerAddr}, + } + + // Query for tailscale.com, verify we get the right responses + _, _, err := r.Resolve(context.Background(), "tailscale.com") + if err == nil { + t.Fatalf("got no error, want error") + } + if !errors.Is(err, ErrAuthoritativeNoResponses) { + t.Fatalf("got err=%v, want %v", err, ErrAuthoritativeNoResponses) + } +} + +func TestRecursionCNAME(t *testing.T) { + mock := &replyMock{ + tb: t, + replies: map[netip.Addr][]mockReply{ + // Query to the root server returns the .com server + a glue record + rootServerAddr: { + {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, + {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, + + {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, + {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, + }, + + // Query to the ".com" server return the nameservers for tailscale.com + comNSAddr: { + {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, + {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, + + {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, + {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, + }, + + // Query to the actual nameserver works. + amazonNS: { + {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{cnameRR("subdomain.otherdomain.com.", "subdomain.tailscale.com.")}, + }}, + {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{cnameRR("subdomain.otherdomain.com.", "subdomain.tailscale.com.")}, + }}, + + {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("13.248.141.131"))}, + }}, + {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"))}, + }}, + }, + }, + } + + r := &Resolver{ + Logf: t.Logf, + testExchangeHook: mock.exchangeHook, + rootServers: []netip.Addr{rootServerAddr}, + } + + // Query for tailscale.com, verify we get the right responses + addrs, minTTL, err := r.Resolve(context.Background(), "subdomain.otherdomain.com") + if err != nil { + t.Fatal(err) + } + wantAddrs := []netip.Addr{ + netip.MustParseAddr("13.248.141.131"), + netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"), + } + slices.SortFunc(addrs, func(x, y netip.Addr) bool { return x.String() < y.String() }) + slices.SortFunc(wantAddrs, func(x, y netip.Addr) bool { return x.String() < y.String() }) + + if !reflect.DeepEqual(addrs, wantAddrs) { + t.Errorf("got addrs=%+v; want %+v", addrs, wantAddrs) + } + + const wantMinTTL = 5 * time.Minute + if minTTL != wantMinTTL { + t.Errorf("got minTTL=%+v; want %+v", minTTL, wantMinTTL) + } +} + +func TestRecursionNoGlue(t *testing.T) { + coukNS := netip.MustParseAddr("213.248.216.1") + coukRecord := &dns.Msg{ + Ns: []dns.RR{nsRR("com.", "dns1.nic.uk.")}, + Extra: []dns.RR{dnsIPRR("dns1.nic.uk.", coukNS)}, + } + + intermediateNS := netip.MustParseAddr("205.251.193.66") // g-ns-322.awsdns-02.co.uk. + intermediateRecord := &dns.Msg{ + Ns: []dns.RR{nsRR("awsdns-02.co.uk.", "g-ns-322.awsdns-02.co.uk.")}, + Extra: []dns.RR{dnsIPRR("g-ns-322.awsdns-02.co.uk.", intermediateNS)}, + } + + const amazonNameserver = "ns-1558.awsdns-02.co.uk." + tailscaleNameservers := &dns.Msg{ + Ns: []dns.RR{ + nsRR("tailscale.com.", amazonNameserver), + }, + } + + tailscaleResponses := []mockReply{ + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("13.248.141.131"))}, + }}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"))}, + }}, + } + + mock := &replyMock{ + tb: t, + replies: map[netip.Addr][]mockReply{ + rootServerAddr: { + // Query to the root server returns the .com server + a glue record + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, + + // Querying the .co.uk nameserver returns the .co.uk nameserver + a glue record. + {name: amazonNameserver, qtype: dns.Type(dns.TypeA), resp: coukRecord}, + {name: amazonNameserver, qtype: dns.Type(dns.TypeAAAA), resp: coukRecord}, + }, + + // Queries to the ".com" server return the nameservers + // for tailscale.com, which don't contain a glue + // record. + comNSAddr: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, + }, + + // Queries to the ".co.uk" nameserver returns the + // address of the intermediate Amazon nameserver. + coukNS: { + {name: amazonNameserver, qtype: dns.Type(dns.TypeA), resp: intermediateRecord}, + {name: amazonNameserver, qtype: dns.Type(dns.TypeAAAA), resp: intermediateRecord}, + }, + + // Queries to the intermediate nameserver returns an + // answer for the final Amazon nameserver. + intermediateNS: { + {name: amazonNameserver, qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{dnsIPRR(amazonNameserver, amazonNS)}, + }}, + {name: amazonNameserver, qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{dnsIPRR(amazonNameserver, amazonNSv6)}, + }}, + }, + + // Queries to the actual nameserver work and return + // responses to the query. + amazonNS: tailscaleResponses, + amazonNSv6: tailscaleResponses, + }, + } + + r := newResolver(t) + r.testExchangeHook = mock.exchangeHook + r.rootServers = []netip.Addr{rootServerAddr} + + // Query for tailscale.com, verify we get the right responses + addrs, minTTL, err := r.Resolve(context.Background(), "tailscale.com") + if err != nil { + t.Fatal(err) + } + wantAddrs := []netip.Addr{ + netip.MustParseAddr("13.248.141.131"), + netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"), + } + slices.SortFunc(addrs, func(x, y netip.Addr) bool { return x.String() < y.String() }) + slices.SortFunc(wantAddrs, func(x, y netip.Addr) bool { return x.String() < y.String() }) + + if !reflect.DeepEqual(addrs, wantAddrs) { + t.Errorf("got addrs=%+v; want %+v", addrs, wantAddrs) + } + + const wantMinTTL = 5 * time.Minute + if minTTL != wantMinTTL { + t.Errorf("got minTTL=%+v; want %+v", minTTL, wantMinTTL) + } +} + +func TestRecursionLimit(t *testing.T) { + mock := &replyMock{ + tb: t, + replies: map[netip.Addr][]mockReply{}, + } + + // Fill out a CNAME chain equal to our recursion limit; we won't get + // this far since each CNAME is more than 1 level "deep", but this + // ensures that we have more than the limit. + for i := 0; i < maxDepth+1; i++ { + curr := fmt.Sprintf("%d-tailscale.com.", i) + + tailscaleNameservers := &dns.Msg{ + Ns: []dns.RR{nsRR(curr, "ns-197.awsdns-24.com.")}, + Extra: []dns.RR{dnsIPRR("ns-197.awsdns-24.com.", amazonNS)}, + } + + // Query to the root server returns the .com server + a glue record + mock.replies[rootServerAddr] = append(mock.replies[rootServerAddr], + mockReply{name: curr, qtype: dns.Type(dns.TypeA), resp: comRecord}, + mockReply{name: curr, qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, + ) + + // Query to the ".com" server return the nameservers for NN-tailscale.com + mock.replies[comNSAddr] = append(mock.replies[comNSAddr], + mockReply{name: curr, qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, + mockReply{name: curr, qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, + ) + + // Queries to the nameserver return a CNAME for the n+1th server. + next := fmt.Sprintf("%d-tailscale.com.", i+1) + mock.replies[amazonNS] = append(mock.replies[amazonNS], + mockReply{ + name: curr, + qtype: dns.Type(dns.TypeA), + resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{cnameRR(curr, next)}, + }, + }, + mockReply{ + name: curr, + qtype: dns.Type(dns.TypeAAAA), + resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{cnameRR(curr, next)}, + }, + }, + ) + } + + r := newResolver(t) + r.testExchangeHook = mock.exchangeHook + r.rootServers = []netip.Addr{rootServerAddr} + + // Query for the first node in the chain, 0-tailscale.com, and verify + // we get a max-depth error. + ctx := context.Background() + _, _, err := r.Resolve(ctx, "0-tailscale.com") + if err == nil { + t.Fatal("expected error, got nil") + } else if !errors.Is(err, ErrMaxDepth) { + t.Fatalf("got err=%v, want ErrMaxDepth", err) + } +} + +func TestInvalidResponses(t *testing.T) { + mock := &replyMock{ + tb: t, + replies: map[netip.Addr][]mockReply{ + // Query to the root server returns the .com server + a glue record + rootServerAddr: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, + }, + + // Query to the ".com" server return the nameservers for tailscale.com + comNSAddr: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, + }, + + // Query to the actual nameserver returns an invalid IP address + amazonNS: { + {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + Answer: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "tailscale.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + // Note: this is an IPv6 addr in an IPv4 response + A: net.IP(netip.MustParseAddr("2600:9000:a51d:27c1:1530:b9ef:2a6:b9e5").AsSlice()), + }}, + }}, + {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{Authoritative: true}, + // This an IPv4 response to an IPv6 query + Answer: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "tailscale.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.IP(netip.MustParseAddr("13.248.141.131").AsSlice()), + }}, + }}, + }, + }, + } + + r := &Resolver{ + Logf: t.Logf, + testExchangeHook: mock.exchangeHook, + rootServers: []netip.Addr{rootServerAddr}, + } + + // Query for tailscale.com, verify we get no responses since the + // addresses are invalid. + _, _, err := r.Resolve(context.Background(), "tailscale.com") + if err == nil { + t.Fatalf("got no error, want error") + } + if !errors.Is(err, ErrAuthoritativeNoResponses) { + t.Fatalf("got err=%v, want %v", err, ErrAuthoritativeNoResponses) + } +} + +// TODO(andrew): test for more edge cases that aren't currently covered: +// * Nameservers that cross between IPv4 and IPv6 +// * Authoritative no replies after following CNAME +// * Authoritative no replies after following non-glue NS record +// * Error querying non-glue NS record followed by success