@ -5,19 +5,23 @@
package dns
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"net/netip"
"testing"
"time"
"github.com/google/go-cmp/cmp"
dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/net/tsdial"
"tailscale.com/tstest"
"tailscale.com/util/dnsname"
)
func mkDNSRequest ( domain dnsname . FQDN , tp dns . Type ) [ ] byte {
func mkDNSRequest ( domain dnsname . FQDN , tp dns . Type , modify func ( * dns . Builder ) ) [ ] byte {
var dnsHeader dns . Header
question := dns . Question {
Name : dns . MustNewName ( domain . WithTrailingDot ( ) ) ,
@ -37,6 +41,15 @@ func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte {
panic ( err )
}
if modify != nil {
modify ( & builder )
}
payload , _ := builder . Finish ( )
return payload
}
func addEDNS ( builder * dns . Builder ) {
ednsHeader := dns . ResourceHeader {
Name : dns . MustNewName ( "." ) ,
Type : dns . TypeOPT ,
@ -46,10 +59,25 @@ func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte {
if err := builder . OPTResource ( ednsHeader , dns . OPTResource { } ) ; err != nil {
panic ( err )
}
}
payload , _ := builder . Finish ( )
func mkLargeDNSRequest ( domain dnsname . FQDN , tp dns . Type ) [ ] byte {
return mkDNSRequest ( domain , tp , func ( builder * dns . Builder ) {
ednsHeader := dns . ResourceHeader {
Name : dns . MustNewName ( "." ) ,
Type : dns . TypeOPT ,
Class : dns . Class ( 4095 ) ,
}
return payload
if err := builder . OPTResource ( ednsHeader , dns . OPTResource {
Options : [ ] dns . Option { {
Code : 1234 ,
Data : bytes . Repeat ( [ ] byte ( "A" ) , maxReqSizeTCP ) ,
} } ,
} ) ; err != nil {
panic ( err )
}
} )
}
func TestDNSOverTCP ( t * testing . T ) {
@ -82,7 +110,7 @@ func TestDNSOverTCP(t *testing.T) {
}
for domain , _ := range wantResults {
b := mkDNSRequest ( domain , dns . TypeA )
b := mkDNSRequest ( domain , dns . TypeA , addEDNS )
binary . Write ( c , binary . BigEndian , uint16 ( len ( b ) ) )
c . Write ( b )
}
@ -134,3 +162,69 @@ func TestDNSOverTCP(t *testing.T) {
t . Errorf ( "wrong results (-got+want)\n%s" , diff )
}
}
func TestDNSOverTCP_TooLarge ( t * testing . T ) {
log := tstest . WhileTestRunningLogger ( t )
f := fakeOSConfigurator {
SplitDNS : true ,
BaseConfig : OSConfig {
Nameservers : mustIPs ( "8.8.8.8" ) ,
SearchDomains : fqdns ( "coffee.shop" ) ,
} ,
}
m := NewManager ( log , & f , nil , new ( tsdial . Dialer ) , nil )
m . resolver . TestOnlySetHook ( f . SetResolver )
m . Set ( Config {
Hosts : hosts ( "andrew.ts.com." , "1.2.3.4" ) ,
Routes : upstreams ( "ts.com" , "" ) ,
SearchDomains : fqdns ( "tailscale.com" ) ,
} )
defer m . Down ( )
c , s := net . Pipe ( )
defer s . Close ( )
go m . HandleTCPConn ( s , netip . AddrPort { } )
defer c . Close ( )
var b [ ] byte
domain := dnsname . FQDN ( "andrew.ts.com." )
// Write a successful request, then a large one that will fail; this
// exercises the data race in tailscale/tailscale#6725
b = mkDNSRequest ( domain , dns . TypeA , addEDNS )
binary . Write ( c , binary . BigEndian , uint16 ( len ( b ) ) )
if _ , err := c . Write ( b ) ; err != nil {
t . Fatal ( err )
}
c . SetWriteDeadline ( time . Now ( ) . Add ( 5 * time . Second ) )
b = mkLargeDNSRequest ( domain , dns . TypeA )
if err := binary . Write ( c , binary . BigEndian , uint16 ( len ( b ) ) ) ; err != nil {
t . Fatal ( err )
}
if _ , err := c . Write ( b ) ; err != nil {
// It's possible that we get an error here, since the
// net.Pipe() implementation enforces synchronous reads. So,
// handleReads could read the size, then error, and this write
// fails. That's actually a success for this test!
if errors . Is ( err , io . ErrClosedPipe ) {
t . Logf ( "pipe (correctly) closed when writing large response" )
return
}
t . Fatal ( err )
}
t . Logf ( "reading responses" )
c . SetReadDeadline ( time . Now ( ) . Add ( 5 * time . Second ) )
// We expect an EOF now, since the connection will have been closed due
// to a too-large query.
var respLength uint16
err := binary . Read ( c , binary . BigEndian , & respLength )
if ! errors . Is ( err , io . EOF ) && ! errors . Is ( err , io . ErrClosedPipe ) {
t . Errorf ( "expected EOF on large read; got %v" , err )
}
}