@ -449,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte)
return
return
}
}
func runTestQuery ( tb testing . TB , port uint16 , request [ ] byte , modify func ( * forwarder ) ) ( [ ] byte , error ) {
func runTestQuery ( tb testing . TB , request [ ] byte , modify func ( * forwarder ) , ports ... uint16 ) ( [ ] byte , error ) {
netMon , err := netmon . New ( tb . Logf )
netMon , err := netmon . New ( tb . Logf )
if err != nil {
if err != nil {
tb . Fatal ( err )
tb . Fatal ( err )
@ -463,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
modify ( fwd )
modify ( fwd )
}
}
rr := resolverAndDelay {
resolvers := make ( [ ] resolverAndDelay , len ( ports ) )
name : & dnstype . Resolver { Addr : fmt . Sprintf ( "127.0.0.1:%d" , port ) } ,
for i , port := range ports {
resolvers [ i ] . name = & dnstype . Resolver { Addr : fmt . Sprintf ( "127.0.0.1:%d" , port ) }
}
}
rpkt := packet {
rpkt := packet {
@ -476,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
rchan := make ( chan packet , 1 )
rchan := make ( chan packet , 1 )
ctx , cancel := context . WithTimeout ( context . Background ( ) , 5 * time . Second )
ctx , cancel := context . WithTimeout ( context . Background ( ) , 5 * time . Second )
tb . Cleanup ( cancel )
tb . Cleanup ( cancel )
err = fwd . forwardWithDestChan ( ctx , rpkt , rchan , r r)
err = fwd . forwardWithDestChan ( ctx , rpkt , rchan , r esolve rs... )
select {
select {
case res := <- rchan :
case res := <- rchan :
return res . bs , err
return res . bs , err
@ -485,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
}
}
}
}
func mustRunTestQuery ( tb testing . TB , port uint16 , request [ ] byte , modify func ( * forwarder ) ) [ ] byte {
// makeTestRequest returns a new TypeA request for the given domain.
resp , err := runTestQuery ( tb , port , request , modify )
func makeTestRequest ( tb testing . TB , domain string ) [ ] byte {
tb . Helper ( )
name := dns . MustNewName ( domain )
builder := dns . NewBuilder ( nil , dns . Header { } )
builder . StartQuestions ( )
builder . Question ( dns . Question {
Name : name ,
Type : dns . TypeA ,
Class : dns . ClassINET ,
} )
request , err := builder . Finish ( )
if err != nil {
tb . Fatal ( err )
}
return request
}
// makeTestResponse returns a new Type A response for the given domain,
// with the specified status code and zero or more addresses.
func makeTestResponse ( tb testing . TB , domain string , code dns . RCode , addrs ... netip . Addr ) [ ] byte {
tb . Helper ( )
name := dns . MustNewName ( domain )
builder := dns . NewBuilder ( nil , dns . Header {
Response : true ,
Authoritative : true ,
RCode : code ,
} )
builder . StartQuestions ( )
q := dns . Question {
Name : name ,
Type : dns . TypeA ,
Class : dns . ClassINET ,
}
builder . Question ( q )
if len ( addrs ) > 0 {
builder . StartAnswers ( )
for _ , addr := range addrs {
builder . AResource ( dns . ResourceHeader {
Name : q . Name ,
Class : q . Class ,
TTL : 120 ,
} , dns . AResource {
A : addr . As4 ( ) ,
} )
}
}
response , err := builder . Finish ( )
if err != nil {
tb . Fatal ( err )
}
return response
}
func mustRunTestQuery ( tb testing . TB , request [ ] byte , modify func ( * forwarder ) , ports ... uint16 ) [ ] byte {
resp , err := runTestQuery ( tb , request , modify , ports ... )
if err != nil {
if err != nil {
tb . Fatalf ( "error making request: %v" , err )
tb . Fatalf ( "error making request: %v" , err )
}
}
@ -515,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) {
}
}
} )
} )
resp := mustRunTestQuery ( t , port, request, nil )
resp := mustRunTestQuery ( t , request, nil , port )
if ! bytes . Equal ( resp , largeResponse ) {
if ! bytes . Equal ( resp , largeResponse ) {
t . Errorf ( "invalid response\ngot: %+v\nwant: %+v" , resp , largeResponse )
t . Errorf ( "invalid response\ngot: %+v\nwant: %+v" , resp , largeResponse )
}
}
@ -553,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) {
}
}
} )
} )
resp := mustRunTestQuery ( t , port, request, nil )
resp := mustRunTestQuery ( t , request, nil , port )
if ! bytes . Equal ( resp , largeResponse ) {
if ! bytes . Equal ( resp , largeResponse ) {
t . Errorf ( "invalid response\ngot: %+v\nwant: %+v" , resp , largeResponse )
t . Errorf ( "invalid response\ngot: %+v\nwant: %+v" , resp , largeResponse )
}
}
@ -584,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) {
}
}
} )
} )
resp := mustRunTestQuery ( t , port, request, func ( fwd * forwarder ) {
resp := mustRunTestQuery ( t , request, func ( fwd * forwarder ) {
// Disable retries for this test.
// Disable retries for this test.
fwd . controlKnobs = & controlknobs . Knobs { }
fwd . controlKnobs = & controlknobs . Knobs { }
fwd . controlKnobs . DisableDNSForwarderTCPRetries . Store ( true )
fwd . controlKnobs . DisableDNSForwarderTCPRetries . Store ( true )
} )
} , port )
wantResp := append ( [ ] byte ( nil ) , largeResponse [ : maxResponseBytes ] ... )
wantResp := append ( [ ] byte ( nil ) , largeResponse [ : maxResponseBytes ] ... )
@ -612,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) {
const domain = "error-response.tailscale.com."
const domain = "error-response.tailscale.com."
// Our response is a SERVFAIL
// Our response is a SERVFAIL
response := func ( ) [ ] byte {
response := makeTestResponse ( t , domain , dns . RCodeServerFailure )
name := dns . MustNewName ( domain )
builder := dns . NewBuilder ( nil , dns . Header {
Response : true ,
RCode : dns . RCodeServerFailure ,
} )
builder . StartQuestions ( )
builder . Question ( dns . Question {
Name : name ,
Type : dns . TypeA ,
Class : dns . ClassINET ,
} )
response , err := builder . Finish ( )
if err != nil {
t . Fatal ( err )
}
return response
} ( )
// Our request is a single A query for the domain in the answer, above.
// Our request is a single A query for the domain in the answer, above.
request := func ( ) [ ] byte {
request := makeTestRequest ( t , domain )
builder := dns . NewBuilder ( nil , dns . Header { } )
builder . StartQuestions ( )
builder . Question ( dns . Question {
Name : dns . MustNewName ( domain ) ,
Type : dns . TypeA ,
Class : dns . ClassINET ,
} )
request , err := builder . Finish ( )
if err != nil {
t . Fatal ( err )
}
return request
} ( )
var sawRequest atomic . Bool
var sawRequest atomic . Bool
port := runDNSServer ( t , nil , response , func ( isTCP bool , gotRequest [ ] byte ) {
port := runDNSServer ( t , nil , response , func ( isTCP bool , gotRequest [ ] byte ) {
@ -656,7 +680,7 @@ func TestForwarderTCPFallbackError(t *testing.T) {
}
}
} )
} )
resp , err := runTestQuery ( t , port, request, nil )
resp , err := runTestQuery ( t , request, nil , port )
if ! sawRequest . Load ( ) {
if ! sawRequest . Load ( ) {
t . Error ( "did not see DNS request" )
t . Error ( "did not see DNS request" )
}
}
@ -673,6 +697,127 @@ func TestForwarderTCPFallbackError(t *testing.T) {
}
}
}
}
// Test to ensure that if we have more than one resolver, and at least one of them
// returns a successful response, we propagate it.
func TestForwarderWithManyResolvers ( t * testing . T ) {
enableDebug ( t )
const domain = "example.com."
request := makeTestRequest ( t , domain )
tests := [ ] struct {
name string
responses [ ] [ ] byte // upstream responses
wantResponses [ ] [ ] byte // we should receive one of these from the forwarder
} {
{
name : "Success" ,
responses : [ ] [ ] byte { // All upstream servers returned successful, but different, response.
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.1" ) ) ,
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.2" ) ) ,
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.3" ) ) ,
} ,
wantResponses : [ ] [ ] byte { // We may forward whichever response is received first.
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.1" ) ) ,
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.2" ) ) ,
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.3" ) ) ,
} ,
} ,
{
name : "ServFail" ,
responses : [ ] [ ] byte { // All upstream servers returned a SERVFAIL.
makeTestResponse ( t , domain , dns . RCodeServerFailure ) ,
makeTestResponse ( t , domain , dns . RCodeServerFailure ) ,
makeTestResponse ( t , domain , dns . RCodeServerFailure ) ,
} ,
wantResponses : [ ] [ ] byte {
makeTestResponse ( t , domain , dns . RCodeServerFailure ) ,
} ,
} ,
{
name : "ServFail+Success" ,
responses : [ ] [ ] byte { // All upstream servers fail except for one.
makeTestResponse ( t , domain , dns . RCodeServerFailure ) ,
makeTestResponse ( t , domain , dns . RCodeServerFailure ) ,
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.1" ) ) ,
makeTestResponse ( t , domain , dns . RCodeServerFailure ) ,
} ,
wantResponses : [ ] [ ] byte { // We should forward the successful response.
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.1" ) ) ,
} ,
} ,
{
name : "NXDomain" ,
responses : [ ] [ ] byte { // All upstream servers returned NXDOMAIN.
makeTestResponse ( t , domain , dns . RCodeNameError ) ,
makeTestResponse ( t , domain , dns . RCodeNameError ) ,
makeTestResponse ( t , domain , dns . RCodeNameError ) ,
} ,
wantResponses : [ ] [ ] byte {
makeTestResponse ( t , domain , dns . RCodeNameError ) ,
} ,
} ,
{
name : "NXDomain+Success" ,
responses : [ ] [ ] byte { // All upstream servers returned NXDOMAIN except for one.
makeTestResponse ( t , domain , dns . RCodeNameError ) ,
makeTestResponse ( t , domain , dns . RCodeNameError ) ,
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.1" ) ) ,
} ,
wantResponses : [ ] [ ] byte { // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response.
makeTestResponse ( t , domain , dns . RCodeNameError ) ,
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.1" ) ) ,
} ,
} ,
{
name : "Refused" ,
responses : [ ] [ ] byte { // All upstream servers return different failures.
makeTestResponse ( t , domain , dns . RCodeRefused ) ,
makeTestResponse ( t , domain , dns . RCodeRefused ) ,
makeTestResponse ( t , domain , dns . RCodeRefused ) ,
makeTestResponse ( t , domain , dns . RCodeRefused ) ,
makeTestResponse ( t , domain , dns . RCodeRefused ) ,
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.1" ) ) ,
} ,
wantResponses : [ ] [ ] byte { // Refused is not considered to be an error and can be forwarded.
makeTestResponse ( t , domain , dns . RCodeRefused ) ,
makeTestResponse ( t , domain , dns . RCodeSuccess , netip . MustParseAddr ( "127.0.0.1" ) ) ,
} ,
} ,
{
name : "MixFail" ,
responses : [ ] [ ] byte { // All upstream servers return different failures.
makeTestResponse ( t , domain , dns . RCodeServerFailure ) ,
makeTestResponse ( t , domain , dns . RCodeNameError ) ,
makeTestResponse ( t , domain , dns . RCodeRefused ) ,
} ,
wantResponses : [ ] [ ] byte { // Both NXDomain and Refused can be forwarded.
makeTestResponse ( t , domain , dns . RCodeNameError ) ,
makeTestResponse ( t , domain , dns . RCodeRefused ) ,
} ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
ports := make ( [ ] uint16 , len ( tt . responses ) )
for i := range tt . responses {
ports [ i ] = runDNSServer ( t , nil , tt . responses [ i ] , func ( isTCP bool , gotRequest [ ] byte ) { } )
}
gotResponse , err := runTestQuery ( t , request , nil , ports ... )
if err != nil {
t . Fatalf ( "wanted nil, got %v" , err )
}
responseOk := slices . ContainsFunc ( tt . wantResponses , func ( wantResponse [ ] byte ) bool {
return slices . Equal ( gotResponse , wantResponse )
} )
if ! responseOk {
t . Errorf ( "invalid response\ngot: %+v\nwant: %+v" , gotResponse , tt . wantResponses [ 0 ] )
}
} )
}
}
// mdnsResponder at minimum has an expectation that NXDOMAIN must include the
// mdnsResponder at minimum has an expectation that NXDOMAIN must include the
// question, otherwise it will penalize our server (#13511).
// question, otherwise it will penalize our server (#13511).
func TestNXDOMAINIncludesQuestion ( t * testing . T ) {
func TestNXDOMAINIncludesQuestion ( t * testing . T ) {
@ -718,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) {
port := runDNSServer ( t , nil , response , func ( isTCP bool , gotRequest [ ] byte ) {
port := runDNSServer ( t , nil , response , func ( isTCP bool , gotRequest [ ] byte ) {
} )
} )
res , err := runTestQuery ( t , port, request, nil )
res , err := runTestQuery ( t , request, nil , port )
if err != nil {
if err != nil {
t . Fatal ( err )
t . Fatal ( err )
}
}