control/controlhttp: start port 443 fallback sooner if 80's stuck

Fixes #4544

Change-Id: I39877e71915ad48c6668351c45cd8e33e2f5dbae
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/4551/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent 637cc1b5fc
commit e38d3dfc76

@ -30,6 +30,7 @@ import (
"net/http"
"net/http/httptrace"
"net/url"
"time"
"tailscale.com/control/controlbase"
"tailscale.com/net/dnscache"
@ -98,48 +99,98 @@ type dialParams struct {
}
func (a *dialParams) dial() (*controlbase.Conn, error) {
init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
if err != nil {
return nil, err
}
// Create one shared context used by both port 80 and port 443 dials.
// If port 80 is still in flight when 443 returns, this deferred cancel
// will stop the port 80 dial.
ctx, cancel := context.WithCancel(a.ctx)
defer cancel()
u := &url.URL{
// u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS,
// respectively, in order to do the HTTP upgrade to a net.Conn over which
// we'll speak Noise.
u80 := &url.URL{
Scheme: "http",
Host: net.JoinHostPort(a.host, a.httpPort),
Path: serverUpgradePath,
}
conn, httpErr := a.tryURL(u, init)
if httpErr == nil {
ret, err := cont(a.ctx, conn)
if err != nil {
conn.Close()
return nil, err
u443 := &url.URL{
Scheme: "https",
Host: net.JoinHostPort(a.host, a.httpsPort),
Path: serverUpgradePath,
}
return ret, nil
type tryURLRes struct {
u *url.URL
conn net.Conn
cont controlbase.HandshakeContinuation
err error
}
ch := make(chan tryURLRes) // must be unbuffered
// Connecting over plain HTTP failed, assume it's an HTTP proxy
// being difficult and see if we can get through over HTTPS.
u.Scheme = "https"
u.Host = net.JoinHostPort(a.host, a.httpsPort)
init, cont, err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
if err != nil {
return nil, err
try := func(u *url.URL) {
res := tryURLRes{u: u}
var init []byte
init, res.cont, res.err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
if res.err == nil {
res.conn, res.err = a.tryURL(ctx, u, init)
}
select {
case ch <- res:
case <-ctx.Done():
if res.conn != nil {
res.conn.Close()
}
}
conn, tlsErr := a.tryURL(u, init)
if tlsErr == nil {
ret, err := cont(a.ctx, conn)
}
// Start the plaintext HTTP attempt first.
go try(u80)
// In case outbound port 80 blocked or MITM'ed poorly, start a backup timer
// to dial port 443 if port 80 doesn't either succeed or fail quickly.
try443Timer := time.AfterFunc(500*time.Millisecond, func() { try(u443) })
defer try443Timer.Stop()
var err80, err443 error
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err())
case res := <-ch:
if res.err == nil {
ret, err := res.cont(ctx, res.conn)
if err != nil {
conn.Close()
res.conn.Close()
return nil, err
}
return ret, nil
}
return nil, fmt.Errorf("all connection attempts failed (HTTP: %v, HTTPS: %v)", httpErr, tlsErr)
switch res.u {
case u80:
// Connecting over plain HTTP failed; assume it's an HTTP proxy
// being difficult and see if we can get through over HTTPS.
err80 = res.err
// Stop the fallback timer and run it immediately. We don't use
// Timer.Reset(0) here because on AfterFuncs, that can run it
// again.
if try443Timer.Stop() {
go try(u443)
} // else we lost the race and it started already which is what we want
case u443:
err443 = res.err
default:
panic("invalid")
}
if err80 != nil && err443 != nil {
return nil, fmt.Errorf("all connection attempts failed (HTTP: %v, HTTPS: %v)", err80, err443)
}
}
}
}
func (a *dialParams) tryURL(u *url.URL, init []byte) (net.Conn, error) {
// tryURL connects to u, and tries to upgrade it to a net.Conn.
//
// Only the provided ctx is used, not a.ctx.
func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
dns := &dnscache.Resolver{
Forward: dnscache.Get().Forward,
LookupIPFallback: dnsfallback.Lookup,
@ -189,7 +240,7 @@ func (a *dialParams) tryURL(u *url.URL, init []byte) (net.Conn, error) {
connCh <- info.Conn
},
}
ctx := httptrace.WithClientTrace(a.ctx, &trace)
ctx = httptrace.WithClientTrace(ctx, &trace)
req := &http.Request{
Method: "POST",
URL: u,

Loading…
Cancel
Save