@ -9,6 +9,7 @@ import (
"io"
"net"
"net/netip"
"sync"
"testing"
"time"
@ -480,3 +481,198 @@ func TestV6V4(t *testing.T) {
}
}
}
// echoServer is a simple server that just echos back data set to it.
type echoServer struct {
listener net . Listener
addr string
wg sync . WaitGroup
done chan struct { }
}
// newEchoServer creates a new test DNS server on the specified network and address
func newEchoServer ( t * testing . T , network , addr string ) * echoServer {
listener , err := net . Listen ( network , addr )
if err != nil {
t . Fatalf ( "Failed to create test DNS server: %v" , err )
}
server := & echoServer {
listener : listener ,
addr : listener . Addr ( ) . String ( ) ,
done : make ( chan struct { } ) ,
}
server . wg . Add ( 1 )
go server . serve ( )
return server
}
func ( s * echoServer ) serve ( ) {
defer s . wg . Done ( )
for {
select {
case <- s . done :
return
default :
conn , err := s . listener . Accept ( )
if err != nil {
select {
case <- s . done :
return
default :
continue
}
}
go s . handleConnection ( conn )
}
}
}
func ( s * echoServer ) handleConnection ( conn net . Conn ) {
defer conn . Close ( )
// Simple response - just echo back some data to confirm connectivity
buf := make ( [ ] byte , 1024 )
n , err := conn . Read ( buf )
if err != nil {
return
}
conn . Write ( buf [ : n ] )
}
func ( s * echoServer ) close ( ) {
close ( s . done )
s . listener . Close ( )
s . wg . Wait ( )
}
func TestGetResolver ( t * testing . T ) {
tests := [ ] struct {
name string
network string
addr string
} {
{
name : "ipv4_loopback" ,
network : "tcp4" ,
addr : "127.0.0.1:0" ,
} ,
{
name : "ipv6_loopback" ,
network : "tcp6" ,
addr : "[::1]:0" ,
} ,
}
for _ , tc := range tests {
t . Run ( tc . name , func ( t * testing . T ) {
server := newEchoServer ( t , tc . network , tc . addr )
defer server . close ( )
serverAddr := server . addr
resolver := getResolver ( serverAddr )
if resolver == nil {
t . Fatal ( "getResolver returned nil" )
}
netResolver , ok := resolver . ( * net . Resolver )
if ! ok {
t . Fatal ( "getResolver did not return a *net.Resolver" )
}
if netResolver . Dial == nil {
t . Fatal ( "resolver.Dial is nil" )
}
ctx , cancel := context . WithTimeout ( context . Background ( ) , 5 * time . Second )
defer cancel ( )
conn , err := netResolver . Dial ( ctx , "tcp" , "dummy.address:53" )
if err != nil {
t . Fatalf ( "Failed to dial test DNS server: %v" , err )
}
defer conn . Close ( )
testData := [ ] byte ( "test" )
_ , err = conn . Write ( testData )
if err != nil {
t . Fatalf ( "Failed to write to connection: %v" , err )
}
response := make ( [ ] byte , len ( testData ) )
_ , err = conn . Read ( response )
if err != nil {
t . Fatalf ( "Failed to read from connection: %v" , err )
}
if string ( response ) != string ( testData ) {
t . Fatalf ( "Expected echo response %q, got %q" , testData , response )
}
} )
}
}
func TestGetResolverMultipleServers ( t * testing . T ) {
server1 := newEchoServer ( t , "tcp4" , "127.0.0.1:0" )
defer server1 . close ( )
server2 := newEchoServer ( t , "tcp4" , "127.0.0.1:0" )
defer server2 . close ( )
serverFlag := server1 . addr + ", " + server2 . addr
resolver := getResolver ( serverFlag )
netResolver , ok := resolver . ( * net . Resolver )
if ! ok {
t . Fatal ( "getResolver did not return a *net.Resolver" )
}
ctx , cancel := context . WithTimeout ( context . Background ( ) , 5 * time . Second )
defer cancel ( )
servers := map [ string ] bool {
server1 . addr : false ,
server2 . addr : false ,
}
// Try up to 1000 times to hit all servers, this should be very quick, and
// if this fails randomness has regressed beyond reason.
for range 1000 {
conn , err := netResolver . Dial ( ctx , "tcp" , "dummy.address:53" )
if err != nil {
t . Fatalf ( "Failed to dial test DNS server: %v" , err )
}
remoteAddr := conn . RemoteAddr ( ) . String ( )
conn . Close ( )
servers [ remoteAddr ] = true
var allDone = true
for _ , done := range servers {
if ! done {
allDone = false
break
}
}
if allDone {
break
}
}
var allDone = true
for _ , done := range servers {
if ! done {
allDone = false
break
}
}
if ! allDone {
t . Errorf ( "after 1000 queries, not all servers were hit, significant lack of randomness: %#v" , servers )
}
}
func TestGetResolverEmpty ( t * testing . T ) {
resolver := getResolver ( "" )
if resolver != net . DefaultResolver {
t . Fatal ( ` getResolver("") should return net.DefaultResolver ` )
}
}