control/controlhttp: start port 443 fallback sooner if 80's stuck

Fixes #4544

Change-Id: I39877e71915ad48c6668351c45cd8e33e2f5dbae
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/4551/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent 637cc1b5fc
commit e38d3dfc76

@ -30,6 +30,7 @@ import (
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/url" "net/url"
"time"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
@ -98,48 +99,98 @@ type dialParams struct {
} }
func (a *dialParams) dial() (*controlbase.Conn, error) { func (a *dialParams) dial() (*controlbase.Conn, error) {
init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version) // Create one shared context used by both port 80 and port 443 dials.
if err != nil { // If port 80 is still in flight when 443 returns, this deferred cancel
return nil, err // will stop the port 80 dial.
} ctx, cancel := context.WithCancel(a.ctx)
defer cancel()
u := &url.URL{ // u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS,
// respectively, in order to do the HTTP upgrade to a net.Conn over which
// we'll speak Noise.
u80 := &url.URL{
Scheme: "http", Scheme: "http",
Host: net.JoinHostPort(a.host, a.httpPort), Host: net.JoinHostPort(a.host, a.httpPort),
Path: serverUpgradePath, Path: serverUpgradePath,
} }
conn, httpErr := a.tryURL(u, init) u443 := &url.URL{
if httpErr == nil { Scheme: "https",
ret, err := cont(a.ctx, conn) Host: net.JoinHostPort(a.host, a.httpsPort),
if err != nil { Path: serverUpgradePath,
conn.Close()
return nil, err
} }
return ret, nil type tryURLRes struct {
u *url.URL
conn net.Conn
cont controlbase.HandshakeContinuation
err error
} }
ch := make(chan tryURLRes) // must be unbuffered
// Connecting over plain HTTP failed, assume it's an HTTP proxy try := func(u *url.URL) {
// being difficult and see if we can get through over HTTPS. res := tryURLRes{u: u}
u.Scheme = "https" var init []byte
u.Host = net.JoinHostPort(a.host, a.httpsPort) init, res.cont, res.err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
init, cont, err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version) if res.err == nil {
if err != nil { res.conn, res.err = a.tryURL(ctx, u, init)
return nil, err }
select {
case ch <- res:
case <-ctx.Done():
if res.conn != nil {
res.conn.Close()
}
} }
conn, tlsErr := a.tryURL(u, init) }
if tlsErr == nil {
ret, err := cont(a.ctx, conn) // Start the plaintext HTTP attempt first.
go try(u80)
// In case outbound port 80 blocked or MITM'ed poorly, start a backup timer
// to dial port 443 if port 80 doesn't either succeed or fail quickly.
try443Timer := time.AfterFunc(500*time.Millisecond, func() { try(u443) })
defer try443Timer.Stop()
var err80, err443 error
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err())
case res := <-ch:
if res.err == nil {
ret, err := res.cont(ctx, res.conn)
if err != nil { if err != nil {
conn.Close() res.conn.Close()
return nil, err return nil, err
} }
return ret, nil return ret, nil
} }
switch res.u {
return nil, fmt.Errorf("all connection attempts failed (HTTP: %v, HTTPS: %v)", httpErr, tlsErr) case u80:
// Connecting over plain HTTP failed; assume it's an HTTP proxy
// being difficult and see if we can get through over HTTPS.
err80 = res.err
// Stop the fallback timer and run it immediately. We don't use
// Timer.Reset(0) here because on AfterFuncs, that can run it
// again.
if try443Timer.Stop() {
go try(u443)
} // else we lost the race and it started already which is what we want
case u443:
err443 = res.err
default:
panic("invalid")
}
if err80 != nil && err443 != nil {
return nil, fmt.Errorf("all connection attempts failed (HTTP: %v, HTTPS: %v)", err80, err443)
}
}
}
} }
func (a *dialParams) tryURL(u *url.URL, init []byte) (net.Conn, error) { // tryURL connects to u, and tries to upgrade it to a net.Conn.
//
// Only the provided ctx is used, not a.ctx.
func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
dns := &dnscache.Resolver{ dns := &dnscache.Resolver{
Forward: dnscache.Get().Forward, Forward: dnscache.Get().Forward,
LookupIPFallback: dnsfallback.Lookup, LookupIPFallback: dnsfallback.Lookup,
@ -189,7 +240,7 @@ func (a *dialParams) tryURL(u *url.URL, init []byte) (net.Conn, error) {
connCh <- info.Conn connCh <- info.Conn
}, },
} }
ctx := httptrace.WithClientTrace(a.ctx, &trace) ctx = httptrace.WithClientTrace(ctx, &trace)
req := &http.Request{ req := &http.Request{
Method: "POST", Method: "POST",
URL: u, URL: u,

Loading…
Cancel
Save