diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index 63b933803..0d4518efe 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -489,7 +489,15 @@ func runTS2021(ctx context.Context, args []string) error { return c, err } - conn, err := controlhttp.Dial(ctx, ts2021Args.host, "80", "443", machinePrivate, keys.PublicKey, uint16(ts2021Args.version), dialFunc) + conn, err := (&controlhttp.Dialer{ + Hostname: ts2021Args.host, + HTTPPort: "80", + HTTPSPort: "443", + MachineKey: machinePrivate, + ControlKey: keys.PublicKey, + ProtocolVersion: uint16(ts2021Args.version), + Dialer: dialFunc, + }).Dial(ctx) log.Printf("controlhttp.Dial = %p, %v", conn, err) if err != nil { return err diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index f5b6b12a3..29f2b025c 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -165,7 +165,15 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) { // thousand version numbers before getting to this point. panic("capability version is too high to fit in the wire protocol") } - conn, err := controlhttp.Dial(ctx, nc.host, nc.httpPort, nc.httpsPort, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion), nc.dialer.SystemDial) + conn, err := (&controlhttp.Dialer{ + Hostname: nc.host, + HTTPPort: nc.httpPort, + HTTPSPort: nc.httpsPort, + MachineKey: nc.privKey, + ControlKey: nc.serverPubKey, + ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion), + Dialer: nc.dialer.SystemDial, + }).Dial(ctx) if err != nil { return nil, err } diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 7717e931d..3f987a692 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -40,57 +40,49 @@ import ( "tailscale.com/net/netutil" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" - "tailscale.com/types/key" ) -// Dial connects to the HTTP server at host:httpPort, requests to switch to the -// Tailscale control protocol, and returns an established control +var stdDialer net.Dialer + +// Dial connects to the HTTP server at this Dialer's Host:HTTPPort, requests to +// switch to the Tailscale control protocol, and returns an established control // protocol connection. // -// If Dial fails to connect using addr, it also tries to tunnel over -// TLS to host:httpsPort as a compatibility fallback. +// If Dial fails to connect using HTTP, it also tries to tunnel over TLS to the +// Dialer's Host:HTTPSPort as a compatibility fallback. // // The provided ctx is only used for the initial connection, until // Dial returns. It does not affect the connection once established. -func Dial(ctx context.Context, host string, httpPort string, httpsPort string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) { - a := &dialParams{ - host: host, - httpPort: httpPort, - httpsPort: httpsPort, - machineKey: machineKey, - controlKey: controlKey, - version: protocolVersion, - proxyFunc: tshttpproxy.ProxyFromEnvironment, - dialer: dialer, +func (a *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) { + if a.Hostname == "" { + return nil, errors.New("required Dialer.Hostname empty") } return a.dial(ctx) } -type dialParams struct { - host string - httpPort string - httpsPort string - machineKey key.MachinePrivate - controlKey key.MachinePublic - version uint16 - proxyFunc func(*http.Request) (*url.URL, error) // or nil - dialer dnscache.DialContextFunc +func (a *Dialer) logf(format string, args ...any) { + if a.Logf != nil { + a.Logf(format, args...) + } +} - // For tests only - insecureTLS bool - testFallbackDelay time.Duration +func (a *Dialer) getProxyFunc() func(*http.Request) (*url.URL, error) { + if a.proxyFunc != nil { + return a.proxyFunc + } + return tshttpproxy.ProxyFromEnvironment } -// httpsFallbackDelay is how long we'll wait for a.httpPort to work before -// starting to try a.httpsPort. -func (a *dialParams) httpsFallbackDelay() time.Duration { +// httpsFallbackDelay is how long we'll wait for a.HTTPPort to work before +// starting to try a.HTTPSPort. +func (a *Dialer) httpsFallbackDelay() time.Duration { if v := a.testFallbackDelay; v != 0 { return v } return 500 * time.Millisecond } -func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) { +func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) { // Create one shared context used by both port 80 and port 443 dials. // If port 80 is still in flight when 443 returns, this deferred cancel // will stop the port 80 dial. @@ -102,12 +94,12 @@ func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) { // we'll speak Noise. u80 := &url.URL{ Scheme: "http", - Host: net.JoinHostPort(a.host, a.httpPort), + Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPPort, "80")), Path: serverUpgradePath, } u443 := &url.URL{ Scheme: "https", - Host: net.JoinHostPort(a.host, a.httpsPort), + Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPSPort, "443")), Path: serverUpgradePath, } @@ -169,8 +161,8 @@ func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) { } // dialURL attempts to connect to the given URL. -func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) { - init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version) +func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) { + init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion) if err != nil { return nil, err } @@ -189,26 +181,34 @@ func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. // // Only the provided ctx is used, not a.ctx. -func (a *dialParams) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) { +func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) { dns := &dnscache.Resolver{ Forward: dnscache.Get().Forward, LookupIPFallback: dnsfallback.Lookup, UseLastGood: true, } + + var dialer dnscache.DialContextFunc + if a.Dialer != nil { + dialer = a.Dialer + } else { + dialer = stdDialer.DialContext + } + tr := http.DefaultTransport.(*http.Transport).Clone() defer tr.CloseIdleConnections() - tr.Proxy = a.proxyFunc + tr.Proxy = a.getProxyFunc() tshttpproxy.SetTransportGetProxyConnectHeader(tr) - tr.DialContext = dnscache.Dialer(a.dialer, dns) + tr.DialContext = dnscache.Dialer(dialer, dns) // Disable HTTP2, since h2 can't do protocol switching. tr.TLSClientConfig.NextProtos = []string{} tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} - tr.TLSClientConfig = tlsdial.Config(a.host, tr.TLSClientConfig) + tr.TLSClientConfig = tlsdial.Config(a.Hostname, tr.TLSClientConfig) if a.insecureTLS { tr.TLSClientConfig.InsecureSkipVerify = true tr.TLSClientConfig.VerifyConnection = nil } - tr.DialTLSContext = dnscache.TLSDialer(a.dialer, dns, tr.TLSClientConfig) + tr.DialTLSContext = dnscache.TLSDialer(dialer, dns, tr.TLSClientConfig) tr.DisableCompression = true // (mis)use httptrace to extract the underlying net.Conn from the diff --git a/control/controlhttp/client_js.go b/control/controlhttp/client_js.go index 850fd4de9..3b1ad4151 100644 --- a/control/controlhttp/client_js.go +++ b/control/controlhttp/client_js.go @@ -7,27 +7,31 @@ package controlhttp import ( "context" "encoding/base64" + "errors" "net" "net/url" "nhooyr.io/websocket" "tailscale.com/control/controlbase" - "tailscale.com/net/dnscache" - "tailscale.com/types/key" ) // Variant of Dial that tunnels the request over WebSockets, since we cannot do // bi-directional communication over an HTTP connection when in JS. -func Dial(ctx context.Context, host string, httpPort string, httpsPort string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) { - init, cont, err := controlbase.ClientDeferred(machineKey, controlKey, protocolVersion) +func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) { + if d.Hostname == "" { + return nil, errors.New("required Dialer.Hostname empty") + } + + init, cont, err := controlbase.ClientDeferred(d.MachineKey, d.ControlKey, d.ProtocolVersion) if err != nil { return nil, err } wsScheme := "wss" + host := d.Hostname if host == "localhost" { wsScheme = "ws" - host = net.JoinHostPort(host, httpPort) + host = net.JoinHostPort(host, strDef(d.HTTPPort, "80")) } wsURL := &url.URL{ Scheme: wsScheme, @@ -52,5 +56,4 @@ func Dial(ctx context.Context, host string, httpPort string, httpsPort string, m return nil, err } return cbConn, nil - } diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index 216adc269..b254e5be6 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -4,6 +4,16 @@ package controlhttp +import ( + "net/http" + "net/url" + "time" + + "tailscale.com/net/dnscache" + "tailscale.com/types/key" + "tailscale.com/types/logger" +) + const ( // upgradeHeader is the value of the Upgrade HTTP header used to // indicate the Tailscale control protocol. @@ -18,3 +28,58 @@ const ( // to do the protocol switch is located. serverUpgradePath = "/ts2021" ) + +// Dialer contains configuration on how to dial the Tailscale control server. +type Dialer struct { + // Hostname is the hostname to connect to, with no port number. + // + // This field is required. + Hostname string + + // MachineKey contains the current machine's private key. + // + // This field is required. + MachineKey key.MachinePrivate + + // ControlKey contains the expected public key for the control server. + // + // This field is required. + ControlKey key.MachinePublic + + // ProtocolVersion is the expected protocol version to negotiate. + // + // This field is required. + ProtocolVersion uint16 + + // HTTPPort is the port number to use when making a HTTP connection. + // + // If not specified, this defaults to port 80. + HTTPPort string + + // HTTPSPort is the port number to use when making a HTTPS connection. + // + // If not specified, this defaults to port 443. + HTTPSPort string + + // Dialer is the dialer used to make outbound connections. + // + // If not specified, this defaults to net.Dialer.DialContext. + Dialer dnscache.DialContextFunc + + // Logf, if set, is a logging function to use; if unset, logs are + // dropped. + Logf logger.Logf + + proxyFunc func(*http.Request) (*url.URL, error) // or nil + + // For tests only + insecureTLS bool + testFallbackDelay time.Duration +} + +func strDef(v1, v2 string) string { + if v1 != "" { + return v1 + } + return v2 +} diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 545b4c303..b4f10eabf 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -170,15 +170,16 @@ func testControlHTTP(t *testing.T, param httpTestParam) { defer cancel() } - a := dialParams{ - host: "localhost", - httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port), - httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port), - machineKey: client, - controlKey: server.Public(), - version: testProtocolVersion, + a := &Dialer{ + Hostname: "localhost", + HTTPPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port), + HTTPSPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port), + MachineKey: client, + ControlKey: server.Public(), + ProtocolVersion: testProtocolVersion, + Dialer: new(tsdial.Dialer).SystemDial, + Logf: t.Logf, insecureTLS: true, - dialer: new(tsdial.Dialer).SystemDial, testFallbackDelay: 50 * time.Millisecond, }