@ -6,11 +6,23 @@
package nettest
package nettest
import (
import (
"context"
"flag"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"sync"
"testing"
"testing"
"tailscale.com/net/memnet"
"tailscale.com/net/netmon"
"tailscale.com/net/netmon"
"tailscale.com/util/testenv"
)
)
var useMemNet = flag . Bool ( "use-test-memnet" , false , "prefer using in-memory network for tests" )
// SkipIfNoNetwork skips the test if it looks like there's no network
// SkipIfNoNetwork skips the test if it looks like there's no network
// access.
// access.
func SkipIfNoNetwork ( t testing . TB ) {
func SkipIfNoNetwork ( t testing . TB ) {
@ -19,3 +31,190 @@ func SkipIfNoNetwork(t testing.TB) {
t . Skip ( "skipping; test requires network but no interface is up" )
t . Skip ( "skipping; test requires network but no interface is up" )
}
}
}
}
// Network is an interface for use in tests that describes either [RealNetwork]
// or [MemNetwork].
type Network interface {
NewLocalTCPListener ( ) net . Listener
Listen ( network , address string ) ( net . Listener , error )
Dial ( ctx context . Context , network , address string ) ( net . Conn , error )
}
// PreferMemNetwork reports whether the --use-test-memnet flag is set.
func PreferMemNetwork ( ) bool {
return * useMemNet
}
// GetNetwork returns the appropriate Network implementation based on
// whether the --use-test-memnet flag is set.
//
// Each call generates a new network.
func GetNetwork ( tb testing . TB ) Network {
var n Network
if PreferMemNetwork ( ) {
n = MemNetwork ( )
} else {
n = RealNetwork ( )
}
detectLeaks := PreferMemNetwork ( ) || ! testenv . InParallelTest ( tb )
if detectLeaks {
tb . Cleanup ( func ( ) {
// TODO: leak detection, making sure no connections
// remain at the end of the test. For real network,
// snapshot conns in pid table before & after.
} )
}
return n
}
// RealNetwork returns a Network implementation that uses the real
// net package.
func RealNetwork ( ) Network { return realNetwork { } }
// realNetwork implements [Network] using the real net package.
type realNetwork struct { }
func ( realNetwork ) Listen ( network , address string ) ( net . Listener , error ) {
return net . Listen ( network , address )
}
func ( realNetwork ) Dial ( ctx context . Context , network , address string ) ( net . Conn , error ) {
var d net . Dialer
return d . DialContext ( ctx , network , address )
}
func ( realNetwork ) NewLocalTCPListener ( ) net . Listener {
ln , err := net . Listen ( "tcp" , "127.0.0.1:0" )
if err != nil {
if ln , err = net . Listen ( "tcp6" , "[::1]:0" ) ; err != nil {
panic ( fmt . Sprintf ( "httptest: failed to listen on a port: %v" , err ) )
}
}
return ln
}
// MemNetwork returns a Network implementation that uses an in-memory
// network for testing. It is only suitable for tests that do not
// require real network access.
func MemNetwork ( ) Network { return & memNetwork { } }
// memNetwork implements [Network] using an in-memory network.
type memNetwork struct {
mu sync . Mutex
lns map [ string ] * memnet . Listener // address -> listener
}
func ( m * memNetwork ) Listen ( network , address string ) ( net . Listener , error ) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil , fmt . Errorf ( "memNetwork: Listen called with unsupported network %q" , network )
}
ap , err := netip . ParseAddrPort ( address )
if err != nil {
return nil , fmt . Errorf ( "memNetwork: Listen called with invalid address %q: %w" , address , err )
}
m . mu . Lock ( )
defer m . mu . Unlock ( )
if m . lns == nil {
m . lns = make ( map [ string ] * memnet . Listener )
}
port := ap . Port ( )
for {
if port == 0 {
port = 33000
}
key := net . JoinHostPort ( ap . Addr ( ) . String ( ) , fmt . Sprint ( port ) )
_ , ok := m . lns [ key ]
if ok {
if ap . Port ( ) != 0 {
return nil , fmt . Errorf ( "memNetwork: Listen called with duplicate address %q" , address )
}
port ++
continue
}
ln := memnet . Listen ( key )
m . lns [ key ] = ln
return ln , nil
}
}
func ( m * memNetwork ) NewLocalTCPListener ( ) net . Listener {
ln , err := m . Listen ( "tcp" , "127.0.0.1:0" )
if err != nil {
panic ( fmt . Sprintf ( "memNetwork: failed to create local TCP listener: %v" , err ) )
}
return ln
}
func ( m * memNetwork ) Dial ( ctx context . Context , network , address string ) ( net . Conn , error ) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil , fmt . Errorf ( "memNetwork: Dial called with unsupported network %q" , network )
}
m . mu . Lock ( )
ln , ok := m . lns [ address ]
m . mu . Unlock ( )
if ! ok {
return nil , fmt . Errorf ( "memNetwork: Dial called on unknown address %q" , address )
}
return ln . Dial ( ctx , network , address )
}
// NewHTTPServer starts and returns a new [httptest.Server].
// The caller should call Close when finished, to shut it down.
func NewHTTPServer ( net Network , handler http . Handler ) * httptest . Server {
ts := NewUnstartedHTTPServer ( net , handler )
ts . Start ( )
return ts
}
// NewUnstartedHTTPServer returns a new [httptest.Server] but doesn't start it.
//
// After changing its configuration, the caller should call Start or
// StartTLS.
//
// The caller should call Close when finished, to shut it down.
func NewUnstartedHTTPServer ( nw Network , handler http . Handler ) * httptest . Server {
s := & httptest . Server {
Config : & http . Server { Handler : handler } ,
}
ln := nw . NewLocalTCPListener ( )
s . Listener = & listenerOnAddrOnce {
Listener : ln ,
fn : func ( ) {
c := s . Client ( )
if c == nil {
// This httptest.Server.Start initialization order has been true
// for over 10 years. Let's keep counting on it.
panic ( "httptest.Server: Client not initialized before Addr called" )
}
if c . Transport == nil {
c . Transport = & http . Transport { }
}
tr := c . Transport . ( * http . Transport )
if tr . Dial != nil || tr . DialContext != nil {
panic ( "unexpected non-nil Dial or DialContext in httptest.Server.Client.Transport" )
}
tr . DialContext = func ( ctx context . Context , network , addr string ) ( net . Conn , error ) {
return nw . Dial ( ctx , network , addr )
}
} ,
}
return s
}
// listenerOnAddrOnce is a net.Listener that wraps another net.Listener
// and calls a function the first time its Addr is called.
type listenerOnAddrOnce struct {
net . Listener
once sync . Once
fn func ( )
}
func ( ln * listenerOnAddrOnce ) Addr ( ) net . Addr {
ln . once . Do ( func ( ) {
ln . fn ( )
} )
return ln . Listener . Addr ( )
}