@ -1,23 +1,25 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package netx contains t he Network type to abstract over either a real
// network or a virtual network for testing .
// Package netx contains t ypes to describe and abstract over how dialing and
// listening are performed .
package netx
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"tailscale.com/net/memnet"
)
// DialFunc is a function that dials a network address.
//
// It's the type implemented by net.Dialer.DialContext or required
// by net/http.Transport.DialContext, etc.
type DialFunc func ( ctx context . Context , network , address string ) ( net . Conn , error )
// Network describes a network that can listen and dial. The two common
// implementations are [RealNetwork], using the net package to use the real
// network, or [MemNetwork], using an in-memory network (typically for testing)
// network, or [ memnet. Network], using an in-memory network (typically for testing)
type Network interface {
NewLocalTCPListener ( ) net . Listener
Listen ( network , address string ) ( net . Listener , error )
@ -44,77 +46,8 @@ func (realNetwork) NewLocalTCPListener() net.Listener {
ln , err := net . Listen ( "tcp" , "127.0.0.1:0" )
if err != nil {
if ln , err = net . Listen ( "tcp6" , "[::1]:0" ) ; err != nil {
panic ( fmt . Sprintf ( " httptest: failed to listen on a port: %v", err ) )
panic ( fmt . Sprintf ( " failed to listen on either IPv4 or IPv6 loc alhost port: %v", err ) )
}
}
return ln
}
// MemNetwork returns a Network implementation that uses an in-memory
// network for testing. It is only suitable for tests that do not
// require real network access.
//
// As of 2025-04-08, it only supports TCP.
func MemNetwork ( ) Network { return & memNetwork { } }
// memNetwork implements [Network] using an in-memory network.
type memNetwork struct {
mu sync . Mutex
lns map [ string ] * memnet . Listener // address -> listener
}
func ( m * memNetwork ) Listen ( network , address string ) ( net . Listener , error ) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil , fmt . Errorf ( "memNetwork: Listen called with unsupported network %q" , network )
}
ap , err := netip . ParseAddrPort ( address )
if err != nil {
return nil , fmt . Errorf ( "memNetwork: Listen called with invalid address %q: %w" , address , err )
}
m . mu . Lock ( )
defer m . mu . Unlock ( )
if m . lns == nil {
m . lns = make ( map [ string ] * memnet . Listener )
}
port := ap . Port ( )
for {
if port == 0 {
port = 33000
}
key := net . JoinHostPort ( ap . Addr ( ) . String ( ) , fmt . Sprint ( port ) )
_ , ok := m . lns [ key ]
if ok {
if ap . Port ( ) != 0 {
return nil , fmt . Errorf ( "memNetwork: Listen called with duplicate address %q" , address )
}
port ++
continue
}
ln := memnet . Listen ( key )
m . lns [ key ] = ln
return ln , nil
}
}
func ( m * memNetwork ) NewLocalTCPListener ( ) net . Listener {
ln , err := m . Listen ( "tcp" , "127.0.0.1:0" )
if err != nil {
panic ( fmt . Sprintf ( "memNetwork: failed to create local TCP listener: %v" , err ) )
}
return ln
}
func ( m * memNetwork ) Dial ( ctx context . Context , network , address string ) ( net . Conn , error ) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil , fmt . Errorf ( "memNetwork: Dial called with unsupported network %q" , network )
}
m . mu . Lock ( )
ln , ok := m . lns [ address ]
m . mu . Unlock ( )
if ! ok {
return nil , fmt . Errorf ( "memNetwork: Dial called on unknown address %q" , address )
}
return ln . Dial ( ctx , network , address )
}