control/controlhttp: don't assume port 80 upgrade response will work

Just because we get an HTTP upgrade response over port 80, don't
assume we'll be able to do bi-di Noise over it. There might be a MITM
corp proxy or anti-virus/firewall interfering. Do a bit more work to
validate the connection before proceeding to give up on the TLS port
443 dial.

Updates #4557 (probably fixes)

Change-Id: I0e1bcc195af21ad3d360ffe79daead730dfd86f1
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/4561/head
Brad Fitzpatrick 2 years ago committed by Brad Fitzpatrick
parent 488e63979e
commit 1237000efe

@ -70,7 +70,6 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr
return nil, err
}
a := &dialParams{
ctx: ctx,
host: host,
httpPort: port,
httpsPort: "443",
@ -80,11 +79,10 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr
proxyFunc: tshttpproxy.ProxyFromEnvironment,
dialer: dialer,
}
return a.dial()
return a.dial(ctx)
}
type dialParams struct {
ctx context.Context
host string
httpPort string
httpsPort string
@ -95,14 +93,24 @@ type dialParams struct {
dialer dnscache.DialContextFunc
// For tests only
insecureTLS bool
insecureTLS bool
testFallbackDelay time.Duration
}
func (a *dialParams) dial() (*controlbase.Conn, error) {
// httpsFallbackDelay is how long we'll wait for a.httpPort to work before
// starting to try a.httpsPort.
func (a *dialParams) httpsFallbackDelay() time.Duration {
if v := a.testFallbackDelay; v != 0 {
return v
}
return 500 * time.Millisecond
}
func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
// Create one shared context used by both port 80 and port 443 dials.
// If port 80 is still in flight when 443 returns, this deferred cancel
// will stop the port 80 dial.
ctx, cancel := context.WithCancel(a.ctx)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS,
@ -118,26 +126,20 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
Host: net.JoinHostPort(a.host, a.httpsPort),
Path: serverUpgradePath,
}
type tryURLRes struct {
u *url.URL
conn net.Conn
cont controlbase.HandshakeContinuation
u *url.URL // input (the URL conn+err are for/from)
conn *controlbase.Conn // result (mutually exclusive with err)
err error
}
ch := make(chan tryURLRes) // must be unbuffered
try := func(u *url.URL) {
res := tryURLRes{u: u}
var init []byte
init, res.cont, res.err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
if res.err == nil {
res.conn, res.err = a.tryURL(ctx, u, init)
}
cbConn, err := a.dialURL(ctx, u)
select {
case ch <- res:
case ch <- tryURLRes{u, cbConn, err}:
case <-ctx.Done():
if res.conn != nil {
res.conn.Close()
if cbConn != nil {
cbConn.Close()
}
}
}
@ -147,7 +149,7 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
// In case outbound port 80 blocked or MITM'ed poorly, start a backup timer
// to dial port 443 if port 80 doesn't either succeed or fail quickly.
try443Timer := time.AfterFunc(500*time.Millisecond, func() { try(u443) })
try443Timer := time.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) })
defer try443Timer.Stop()
var err80, err443 error
@ -157,12 +159,7 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err())
case res := <-ch:
if res.err == nil {
ret, err := res.cont(ctx, res.conn)
if err != nil {
res.conn.Close()
return nil, err
}
return ret, nil
return res.conn, nil
}
switch res.u {
case u80:
@ -187,10 +184,28 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
}
}
// tryURL connects to u, and tries to upgrade it to a net.Conn.
// dialURL attempts to connect to the given URL.
func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
if err != nil {
return nil, err
}
netConn, err := a.tryURLUpgrade(ctx, u, init)
if err != nil {
return nil, err
}
cbConn, err := cont(ctx, netConn)
if err != nil {
netConn.Close()
return nil, err
}
return cbConn, nil
}
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
//
// Only the provided ctx is used, not a.ctx.
func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
func (a *dialParams) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
dns := &dnscache.Resolver{
Forward: dnscache.Get().Forward,
LookupIPFallback: dnsfallback.Lookup,

@ -17,6 +17,7 @@ import (
"strconv"
"sync"
"testing"
"time"
"tailscale.com/control/controlbase"
"tailscale.com/net/socks5"
@ -24,16 +25,28 @@ import (
"tailscale.com/types/key"
)
type httpTestParam struct {
name string
proxy proxy
// makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
// 101 switching protocols.
makeHTTPHangAfterUpgrade bool
}
func TestControlHTTP(t *testing.T) {
tests := []struct {
name string
proxy proxy
}{
tests := []httpTestParam{
// direct connection
{
name: "no_proxy",
proxy: nil,
},
// direct connection but port 80 is MITM'ed and broken
{
name: "port80_broken_mitm",
proxy: nil,
makeHTTPHangAfterUpgrade: true,
},
// SOCKS5
{
name: "socks5",
@ -97,12 +110,13 @@ func TestControlHTTP(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
testControlHTTP(t, test.proxy)
testControlHTTP(t, test)
})
}
}
func testControlHTTP(t *testing.T, proxy proxy) {
func testControlHTTP(t *testing.T, param httpTestParam) {
proxy := param.proxy
client, server := key.NewMachine(), key.NewMachine()
const testProtocolVersion = 1
@ -133,7 +147,11 @@ func testControlHTTP(t *testing.T, proxy proxy) {
t.Fatalf("HTTPS listen: %v", err)
}
httpServer := &http.Server{Handler: handler}
var httpHandler http.Handler = handler
if param.makeHTTPHangAfterUpgrade {
httpHandler = http.HandlerFunc(brokenMITMHandler)
}
httpServer := &http.Server{Handler: httpHandler}
go httpServer.Serve(httpLn)
defer httpServer.Close()
@ -144,19 +162,24 @@ func testControlHTTP(t *testing.T, proxy proxy) {
go httpsServer.ServeTLS(httpsLn, "", "")
defer httpsServer.Close()
//ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
//defer cancel()
ctx := context.Background()
const debugTimeout = false
if debugTimeout {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
}
a := dialParams{
ctx: context.Background(), //ctx,
host: "localhost",
httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
machineKey: client,
controlKey: server.Public(),
version: testProtocolVersion,
insecureTLS: true,
dialer: new(tsdial.Dialer).SystemDial,
host: "localhost",
httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
machineKey: client,
controlKey: server.Public(),
version: testProtocolVersion,
insecureTLS: true,
dialer: new(tsdial.Dialer).SystemDial,
testFallbackDelay: 50 * time.Millisecond,
}
if proxy != nil {
@ -175,7 +198,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
}
}
conn, err := a.dial()
conn, err := a.dial(ctx)
if err != nil {
t.Fatalf("dialing controlhttp: %v", err)
}
@ -217,6 +240,7 @@ type proxy interface {
type socksProxy struct {
sync.Mutex
closed bool
proxy socks5.Server
ln net.Listener
clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
@ -232,7 +256,14 @@ func (s *socksProxy) Start(t *testing.T) (url string) {
}
s.ln = ln
s.clientConnAddrs = map[string]bool{}
s.proxy.Logf = t.Logf
s.proxy.Logf = func(format string, a ...any) {
s.Lock()
defer s.Unlock()
if s.closed {
return
}
t.Logf(format, a...)
}
s.proxy.Dialer = s.dialAndRecord
go s.proxy.Serve(ln)
return fmt.Sprintf("socks5://%s", ln.Addr().String())
@ -241,6 +272,10 @@ func (s *socksProxy) Start(t *testing.T) (url string) {
func (s *socksProxy) Close() {
s.Lock()
defer s.Unlock()
if s.closed {
return
}
s.closed = true
s.ln.Close()
}
@ -400,3 +435,11 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
Certificates: []tls.Certificate{cert},
}
}
func brokenMITMHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Upgrade", upgradeHeaderValue)
w.Header().Set("Connection", "upgrade")
w.WriteHeader(http.StatusSwitchingProtocols)
w.(http.Flusher).Flush()
<-r.Context().Done()
}

Loading…
Cancel
Save