net/dns/resolver: add tests for using a forwarder with multiple upstream resolvers

If multiple upstream DNS servers are available, quad-100 sends requests to all of them
and forwards the first successful response, if any. If no successful responses are received,
it propagates the first failure from any of them.

This PR adds some test coverage for these scenarios.

Updates #13571

Signed-off-by: Nick Khyl <nickk@tailscale.com>
pull/13793/head
Nick Khyl 2 months ago committed by Nick Khyl
parent c2144c44a3
commit f07ff47922

@ -449,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte)
return return
} }
func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) { func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) {
netMon, err := netmon.New(tb.Logf) netMon, err := netmon.New(tb.Logf)
if err != nil { if err != nil {
tb.Fatal(err) tb.Fatal(err)
@ -463,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
modify(fwd) modify(fwd)
} }
rr := resolverAndDelay{ resolvers := make([]resolverAndDelay, len(ports))
name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}, for i, port := range ports {
resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}
} }
rpkt := packet{ rpkt := packet{
@ -476,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
rchan := make(chan packet, 1) rchan := make(chan packet, 1)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
tb.Cleanup(cancel) tb.Cleanup(cancel)
err = fwd.forwardWithDestChan(ctx, rpkt, rchan, rr) err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...)
select { select {
case res := <-rchan: case res := <-rchan:
return res.bs, err return res.bs, err
@ -485,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
} }
} }
func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte { // makeTestRequest returns a new TypeA request for the given domain.
resp, err := runTestQuery(tb, port, request, modify) func makeTestRequest(tb testing.TB, domain string) []byte {
tb.Helper()
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
})
request, err := builder.Finish()
if err != nil {
tb.Fatal(err)
}
return request
}
// makeTestResponse returns a new Type A response for the given domain,
// with the specified status code and zero or more addresses.
func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte {
tb.Helper()
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{
Response: true,
Authoritative: true,
RCode: code,
})
builder.StartQuestions()
q := dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
}
builder.Question(q)
if len(addrs) > 0 {
builder.StartAnswers()
for _, addr := range addrs {
builder.AResource(dns.ResourceHeader{
Name: q.Name,
Class: q.Class,
TTL: 120,
}, dns.AResource{
A: addr.As4(),
})
}
}
response, err := builder.Finish()
if err != nil {
tb.Fatal(err)
}
return response
}
func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte {
resp, err := runTestQuery(tb, request, modify, ports...)
if err != nil { if err != nil {
tb.Fatalf("error making request: %v", err) tb.Fatalf("error making request: %v", err)
} }
@ -515,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) {
} }
}) })
resp := mustRunTestQuery(t, port, request, nil) resp := mustRunTestQuery(t, request, nil, port)
if !bytes.Equal(resp, largeResponse) { if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
} }
@ -553,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) {
} }
}) })
resp := mustRunTestQuery(t, port, request, nil) resp := mustRunTestQuery(t, request, nil, port)
if !bytes.Equal(resp, largeResponse) { if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
} }
@ -584,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) {
} }
}) })
resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) { resp := mustRunTestQuery(t, request, func(fwd *forwarder) {
// Disable retries for this test. // Disable retries for this test.
fwd.controlKnobs = &controlknobs.Knobs{} fwd.controlKnobs = &controlknobs.Knobs{}
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
}) }, port)
wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...) wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
@ -612,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) {
const domain = "error-response.tailscale.com." const domain = "error-response.tailscale.com."
// Our response is a SERVFAIL // Our response is a SERVFAIL
response := func() []byte { response := makeTestResponse(t, domain, dns.RCodeServerFailure)
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{
Response: true,
RCode: dns.RCodeServerFailure,
})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
})
response, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return response
}()
// Our request is a single A query for the domain in the answer, above. // Our request is a single A query for the domain in the answer, above.
request := func() []byte { request := makeTestRequest(t, domain)
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: dns.MustNewName(domain),
Type: dns.TypeA,
Class: dns.ClassINET,
})
request, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return request
}()
var sawRequest atomic.Bool var sawRequest atomic.Bool
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
@ -656,7 +680,7 @@ func TestForwarderTCPFallbackError(t *testing.T) {
} }
}) })
resp, err := runTestQuery(t, port, request, nil) resp, err := runTestQuery(t, request, nil, port)
if !sawRequest.Load() { if !sawRequest.Load() {
t.Error("did not see DNS request") t.Error("did not see DNS request")
} }
@ -673,6 +697,127 @@ func TestForwarderTCPFallbackError(t *testing.T) {
} }
} }
// Test to ensure that if we have more than one resolver, and at least one of them
// returns a successful response, we propagate it.
func TestForwarderWithManyResolvers(t *testing.T) {
enableDebug(t)
const domain = "example.com."
request := makeTestRequest(t, domain)
tests := []struct {
name string
responses [][]byte // upstream responses
wantResponses [][]byte // we should receive one of these from the forwarder
}{
{
name: "Success",
responses: [][]byte{ // All upstream servers returned successful, but different, response.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
},
wantResponses: [][]byte{ // We may forward whichever response is received first.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
},
},
{
name: "ServFail",
responses: [][]byte{ // All upstream servers returned a SERVFAIL.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
wantResponses: [][]byte{
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
},
{
name: "ServFail+Success",
responses: [][]byte{ // All upstream servers fail except for one.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
wantResponses: [][]byte{ // We should forward the successful response.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "NXDomain",
responses: [][]byte{ // All upstream servers returned NXDOMAIN.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeNameError),
},
wantResponses: [][]byte{
makeTestResponse(t, domain, dns.RCodeNameError),
},
},
{
name: "NXDomain+Success",
responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "Refused",
responses: [][]byte{ // All upstream servers return different failures.
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded.
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "MixFail",
responses: [][]byte{ // All upstream servers return different failures.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeRefused),
},
wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeRefused),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ports := make([]uint16, len(tt.responses))
for i := range tt.responses {
ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {})
}
gotResponse, err := runTestQuery(t, request, nil, ports...)
if err != nil {
t.Fatalf("wanted nil, got %v", err)
}
responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool {
return slices.Equal(gotResponse, wantResponse)
})
if !responseOk {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0])
}
})
}
}
// mdnsResponder at minimum has an expectation that NXDOMAIN must include the // mdnsResponder at minimum has an expectation that NXDOMAIN must include the
// question, otherwise it will penalize our server (#13511). // question, otherwise it will penalize our server (#13511).
func TestNXDOMAINIncludesQuestion(t *testing.T) { func TestNXDOMAINIncludesQuestion(t *testing.T) {
@ -718,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) {
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
}) })
res, err := runTestQuery(t, port, request, nil) res, err := runTestQuery(t, request, nil, port)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

Loading…
Cancel
Save