diff --git a/net/dns/manager_test.go b/net/dns/manager_test.go index 679f81cd5..cf0c2458e 100644 --- a/net/dns/manager_test.go +++ b/net/dns/manager_test.go @@ -4,24 +4,36 @@ package dns import ( + "bytes" + "context" "errors" + "io" + "net/http" + "net/http/httptest" "net/netip" "reflect" "runtime" + "slices" "strings" + "sync" "testing" "testing/synctest" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + dns "golang.org/x/net/dns/dnsmessage" "tailscale.com/control/controlknobs" "tailscale.com/health" + "tailscale.com/net/dns/publicdns" "tailscale.com/net/dns/resolver" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" + "tailscale.com/tstest" "tailscale.com/types/dnstype" "tailscale.com/util/dnsname" "tailscale.com/util/eventbus/eventbustest" + "tailscale.com/util/httpm" ) type fakeOSConfigurator struct { @@ -1116,3 +1128,195 @@ func TestTrampleRetrample(t *testing.T) { } }) } + +// TestSystemDNSDoHUpgrade tests that if the user doesn't configure DNS servers +// in their tailnet, and the system DNS happens to be a known DoH provider, +// queries will use DNS-over-HTTPS. +func TestSystemDNSDoHUpgrade(t *testing.T) { + var ( + // This is a non-routable TEST-NET-2 IP (RFC 5737). + testDoHResolverIP = netip.MustParseAddr("198.51.100.1") + // This is a non-routable TEST-NET-1 IP (RFC 5737). + testResponseIP = netip.MustParseAddr("192.0.2.1") + ) + const testDomain = "test.example.com." + + var ( + mu sync.Mutex + dohRequestSeen bool + receivedQuery []byte + ) + dohServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("[DoH Server] received request: %v %v", r.Method, r.URL) + + if r.Method != httpm.POST { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if r.Header.Get("Content-Type") != "application/dns-message" { + http.Error(w, "bad content type", http.StatusBadRequest) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "read error", http.StatusInternalServerError) + return + } + + mu.Lock() + defer mu.Unlock() + + dohRequestSeen = true + receivedQuery = body + + // Build a DNS response + response := buildTestDNSResponse(t, testDomain, testResponseIP) + w.Header().Set("Content-Type", "application/dns-message") + w.Write(response) + })) + t.Cleanup(dohServer.Close) + + // Register the test IP to route to our mock DoH server + cleanup := publicdns.RegisterTestDoHEndpoint(testDoHResolverIP, dohServer.URL) + t.Cleanup(cleanup) + + // This simulates a system with the single DoH-capable DNS server + // configured. + f := &fakeOSConfigurator{ + SplitDNS: false, // non-split DNS required to use the forwarder + BaseConfig: OSConfig{ + Nameservers: []netip.Addr{testDoHResolverIP}, + }, + } + + logf := tstest.WhileTestRunningLogger(t) + bus := eventbustest.NewBus(t) + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(bus) + m := NewManager(logf, f, health.NewTracker(bus), dialer, nil, &controlknobs.Knobs{}, "linux", bus) + t.Cleanup(func() { m.Down() }) + + // Set up hook to capture the resolver config + m.resolver.TestOnlySetHook(f.SetResolver) + + // Configure the manager with routes but no default resolvers, which + // reads BaseConfig from the OS configurator. + config := Config{ + Routes: upstreams("tailscale.com.", "10.0.0.1"), + SearchDomains: fqdns("tailscale.com."), + } + if err := m.Set(config); err != nil { + t.Fatal(err) + } + + // Verify the resolver config has our test IP in Routes["."] + if f.ResolverConfig.Routes == nil { + t.Fatal("ResolverConfig.Routes is nil (SetResolver hook not called)") + } + + const defaultRouteKey = "." + defaultRoute, ok := f.ResolverConfig.Routes[defaultRouteKey] + if !ok { + t.Fatalf("ResolverConfig.Routes[%q] not found", defaultRouteKey) + } + if !slices.ContainsFunc(defaultRoute, func(r *dnstype.Resolver) bool { + return r.Addr == testDoHResolverIP.String() + }) { + t.Errorf("test IP %v not found in Routes[%q], got: %v", testDoHResolverIP, defaultRouteKey, defaultRoute) + } + + // Build a DNS query to something not handled by our split DNS route + // (tailscale.com) above. + query := buildTestDNSQuery(t, testDomain) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + resp, err := m.Query(ctx, query, "udp", netip.MustParseAddrPort("127.0.0.1:12345")) + if err != nil { + t.Fatal(err) + } + if len(resp) == 0 { + t.Fatal("empty response") + } + + // Parse the response to verify we get our test IP back. + var parser dns.Parser + if _, err := parser.Start(resp); err != nil { + t.Fatalf("parsing response header: %v", err) + } + if err := parser.SkipAllQuestions(); err != nil { + t.Fatalf("skipping questions: %v", err) + } + answers, err := parser.AllAnswers() + if err != nil { + t.Fatalf("parsing answers: %v", err) + } + if len(answers) == 0 { + t.Fatal("no answers in response") + } + aRecord, ok := answers[0].Body.(*dns.AResource) + if !ok { + t.Fatalf("first answer is not A record: %T", answers[0].Body) + } + gotIP := netip.AddrFrom4(aRecord.A) + if gotIP != testResponseIP { + t.Errorf("wrong A record IP: got %v, want %v", gotIP, testResponseIP) + } + + // Also verify that our DoH server received the query. + mu.Lock() + defer mu.Unlock() + if !dohRequestSeen { + t.Error("DoH server never received request") + } + if !bytes.Equal(receivedQuery, query) { + t.Errorf("DoH server received wrong query:\ngot: %x\nwant: %x", receivedQuery, query) + } +} + +// buildTestDNSQuery builds a simple DNS A query for the given domain. +func buildTestDNSQuery(t *testing.T, domain string) []byte { + t.Helper() + + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: dns.MustNewName(domain), + Type: dns.TypeA, + Class: dns.ClassINET, + }) + msg, err := builder.Finish() + if err != nil { + t.Fatal(err) + } + + return msg +} + +// buildTestDNSResponse builds a DNS response for the given query with the specified IP. +func buildTestDNSResponse(t *testing.T, domain string, ip netip.Addr) []byte { + t.Helper() + + builder := dns.NewBuilder(nil, dns.Header{Response: true}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: dns.MustNewName(domain), + Type: dns.TypeA, + Class: dns.ClassINET, + }) + + builder.StartAnswers() + builder.AResource(dns.ResourceHeader{ + Name: dns.MustNewName(domain), + Class: dns.ClassINET, + TTL: 300, + }, dns.AResource{A: ip.As4()}) + + msg, err := builder.Finish() + if err != nil { + t.Fatal(err) + } + + return msg +} diff --git a/net/dns/publicdns/publicdns.go b/net/dns/publicdns/publicdns.go index 7ceaf1813..3666bd778 100644 --- a/net/dns/publicdns/publicdns.go +++ b/net/dns/publicdns/publicdns.go @@ -13,12 +13,14 @@ import ( "log" "math/big" "net/netip" + "slices" "sort" "strconv" "strings" "sync" "tailscale.com/feature/buildfeatures" + "tailscale.com/util/testenv" ) // dohOfIP maps from public DNS IPs to their DoH base URL. @@ -367,3 +369,39 @@ func IPIsDoHOnlyServer(ip netip.Addr) bool { controlDv6RangeA.Contains(ip) || controlDv6RangeB.Contains(ip) || ip == controlDv4One || ip == controlDv4Two } + +var testMu sync.Mutex + +// RegisterTestDoHEndpoint registers a test DoH endpoint mapping for use in tests. +// It maps the given IP to the DoH base URL, and the URL back to the IP. +// +// This function panics if called outside of tests, and cannot be called +// concurrently with any usage of this package (i.e. before any DNS forwarders +// are created). It is safe to call concurrently with itself. +// +// It returns a cleanup function that removes the registration. +func RegisterTestDoHEndpoint(ip netip.Addr, dohBase string) func() { + if !testenv.InTest() { + panic("RegisterTestDoHEndpoint called outside of tests") + } + populateOnce.Do(populate) + + testMu.Lock() + defer testMu.Unlock() + + dohOfIP[ip] = dohBase + dohIPsOfBase[dohBase] = append(dohIPsOfBase[dohBase], ip) + + return func() { + testMu.Lock() + defer testMu.Unlock() + + delete(dohOfIP, ip) + dohIPsOfBase[dohBase] = slices.DeleteFunc(dohIPsOfBase[dohBase], func(addr netip.Addr) bool { + return addr == ip + }) + if len(dohIPsOfBase[dohBase]) == 0 { + delete(dohIPsOfBase, dohBase) + } + } +}