diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 5d375a515..085a58383 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -111,6 +111,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa 💣 tailscale.com/net/netmon from tailscale.com/derp/derphttp+ 💣 tailscale.com/net/netns from tailscale.com/derp/derphttp tailscale.com/net/netutil from tailscale.com/client/local + tailscale.com/net/netx from tailscale.com/net/dnscache+ tailscale.com/net/sockstats from tailscale.com/derp/derphttp tailscale.com/net/stun from tailscale.com/net/stunserver tailscale.com/net/stunserver from tailscale.com/cmd/derper diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 7c87649d1..7fd4c4b21 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -866,6 +866,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ 💣 tailscale.com/net/netns from tailscale.com/derp/derphttp+ W 💣 tailscale.com/net/netstat from tailscale.com/portlist tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ tailscale.com/net/packet from tailscale.com/net/connstats+ tailscale.com/net/packet/checksum from tailscale.com/net/tstun tailscale.com/net/ping from tailscale.com/net/netcheck+ diff --git a/cmd/sniproxy/handlers.go b/cmd/sniproxy/handlers.go index 102110fe3..1973eecc0 100644 --- a/cmd/sniproxy/handlers.go +++ b/cmd/sniproxy/handlers.go @@ -14,6 +14,7 @@ import ( "github.com/inetaf/tcpproxy" "tailscale.com/net/netutil" + "tailscale.com/net/netx" ) type tcpRoundRobinHandler struct { @@ -22,7 +23,7 @@ type tcpRoundRobinHandler struct { To []string // DialContext is used to make the outgoing TCP connection. - DialContext func(ctx context.Context, network, address string) (net.Conn, error) + DialContext netx.DialFunc // ReachableIPs enumerates the IP addresses this handler is reachable on. ReachableIPs []netip.Addr diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 431bf7b71..9728a2ff4 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -112,6 +112,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep 💣 tailscale.com/net/netmon from tailscale.com/cmd/tailscale/cli+ 💣 tailscale.com/net/netns from tailscale.com/derp/derphttp+ tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlhttp+ tailscale.com/net/ping from tailscale.com/net/netcheck tailscale.com/net/portmapper from tailscale.com/cmd/tailscale/cli+ tailscale.com/net/sockstats from tailscale.com/control/controlhttp+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 1fbf7caf1..394056295 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -316,6 +316,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de 💣 tailscale.com/net/netns from tailscale.com/cmd/tailscaled+ W 💣 tailscale.com/net/netstat from tailscale.com/portlist tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ tailscale.com/net/packet from tailscale.com/net/connstats+ tailscale.com/net/packet/checksum from tailscale.com/net/tstun tailscale.com/net/ping from tailscale.com/net/netcheck+ diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 70ebe2f23..c8e885799 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -37,6 +37,7 @@ import ( "tailscale.com/net/dnsfallback" "tailscale.com/net/netmon" "tailscale.com/net/netutil" + "tailscale.com/net/netx" "tailscale.com/net/tlsdial" "tailscale.com/net/tsdial" "tailscale.com/net/tshttpproxy" @@ -272,7 +273,7 @@ func NewDirect(opts Options) (*Direct, error) { tr.Proxy = tshttpproxy.ProxyFromEnvironment tshttpproxy.SetTransportGetProxyConnectHeader(tr) tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), opts.HealthTracker, tr.TLSClientConfig) - var dialFunc dialFunc + var dialFunc netx.DialFunc dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial) tr.DialContext = dnscache.Dialer(dialFunc, dnsCache) tr.DialTLSContext = dnscache.TLSDialer(dialFunc, dnsCache, tr.TLSClientConfig) @@ -1749,14 +1750,12 @@ func addLBHeader(req *http.Request, nodeKey key.NodePublic) { } } -type dialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) - // makeScreenTimeDetectingDialFunc returns dialFunc, optionally wrapped (on // Apple systems) with a func that sets the returned atomic.Bool for whether // Screen Time seemed to intercept the connection. // // The returned *atomic.Bool is nil on non-Apple systems. -func makeScreenTimeDetectingDialFunc(dial dialFunc) (dialFunc, *atomic.Bool) { +func makeScreenTimeDetectingDialFunc(dial netx.DialFunc) (netx.DialFunc, *atomic.Bool) { switch runtime.GOOS { case "darwin", "ios": // Continue below. diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 44de6b0df..869bcb599 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -44,6 +44,7 @@ import ( "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" "tailscale.com/net/netutil" + "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" @@ -494,7 +495,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad dns = a.resolver() } - var dialer dnscache.DialContextFunc + var dialer netx.DialFunc if a.Dialer != nil { dialer = a.Dialer } else { diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index 80b3fe64c..12038fae4 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -12,6 +12,7 @@ import ( "tailscale.com/health" "tailscale.com/net/dnscache" "tailscale.com/net/netmon" + "tailscale.com/net/netx" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" @@ -66,7 +67,7 @@ type Dialer struct { // Dialer is the dialer used to make outbound connections. // // If not specified, this defaults to net.Dialer.DialContext. - Dialer dnscache.DialContextFunc + Dialer netx.DialFunc // DNSCache is the caching Resolver used by this Dialer. // diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index aef916ef6..f556640f8 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -26,8 +26,8 @@ import ( "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/health" - "tailscale.com/net/dnscache" "tailscale.com/net/netmon" + "tailscale.com/net/netx" "tailscale.com/net/socks5" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" @@ -760,7 +760,7 @@ func TestDialPlan(t *testing.T) { type closeTrackDialer struct { t testing.TB - inner dnscache.DialContextFunc + inner netx.DialFunc mu sync.Mutex conns map[*closeTrackConn]bool } diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 319c02429..21ee4a671 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -35,6 +35,7 @@ import ( "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" @@ -587,7 +588,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien // // The primary use for this is the derper mesh mode to connect to each // other over a VPC network. -func (c *Client) SetURLDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) { +func (c *Client) SetURLDialer(dialer netx.DialFunc) { c.dialer = dialer } diff --git a/k8s-operator/sessionrecording/hijacker.go b/k8s-operator/sessionrecording/hijacker.go index 43aa14e61..a9ed65896 100644 --- a/k8s-operator/sessionrecording/hijacker.go +++ b/k8s-operator/sessionrecording/hijacker.go @@ -25,6 +25,7 @@ import ( "tailscale.com/k8s-operator/sessionrecording/spdy" "tailscale.com/k8s-operator/sessionrecording/tsrecorder" "tailscale.com/k8s-operator/sessionrecording/ws" + "tailscale.com/net/netx" "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tsnet" @@ -102,7 +103,7 @@ type Hijacker struct { // connection succeeds. In case of success, returns a list with a single // successful recording attempt and an error channel. If the connection errors // after having been established, an error is sent down the channel. -type RecorderDialFn func(context.Context, []netip.AddrPort, sessionrecording.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) +type RecorderDialFn func(context.Context, []netip.AddrPort, netx.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) // Hijack hijacks a 'kubectl exec' session and configures for the session // contents to be sent to a recorder. diff --git a/k8s-operator/sessionrecording/hijacker_test.go b/k8s-operator/sessionrecording/hijacker_test.go index e166ce63b..880015b22 100644 --- a/k8s-operator/sessionrecording/hijacker_test.go +++ b/k8s-operator/sessionrecording/hijacker_test.go @@ -19,7 +19,7 @@ import ( "go.uber.org/zap" "tailscale.com/client/tailscale/apitype" "tailscale.com/k8s-operator/sessionrecording/fakes" - "tailscale.com/sessionrecording" + "tailscale.com/net/netx" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tstest" @@ -80,7 +80,7 @@ func Test_Hijacker(t *testing.T) { h := &Hijacker{ connectToRecorder: func(context.Context, []netip.AddrPort, - sessionrecording.DialFunc, + netx.DialFunc, ) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) { if tt.failRecorderConnect { err = errors.New("test") diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go index 11c6bf14c..b005cfff6 100644 --- a/logpolicy/logpolicy.go +++ b/logpolicy/logpolicy.go @@ -42,6 +42,7 @@ import ( "tailscale.com/net/netknob" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/netx" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" "tailscale.com/paths" @@ -769,7 +770,7 @@ func (p *Policy) Shutdown(ctx context.Context) error { // // The netMon parameter is optional. It should be specified in environments where // Tailscaled is manipulating the routing table. -func MakeDialFunc(netMon *netmon.Monitor, logf logger.Logf) func(ctx context.Context, netw, addr string) (net.Conn, error) { +func MakeDialFunc(netMon *netmon.Monitor, logf logger.Logf) netx.DialFunc { if netMon == nil { netMon = netmon.NewStatic() } diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index c00dea1ae..c7b9439e6 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -31,6 +31,7 @@ import ( "tailscale.com/net/dnscache" "tailscale.com/net/neterror" "tailscale.com/net/netmon" + "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tsdial" "tailscale.com/types/dnstype" @@ -739,7 +740,7 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn return out, nil } -func (f *forwarder) getDialerType() dnscache.DialContextFunc { +func (f *forwarder) getDialerType() netx.DialFunc { if f.controlKnobs != nil && f.controlKnobs.UserDialUseRoutes.Load() { // It is safe to use UserDial as it dials external servers without going through Tailscale // and closes connections on interface change in the same way as SystemDial does, diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 2cbea6c0f..96550cbb1 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -19,6 +19,7 @@ import ( "time" "tailscale.com/envknob" + "tailscale.com/net/netx" "tailscale.com/types/logger" "tailscale.com/util/cloudenv" "tailscale.com/util/singleflight" @@ -355,10 +356,8 @@ func (r *Resolver) addIPCache(host string, ip, ip6 netip.Addr, allIPs []netip.Ad } } -type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) - // Dialer returns a wrapped DialContext func that uses the provided dnsCache. -func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { +func Dialer(fwd netx.DialFunc, dnsCache *Resolver) netx.DialFunc { d := &dialer{ fwd: fwd, dnsCache: dnsCache, @@ -369,7 +368,7 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { // dialer is the config and accumulated state for a dial func returned by Dialer. type dialer struct { - fwd DialContextFunc + fwd netx.DialFunc dnsCache *Resolver mu sync.Mutex @@ -653,7 +652,7 @@ func v6addrs(aa []netip.Addr) (ret []netip.Addr) { // TLSDialer is like Dialer but returns a func suitable for using with net/http.Transport.DialTLSContext. // It returns a *tls.Conn type on success. // On TLS cert validation failure, it can invoke a backup DNS resolution strategy. -func TLSDialer(fwd DialContextFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) DialContextFunc { +func TLSDialer(fwd netx.DialFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) netx.DialFunc { tcpDialer := Dialer(fwd, dnsCache) return func(ctx context.Context, network, address string) (net.Conn, error) { host, _, err := net.SplitHostPort(address) diff --git a/net/memnet/memnet.go b/net/memnet/memnet.go index c8799bc17..7c2435684 100644 --- a/net/memnet/memnet.go +++ b/net/memnet/memnet.go @@ -6,3 +6,82 @@ // in tests and other situations where you don't want to use the // network. package memnet + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "tailscale.com/net/netx" +) + +var _ netx.Network = (*Network)(nil) + +// Network implements [Network] using an in-memory network, usually +// used for testing. +// +// As of 2025-04-08, it only supports TCP. +// +// Its zero value is a valid [netx.Network] implementation. +type Network struct { + mu sync.Mutex + lns map[string]*Listener // address -> listener +} + +func (m *Network) Listen(network, address string) (net.Listener, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network) + } + ap, err := netip.ParseAddrPort(address) + if err != nil { + return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err) + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.lns == nil { + m.lns = make(map[string]*Listener) + } + port := ap.Port() + for { + if port == 0 { + port = 33000 + } + key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port)) + _, ok := m.lns[key] + if ok { + if ap.Port() != 0 { + return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address) + } + port++ + continue + } + ln := Listen(key) + m.lns[key] = ln + return ln, nil + } +} + +func (m *Network) NewLocalTCPListener() net.Listener { + ln, err := m.Listen("tcp", "127.0.0.1:0") + if err != nil { + panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err)) + } + return ln +} + +func (m *Network) Dial(ctx context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network) + } + m.mu.Lock() + ln, ok := m.lns[address] + m.mu.Unlock() + if !ok { + return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address) + } + return ln.Dial(ctx, network, address) +} diff --git a/net/netx/netx.go b/net/netx/netx.go index 0be277a15..014daa9a7 100644 --- a/net/netx/netx.go +++ b/net/netx/netx.go @@ -1,23 +1,25 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package netx contains the Network type to abstract over either a real -// network or a virtual network for testing. +// Package netx contains types to describe and abstract over how dialing and +// listening are performed. package netx import ( "context" "fmt" "net" - "net/netip" - "sync" - - "tailscale.com/net/memnet" ) +// DialFunc is a function that dials a network address. +// +// It's the type implemented by net.Dialer.DialContext or required +// by net/http.Transport.DialContext, etc. +type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) + // Network describes a network that can listen and dial. The two common // implementations are [RealNetwork], using the net package to use the real -// network, or [MemNetwork], using an in-memory network (typically for testing) +// network, or [memnet.Network], using an in-memory network (typically for testing) type Network interface { NewLocalTCPListener() net.Listener Listen(network, address string) (net.Listener, error) @@ -44,77 +46,8 @@ func (realNetwork) NewLocalTCPListener() net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil { - panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + panic(fmt.Sprintf("failed to listen on either IPv4 or IPv6 localhost port: %v", err)) } } return ln } - -// MemNetwork returns a Network implementation that uses an in-memory -// network for testing. It is only suitable for tests that do not -// require real network access. -// -// As of 2025-04-08, it only supports TCP. -func MemNetwork() Network { return &memNetwork{} } - -// memNetwork implements [Network] using an in-memory network. -type memNetwork struct { - mu sync.Mutex - lns map[string]*memnet.Listener // address -> listener -} - -func (m *memNetwork) Listen(network, address string) (net.Listener, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network) - } - ap, err := netip.ParseAddrPort(address) - if err != nil { - return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err) - } - - m.mu.Lock() - defer m.mu.Unlock() - - if m.lns == nil { - m.lns = make(map[string]*memnet.Listener) - } - port := ap.Port() - for { - if port == 0 { - port = 33000 - } - key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port)) - _, ok := m.lns[key] - if ok { - if ap.Port() != 0 { - return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address) - } - port++ - continue - } - ln := memnet.Listen(key) - m.lns[key] = ln - return ln, nil - } -} - -func (m *memNetwork) NewLocalTCPListener() net.Listener { - ln, err := m.Listen("tcp", "127.0.0.1:0") - if err != nil { - panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err)) - } - return ln -} - -func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network) - } - m.mu.Lock() - ln, ok := m.lns[address] - m.mu.Unlock() - if !ok { - return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address) - } - return ln.Dial(ctx, network, address) -} diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index 8fddd63f2..1188a3077 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -23,6 +23,7 @@ import ( "tailscale.com/net/netknob" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/netx" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" "tailscale.com/types/netmap" @@ -71,7 +72,7 @@ type Dialer struct { netnsDialerOnce sync.Once netnsDialer netns.Dialer - sysDialForTest func(_ context.Context, network, addr string) (net.Conn, error) // or nil + sysDialForTest netx.DialFunc // or nil routes atomic.Pointer[bart.Table[bool]] // or nil if UserDial should not use routes. `true` indicates routes that point into the Tailscale interface @@ -364,7 +365,7 @@ func (d *Dialer) logf(format string, args ...any) { // SetSystemDialerForTest sets an alternate function to use for SystemDial // instead of netns.Dialer. This is intended for use with nettest.MemoryNetwork. -func (d *Dialer) SetSystemDialerForTest(fn func(ctx context.Context, network, addr string) (net.Conn, error)) { +func (d *Dialer) SetSystemDialerForTest(fn netx.DialFunc) { testenv.AssertInTest() d.sysDialForTest = fn } diff --git a/sessionrecording/connect.go b/sessionrecording/connect.go index 94761393f..dc697d071 100644 --- a/sessionrecording/connect.go +++ b/sessionrecording/connect.go @@ -20,6 +20,7 @@ import ( "time" "golang.org/x/net/http2" + "tailscale.com/net/netx" "tailscale.com/tailcfg" "tailscale.com/util/httpm" "tailscale.com/util/multierr" @@ -40,9 +41,6 @@ const ( // in tests. var uploadAckWindow = 30 * time.Second -// DialFunc is a function for dialing the recorder. -type DialFunc func(ctx context.Context, network, host string) (net.Conn, error) - // ConnectToRecorder connects to the recorder at any of the provided addresses. // It returns the first successful response, or a multierr if all attempts fail. // @@ -55,7 +53,7 @@ type DialFunc func(ctx context.Context, network, host string) (net.Conn, error) // attempts are in order the recorder(s) was attempted. If successful a // successful connection is made, the last attempt in the slice is the // attempt for connected recorder. -func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { +func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial netx.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { if len(recs) == 0 { return nil, nil, nil, errors.New("no recorders configured") } @@ -293,7 +291,7 @@ func (u *readCounter) Read(buf []byte) (int, error) { // clientHTTP1 returns a claassic http.Client with a per-dial context. It uses // dialCtx and adds a 5s timeout to it. -func clientHTTP1(dialCtx context.Context, dial DialFunc) *http.Client { +func clientHTTP1(dialCtx context.Context, dial netx.DialFunc) *http.Client { tr := http.DefaultTransport.(*http.Transport).Clone() tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout) @@ -313,7 +311,7 @@ func clientHTTP1(dialCtx context.Context, dial DialFunc) *http.Client { // clientHTTP2 is like clientHTTP1 but returns an http.Client suitable for h2c // requests (HTTP/2 over plaintext). Unfortunately the same client does not // work for HTTP/1 so we need to split these up. -func clientHTTP2(dialCtx context.Context, dial DialFunc) *http.Client { +func clientHTTP2(dialCtx context.Context, dial netx.DialFunc) *http.Client { return &http.Client{ Transport: &http2.Transport{ // Allow "http://" scheme in URLs. diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index e3ecf0f75..1fa170d87 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -54,6 +54,7 @@ import ( "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/netutil" + "tailscale.com/net/netx" "tailscale.com/net/stun" "tailscale.com/syncs" "tailscale.com/tailcfg" @@ -649,7 +650,7 @@ type Server struct { mu sync.Mutex agentConnWaiter map[*node]chan<- struct{} // signaled after added to set agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all - agentDialer map[*node]DialFunc + agentDialer map[*node]netx.DialFunc } func (s *Server) logf(format string, args ...any) { @@ -664,8 +665,6 @@ func (s *Server) SetLoggerForTest(logf func(format string, args ...any)) { s.optLogf = logf } -type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) - var derpMap = &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: { @@ -2130,7 +2129,7 @@ type NodeAgentClient struct { HTTPClient *http.Client } -func (s *Server) NodeAgentDialer(n *Node) DialFunc { +func (s *Server) NodeAgentDialer(n *Node) netx.DialFunc { s.mu.Lock() defer s.mu.Unlock() diff --git a/tstest/nettest/nettest.go b/tstest/nettest/nettest.go index 98662fe39..c78677dd4 100644 --- a/tstest/nettest/nettest.go +++ b/tstest/nettest/nettest.go @@ -14,6 +14,7 @@ import ( "sync" "testing" + "tailscale.com/net/memnet" "tailscale.com/net/netmon" "tailscale.com/net/netx" "tailscale.com/util/testenv" @@ -42,7 +43,7 @@ func PreferMemNetwork() bool { func GetNetwork(tb testing.TB) netx.Network { var n netx.Network if PreferMemNetwork() { - n = netx.MemNetwork() + n = &memnet.Network{} } else { n = netx.RealNetwork() } diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 591bedde4..04bab0cf9 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -38,6 +38,7 @@ import ( "tailscale.com/net/dns" "tailscale.com/net/ipset" "tailscale.com/net/netaddr" + "tailscale.com/net/netx" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" @@ -208,7 +209,7 @@ type Impl struct { // TCP connection to another host (e.g. in subnet router mode). // // This is currently only used in tests. - forwardDialFunc func(context.Context, string, string) (net.Conn, error) + forwardDialFunc netx.DialFunc // forwardInFlightPerClientDropped is a metric that tracks how many // in-flight TCP forward requests were dropped due to the per-client @@ -1457,7 +1458,7 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. }() // Attempt to dial the outbound connection before we accept the inbound one. - var dialFunc func(context.Context, string, string) (net.Conn, error) + var dialFunc netx.DialFunc if ns.forwardDialFunc != nil { dialFunc = ns.forwardDialFunc } else { diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index 823acee91..79a380e84 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -22,6 +22,7 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/metrics" + "tailscale.com/net/netx" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" @@ -512,7 +513,7 @@ func tcp4syn(tb testing.TB, src, dst netip.Addr, sport, dport uint16) []byte { // makeHangDialer returns a dialer that notifies the returned channel when a // connection is dialed and then hangs until the test finishes. -func makeHangDialer(tb testing.TB) (func(context.Context, string, string) (net.Conn, error), chan struct{}) { +func makeHangDialer(tb testing.TB) (netx.DialFunc, chan struct{}) { done := make(chan struct{}) tb.Cleanup(func() { close(done)