From f07ff47922c11377374ffe91a8dbe0fa12fb1b56 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Mon, 7 Oct 2024 17:08:22 -0500 Subject: [PATCH] net/dns/resolver: add tests for using a forwarder with multiple upstream resolvers If multiple upstream DNS servers are available, quad-100 sends requests to all of them and forwards the first successful response, if any. If no successful responses are received, it propagates the first failure from any of them. This PR adds some test coverage for these scenarios. Updates #13571 Signed-off-by: Nick Khyl --- net/dns/resolver/forwarder_test.go | 235 +++++++++++++++++++++++------ 1 file changed, 190 insertions(+), 45 deletions(-) diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 9c0964e93..e341186ec 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -449,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) return } -func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) { +func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) { netMon, err := netmon.New(tb.Logf) if err != nil { tb.Fatal(err) @@ -463,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa modify(fwd) } - rr := resolverAndDelay{ - name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}, + resolvers := make([]resolverAndDelay, len(ports)) + for i, port := range ports { + resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)} } rpkt := packet{ @@ -476,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa rchan := make(chan packet, 1) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) tb.Cleanup(cancel) - err = fwd.forwardWithDestChan(ctx, rpkt, rchan, rr) + err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...) select { case res := <-rchan: return res.bs, err @@ -485,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa } } -func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte { - resp, err := runTestQuery(tb, port, request, modify) +// makeTestRequest returns a new TypeA request for the given domain. +func makeTestRequest(tb testing.TB, domain string) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + }) + request, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return request +} + +// makeTestResponse returns a new Type A response for the given domain, +// with the specified status code and zero or more addresses. +func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{ + Response: true, + Authoritative: true, + RCode: code, + }) + builder.StartQuestions() + q := dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + } + builder.Question(q) + if len(addrs) > 0 { + builder.StartAnswers() + for _, addr := range addrs { + builder.AResource(dns.ResourceHeader{ + Name: q.Name, + Class: q.Class, + TTL: 120, + }, dns.AResource{ + A: addr.As4(), + }) + } + } + response, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return response +} + +func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte { + resp, err := runTestQuery(tb, request, modify, ports...) if err != nil { tb.Fatalf("error making request: %v", err) } @@ -515,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, nil, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -553,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, nil, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -584,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) { + resp := mustRunTestQuery(t, request, func(fwd *forwarder) { // Disable retries for this test. fwd.controlKnobs = &controlknobs.Knobs{} fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) - }) + }, port) wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...) @@ -612,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) { const domain = "error-response.tailscale.com." // Our response is a SERVFAIL - response := func() []byte { - name := dns.MustNewName(domain) - - builder := dns.NewBuilder(nil, dns.Header{ - Response: true, - RCode: dns.RCodeServerFailure, - }) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: name, - Type: dns.TypeA, - Class: dns.ClassINET, - }) - response, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return response - }() + response := makeTestResponse(t, domain, dns.RCodeServerFailure) // Our request is a single A query for the domain in the answer, above. - request := func() []byte { - builder := dns.NewBuilder(nil, dns.Header{}) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: dns.MustNewName(domain), - Type: dns.TypeA, - Class: dns.ClassINET, - }) - request, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return request - }() + request := makeTestRequest(t, domain) var sawRequest atomic.Bool port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { @@ -656,7 +680,7 @@ func TestForwarderTCPFallbackError(t *testing.T) { } }) - resp, err := runTestQuery(t, port, request, nil) + resp, err := runTestQuery(t, request, nil, port) if !sawRequest.Load() { t.Error("did not see DNS request") } @@ -673,6 +697,127 @@ func TestForwarderTCPFallbackError(t *testing.T) { } } +// Test to ensure that if we have more than one resolver, and at least one of them +// returns a successful response, we propagate it. +func TestForwarderWithManyResolvers(t *testing.T) { + enableDebug(t) + + const domain = "example.com." + request := makeTestRequest(t, domain) + + tests := []struct { + name string + responses [][]byte // upstream responses + wantResponses [][]byte // we should receive one of these from the forwarder + }{ + { + name: "Success", + responses: [][]byte{ // All upstream servers returned successful, but different, response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + wantResponses: [][]byte{ // We may forward whichever response is received first. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + }, + { + name: "ServFail", + responses: [][]byte{ // All upstream servers returned a SERVFAIL. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + }, + { + name: "ServFail+Success", + responses: [][]byte{ // All upstream servers fail except for one. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ // We should forward the successful response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "NXDomain", + responses: [][]byte{ // All upstream servers returned NXDOMAIN. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeNameError), + }, + }, + { + name: "NXDomain+Success", + responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "Refused", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "MixFail", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ports := make([]uint16, len(tt.responses)) + for i := range tt.responses { + ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {}) + } + gotResponse, err := runTestQuery(t, request, nil, ports...) + if err != nil { + t.Fatalf("wanted nil, got %v", err) + } + responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool { + return slices.Equal(gotResponse, wantResponse) + }) + if !responseOk { + t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0]) + } + }) + } +} + // mdnsResponder at minimum has an expectation that NXDOMAIN must include the // question, otherwise it will penalize our server (#13511). func TestNXDOMAINIncludesQuestion(t *testing.T) { @@ -718,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) { port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { }) - res, err := runTestQuery(t, port, request, nil) + res, err := runTestQuery(t, request, nil, port) if err != nil { t.Fatal(err) }