control/controlhttp: move Dial options into options struct (#5661)

This turns 'dialParams' into something more like net.Dialer, where
configuration fields are public on the struct.

Split out of #5648

Change-Id: I0c56fd151dc5489c3c94fb40d18fd639e06473bc
Signed-off-by: Andrew Dunham <andrew@tailscale.com>
pull/5667/head
Andrew Dunham 2 years ago committed by GitHub
parent 5623ef0271
commit 9b71008ef2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -489,7 +489,15 @@ func runTS2021(ctx context.Context, args []string) error {
return c, err return c, err
} }
conn, err := controlhttp.Dial(ctx, ts2021Args.host, "80", "443", machinePrivate, keys.PublicKey, uint16(ts2021Args.version), dialFunc) conn, err := (&controlhttp.Dialer{
Hostname: ts2021Args.host,
HTTPPort: "80",
HTTPSPort: "443",
MachineKey: machinePrivate,
ControlKey: keys.PublicKey,
ProtocolVersion: uint16(ts2021Args.version),
Dialer: dialFunc,
}).Dial(ctx)
log.Printf("controlhttp.Dial = %p, %v", conn, err) log.Printf("controlhttp.Dial = %p, %v", conn, err)
if err != nil { if err != nil {
return err return err

@ -165,7 +165,15 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
// thousand version numbers before getting to this point. // thousand version numbers before getting to this point.
panic("capability version is too high to fit in the wire protocol") panic("capability version is too high to fit in the wire protocol")
} }
conn, err := controlhttp.Dial(ctx, nc.host, nc.httpPort, nc.httpsPort, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion), nc.dialer.SystemDial) conn, err := (&controlhttp.Dialer{
Hostname: nc.host,
HTTPPort: nc.httpPort,
HTTPSPort: nc.httpsPort,
MachineKey: nc.privKey,
ControlKey: nc.serverPubKey,
ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion),
Dialer: nc.dialer.SystemDial,
}).Dial(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -40,57 +40,49 @@ import (
"tailscale.com/net/netutil" "tailscale.com/net/netutil"
"tailscale.com/net/tlsdial" "tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy" "tailscale.com/net/tshttpproxy"
"tailscale.com/types/key"
) )
// Dial connects to the HTTP server at host:httpPort, requests to switch to the var stdDialer net.Dialer
// Tailscale control protocol, and returns an established control
// Dial connects to the HTTP server at this Dialer's Host:HTTPPort, requests to
// switch to the Tailscale control protocol, and returns an established control
// protocol connection. // protocol connection.
// //
// If Dial fails to connect using addr, it also tries to tunnel over // If Dial fails to connect using HTTP, it also tries to tunnel over TLS to the
// TLS to host:httpsPort as a compatibility fallback. // Dialer's Host:HTTPSPort as a compatibility fallback.
// //
// The provided ctx is only used for the initial connection, until // The provided ctx is only used for the initial connection, until
// Dial returns. It does not affect the connection once established. // Dial returns. It does not affect the connection once established.
func Dial(ctx context.Context, host string, httpPort string, httpsPort string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) { func (a *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
a := &dialParams{ if a.Hostname == "" {
host: host, return nil, errors.New("required Dialer.Hostname empty")
httpPort: httpPort,
httpsPort: httpsPort,
machineKey: machineKey,
controlKey: controlKey,
version: protocolVersion,
proxyFunc: tshttpproxy.ProxyFromEnvironment,
dialer: dialer,
} }
return a.dial(ctx) return a.dial(ctx)
} }
type dialParams struct { func (a *Dialer) logf(format string, args ...any) {
host string if a.Logf != nil {
httpPort string a.Logf(format, args...)
httpsPort string }
machineKey key.MachinePrivate }
controlKey key.MachinePublic
version uint16
proxyFunc func(*http.Request) (*url.URL, error) // or nil
dialer dnscache.DialContextFunc
// For tests only func (a *Dialer) getProxyFunc() func(*http.Request) (*url.URL, error) {
insecureTLS bool if a.proxyFunc != nil {
testFallbackDelay time.Duration return a.proxyFunc
}
return tshttpproxy.ProxyFromEnvironment
} }
// httpsFallbackDelay is how long we'll wait for a.httpPort to work before // httpsFallbackDelay is how long we'll wait for a.HTTPPort to work before
// starting to try a.httpsPort. // starting to try a.HTTPSPort.
func (a *dialParams) httpsFallbackDelay() time.Duration { func (a *Dialer) httpsFallbackDelay() time.Duration {
if v := a.testFallbackDelay; v != 0 { if v := a.testFallbackDelay; v != 0 {
return v return v
} }
return 500 * time.Millisecond return 500 * time.Millisecond
} }
func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) { func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
// Create one shared context used by both port 80 and port 443 dials. // Create one shared context used by both port 80 and port 443 dials.
// If port 80 is still in flight when 443 returns, this deferred cancel // If port 80 is still in flight when 443 returns, this deferred cancel
// will stop the port 80 dial. // will stop the port 80 dial.
@ -102,12 +94,12 @@ func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
// we'll speak Noise. // we'll speak Noise.
u80 := &url.URL{ u80 := &url.URL{
Scheme: "http", Scheme: "http",
Host: net.JoinHostPort(a.host, a.httpPort), Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPPort, "80")),
Path: serverUpgradePath, Path: serverUpgradePath,
} }
u443 := &url.URL{ u443 := &url.URL{
Scheme: "https", Scheme: "https",
Host: net.JoinHostPort(a.host, a.httpsPort), Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPSPort, "443")),
Path: serverUpgradePath, Path: serverUpgradePath,
} }
@ -169,8 +161,8 @@ func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
} }
// dialURL attempts to connect to the given URL. // dialURL attempts to connect to the given URL.
func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) { func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version) init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -189,26 +181,34 @@ func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
// //
// Only the provided ctx is used, not a.ctx. // Only the provided ctx is used, not a.ctx.
func (a *dialParams) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) { func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
dns := &dnscache.Resolver{ dns := &dnscache.Resolver{
Forward: dnscache.Get().Forward, Forward: dnscache.Get().Forward,
LookupIPFallback: dnsfallback.Lookup, LookupIPFallback: dnsfallback.Lookup,
UseLastGood: true, UseLastGood: true,
} }
var dialer dnscache.DialContextFunc
if a.Dialer != nil {
dialer = a.Dialer
} else {
dialer = stdDialer.DialContext
}
tr := http.DefaultTransport.(*http.Transport).Clone() tr := http.DefaultTransport.(*http.Transport).Clone()
defer tr.CloseIdleConnections() defer tr.CloseIdleConnections()
tr.Proxy = a.proxyFunc tr.Proxy = a.getProxyFunc()
tshttpproxy.SetTransportGetProxyConnectHeader(tr) tshttpproxy.SetTransportGetProxyConnectHeader(tr)
tr.DialContext = dnscache.Dialer(a.dialer, dns) tr.DialContext = dnscache.Dialer(dialer, dns)
// Disable HTTP2, since h2 can't do protocol switching. // Disable HTTP2, since h2 can't do protocol switching.
tr.TLSClientConfig.NextProtos = []string{} tr.TLSClientConfig.NextProtos = []string{}
tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
tr.TLSClientConfig = tlsdial.Config(a.host, tr.TLSClientConfig) tr.TLSClientConfig = tlsdial.Config(a.Hostname, tr.TLSClientConfig)
if a.insecureTLS { if a.insecureTLS {
tr.TLSClientConfig.InsecureSkipVerify = true tr.TLSClientConfig.InsecureSkipVerify = true
tr.TLSClientConfig.VerifyConnection = nil tr.TLSClientConfig.VerifyConnection = nil
} }
tr.DialTLSContext = dnscache.TLSDialer(a.dialer, dns, tr.TLSClientConfig) tr.DialTLSContext = dnscache.TLSDialer(dialer, dns, tr.TLSClientConfig)
tr.DisableCompression = true tr.DisableCompression = true
// (mis)use httptrace to extract the underlying net.Conn from the // (mis)use httptrace to extract the underlying net.Conn from the

@ -7,27 +7,31 @@ package controlhttp
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"errors"
"net" "net"
"net/url" "net/url"
"nhooyr.io/websocket" "nhooyr.io/websocket"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/net/dnscache"
"tailscale.com/types/key"
) )
// Variant of Dial that tunnels the request over WebSockets, since we cannot do // Variant of Dial that tunnels the request over WebSockets, since we cannot do
// bi-directional communication over an HTTP connection when in JS. // bi-directional communication over an HTTP connection when in JS.
func Dial(ctx context.Context, host string, httpPort string, httpsPort string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) { func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
init, cont, err := controlbase.ClientDeferred(machineKey, controlKey, protocolVersion) if d.Hostname == "" {
return nil, errors.New("required Dialer.Hostname empty")
}
init, cont, err := controlbase.ClientDeferred(d.MachineKey, d.ControlKey, d.ProtocolVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
wsScheme := "wss" wsScheme := "wss"
host := d.Hostname
if host == "localhost" { if host == "localhost" {
wsScheme = "ws" wsScheme = "ws"
host = net.JoinHostPort(host, httpPort) host = net.JoinHostPort(host, strDef(d.HTTPPort, "80"))
} }
wsURL := &url.URL{ wsURL := &url.URL{
Scheme: wsScheme, Scheme: wsScheme,
@ -52,5 +56,4 @@ func Dial(ctx context.Context, host string, httpPort string, httpsPort string, m
return nil, err return nil, err
} }
return cbConn, nil return cbConn, nil
} }

@ -4,6 +4,16 @@
package controlhttp package controlhttp
import (
"net/http"
"net/url"
"time"
"tailscale.com/net/dnscache"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
const ( const (
// upgradeHeader is the value of the Upgrade HTTP header used to // upgradeHeader is the value of the Upgrade HTTP header used to
// indicate the Tailscale control protocol. // indicate the Tailscale control protocol.
@ -18,3 +28,58 @@ const (
// to do the protocol switch is located. // to do the protocol switch is located.
serverUpgradePath = "/ts2021" serverUpgradePath = "/ts2021"
) )
// Dialer contains configuration on how to dial the Tailscale control server.
type Dialer struct {
// Hostname is the hostname to connect to, with no port number.
//
// This field is required.
Hostname string
// MachineKey contains the current machine's private key.
//
// This field is required.
MachineKey key.MachinePrivate
// ControlKey contains the expected public key for the control server.
//
// This field is required.
ControlKey key.MachinePublic
// ProtocolVersion is the expected protocol version to negotiate.
//
// This field is required.
ProtocolVersion uint16
// HTTPPort is the port number to use when making a HTTP connection.
//
// If not specified, this defaults to port 80.
HTTPPort string
// HTTPSPort is the port number to use when making a HTTPS connection.
//
// If not specified, this defaults to port 443.
HTTPSPort string
// Dialer is the dialer used to make outbound connections.
//
// If not specified, this defaults to net.Dialer.DialContext.
Dialer dnscache.DialContextFunc
// Logf, if set, is a logging function to use; if unset, logs are
// dropped.
Logf logger.Logf
proxyFunc func(*http.Request) (*url.URL, error) // or nil
// For tests only
insecureTLS bool
testFallbackDelay time.Duration
}
func strDef(v1, v2 string) string {
if v1 != "" {
return v1
}
return v2
}

@ -170,15 +170,16 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
defer cancel() defer cancel()
} }
a := dialParams{ a := &Dialer{
host: "localhost", Hostname: "localhost",
httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port), HTTPPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port), HTTPSPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
machineKey: client, MachineKey: client,
controlKey: server.Public(), ControlKey: server.Public(),
version: testProtocolVersion, ProtocolVersion: testProtocolVersion,
Dialer: new(tsdial.Dialer).SystemDial,
Logf: t.Logf,
insecureTLS: true, insecureTLS: true,
dialer: new(tsdial.Dialer).SystemDial,
testFallbackDelay: 50 * time.Millisecond, testFallbackDelay: 50 * time.Millisecond,
} }

Loading…
Cancel
Save