|
|
|
|
@ -7,7 +7,6 @@ import (
|
|
|
|
|
"bytes"
|
|
|
|
|
"context"
|
|
|
|
|
"encoding/binary"
|
|
|
|
|
"errors"
|
|
|
|
|
"flag"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
@ -450,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) {
|
|
|
|
|
func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) {
|
|
|
|
|
netMon, err := netmon.New(tb.Logf)
|
|
|
|
|
if err != nil {
|
|
|
|
|
tb.Fatal(err)
|
|
|
|
|
@ -464,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
|
|
|
|
|
modify(fwd)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rr := resolverAndDelay{
|
|
|
|
|
name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)},
|
|
|
|
|
resolvers := make([]resolverAndDelay, len(ports))
|
|
|
|
|
for i, port := range ports {
|
|
|
|
|
resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rpkt := packet{
|
|
|
|
|
@ -477,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
|
|
|
|
|
rchan := make(chan packet, 1)
|
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
|
|
|
tb.Cleanup(cancel)
|
|
|
|
|
err = fwd.forwardWithDestChan(ctx, rpkt, rchan, rr)
|
|
|
|
|
err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...)
|
|
|
|
|
select {
|
|
|
|
|
case res := <-rchan:
|
|
|
|
|
return res.bs, err
|
|
|
|
|
@ -486,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte {
|
|
|
|
|
resp, err := runTestQuery(tb, port, request, modify)
|
|
|
|
|
// makeTestRequest returns a new TypeA request for the given domain.
|
|
|
|
|
func makeTestRequest(tb testing.TB, domain string) []byte {
|
|
|
|
|
tb.Helper()
|
|
|
|
|
name := dns.MustNewName(domain)
|
|
|
|
|
builder := dns.NewBuilder(nil, dns.Header{})
|
|
|
|
|
builder.StartQuestions()
|
|
|
|
|
builder.Question(dns.Question{
|
|
|
|
|
Name: name,
|
|
|
|
|
Type: dns.TypeA,
|
|
|
|
|
Class: dns.ClassINET,
|
|
|
|
|
})
|
|
|
|
|
request, err := builder.Finish()
|
|
|
|
|
if err != nil {
|
|
|
|
|
tb.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
return request
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// makeTestResponse returns a new Type A response for the given domain,
|
|
|
|
|
// with the specified status code and zero or more addresses.
|
|
|
|
|
func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte {
|
|
|
|
|
tb.Helper()
|
|
|
|
|
name := dns.MustNewName(domain)
|
|
|
|
|
builder := dns.NewBuilder(nil, dns.Header{
|
|
|
|
|
Response: true,
|
|
|
|
|
Authoritative: true,
|
|
|
|
|
RCode: code,
|
|
|
|
|
})
|
|
|
|
|
builder.StartQuestions()
|
|
|
|
|
q := dns.Question{
|
|
|
|
|
Name: name,
|
|
|
|
|
Type: dns.TypeA,
|
|
|
|
|
Class: dns.ClassINET,
|
|
|
|
|
}
|
|
|
|
|
builder.Question(q)
|
|
|
|
|
if len(addrs) > 0 {
|
|
|
|
|
builder.StartAnswers()
|
|
|
|
|
for _, addr := range addrs {
|
|
|
|
|
builder.AResource(dns.ResourceHeader{
|
|
|
|
|
Name: q.Name,
|
|
|
|
|
Class: q.Class,
|
|
|
|
|
TTL: 120,
|
|
|
|
|
}, dns.AResource{
|
|
|
|
|
A: addr.As4(),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
response, err := builder.Finish()
|
|
|
|
|
if err != nil {
|
|
|
|
|
tb.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
return response
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte {
|
|
|
|
|
resp, err := runTestQuery(tb, request, modify, ports...)
|
|
|
|
|
if err != nil {
|
|
|
|
|
tb.Fatalf("error making request: %v", err)
|
|
|
|
|
}
|
|
|
|
|
@ -516,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
resp := mustRunTestQuery(t, port, request, nil)
|
|
|
|
|
resp := mustRunTestQuery(t, request, nil, port)
|
|
|
|
|
if !bytes.Equal(resp, largeResponse) {
|
|
|
|
|
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
|
|
|
|
}
|
|
|
|
|
@ -554,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
resp := mustRunTestQuery(t, port, request, nil)
|
|
|
|
|
resp := mustRunTestQuery(t, request, nil, port)
|
|
|
|
|
if !bytes.Equal(resp, largeResponse) {
|
|
|
|
|
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
|
|
|
|
}
|
|
|
|
|
@ -585,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) {
|
|
|
|
|
resp := mustRunTestQuery(t, request, func(fwd *forwarder) {
|
|
|
|
|
// Disable retries for this test.
|
|
|
|
|
fwd.controlKnobs = &controlknobs.Knobs{}
|
|
|
|
|
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
|
|
|
|
|
})
|
|
|
|
|
}, port)
|
|
|
|
|
|
|
|
|
|
wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
|
|
|
|
|
|
|
|
|
|
@ -613,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) {
|
|
|
|
|
const domain = "error-response.tailscale.com."
|
|
|
|
|
|
|
|
|
|
// Our response is a SERVFAIL
|
|
|
|
|
response := func() []byte {
|
|
|
|
|
name := dns.MustNewName(domain)
|
|
|
|
|
|
|
|
|
|
builder := dns.NewBuilder(nil, dns.Header{
|
|
|
|
|
Response: true,
|
|
|
|
|
RCode: dns.RCodeServerFailure,
|
|
|
|
|
})
|
|
|
|
|
builder.StartQuestions()
|
|
|
|
|
builder.Question(dns.Question{
|
|
|
|
|
Name: name,
|
|
|
|
|
Type: dns.TypeA,
|
|
|
|
|
Class: dns.ClassINET,
|
|
|
|
|
})
|
|
|
|
|
response, err := builder.Finish()
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
return response
|
|
|
|
|
}()
|
|
|
|
|
response := makeTestResponse(t, domain, dns.RCodeServerFailure)
|
|
|
|
|
|
|
|
|
|
// Our request is a single A query for the domain in the answer, above.
|
|
|
|
|
request := func() []byte {
|
|
|
|
|
builder := dns.NewBuilder(nil, dns.Header{})
|
|
|
|
|
builder.StartQuestions()
|
|
|
|
|
builder.Question(dns.Question{
|
|
|
|
|
Name: dns.MustNewName(domain),
|
|
|
|
|
Type: dns.TypeA,
|
|
|
|
|
Class: dns.ClassINET,
|
|
|
|
|
})
|
|
|
|
|
request, err := builder.Finish()
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
return request
|
|
|
|
|
}()
|
|
|
|
|
request := makeTestRequest(t, domain)
|
|
|
|
|
|
|
|
|
|
var sawRequest atomic.Bool
|
|
|
|
|
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
|
|
|
|
|
@ -657,14 +680,141 @@ func TestForwarderTCPFallbackError(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
_, err := runTestQuery(t, port, request, nil)
|
|
|
|
|
resp, err := runTestQuery(t, request, nil, port)
|
|
|
|
|
if !sawRequest.Load() {
|
|
|
|
|
t.Error("did not see DNS request")
|
|
|
|
|
}
|
|
|
|
|
if err == nil {
|
|
|
|
|
t.Error("wanted error, got nil")
|
|
|
|
|
} else if !errors.Is(err, errServerFailure) {
|
|
|
|
|
t.Errorf("wanted errServerFailure, got: %v", err)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("wanted nil, got %v", err)
|
|
|
|
|
}
|
|
|
|
|
var parser dns.Parser
|
|
|
|
|
respHeader, err := parser.Start(resp)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("parser.Start() failed: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if got, want := respHeader.RCode, dns.RCodeServerFailure; got != want {
|
|
|
|
|
t.Errorf("wanted %v, got %v", want, got)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Test to ensure that if we have more than one resolver, and at least one of them
|
|
|
|
|
// returns a successful response, we propagate it.
|
|
|
|
|
func TestForwarderWithManyResolvers(t *testing.T) {
|
|
|
|
|
enableDebug(t)
|
|
|
|
|
|
|
|
|
|
const domain = "example.com."
|
|
|
|
|
request := makeTestRequest(t, domain)
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
responses [][]byte // upstream responses
|
|
|
|
|
wantResponses [][]byte // we should receive one of these from the forwarder
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
name: "Success",
|
|
|
|
|
responses: [][]byte{ // All upstream servers returned successful, but different, response.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
|
|
|
|
|
},
|
|
|
|
|
wantResponses: [][]byte{ // We may forward whichever response is received first.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "ServFail",
|
|
|
|
|
responses: [][]byte{ // All upstream servers returned a SERVFAIL.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
|
|
|
|
},
|
|
|
|
|
wantResponses: [][]byte{
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "ServFail+Success",
|
|
|
|
|
responses: [][]byte{ // All upstream servers fail except for one.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
|
|
|
|
},
|
|
|
|
|
wantResponses: [][]byte{ // We should forward the successful response.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "NXDomain",
|
|
|
|
|
responses: [][]byte{ // All upstream servers returned NXDOMAIN.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeNameError),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeNameError),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeNameError),
|
|
|
|
|
},
|
|
|
|
|
wantResponses: [][]byte{
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeNameError),
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "NXDomain+Success",
|
|
|
|
|
responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeNameError),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeNameError),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
|
|
|
|
},
|
|
|
|
|
wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeNameError),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "Refused",
|
|
|
|
|
responses: [][]byte{ // All upstream servers return different failures.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeRefused),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeRefused),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeRefused),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeRefused),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeRefused),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
|
|
|
|
},
|
|
|
|
|
wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeRefused),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "MixFail",
|
|
|
|
|
responses: [][]byte{ // All upstream servers return different failures.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeNameError),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeRefused),
|
|
|
|
|
},
|
|
|
|
|
wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded.
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeNameError),
|
|
|
|
|
makeTestResponse(t, domain, dns.RCodeRefused),
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
ports := make([]uint16, len(tt.responses))
|
|
|
|
|
for i := range tt.responses {
|
|
|
|
|
ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {})
|
|
|
|
|
}
|
|
|
|
|
gotResponse, err := runTestQuery(t, request, nil, ports...)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("wanted nil, got %v", err)
|
|
|
|
|
}
|
|
|
|
|
responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool {
|
|
|
|
|
return slices.Equal(gotResponse, wantResponse)
|
|
|
|
|
})
|
|
|
|
|
if !responseOk {
|
|
|
|
|
t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0])
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -713,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) {
|
|
|
|
|
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
res, err := runTestQuery(t, port, request, nil)
|
|
|
|
|
res, err := runTestQuery(t, request, nil, port)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|