control/controlhttp: extract the last network connection

The same context we use for the HTTP request here might be re-used by
the dialer, which could result in `GotConn` being called multiple times.
We only care about the last one.

Fixes #13009

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
pull/13051/head
Anton Tolchanov 4 months ago committed by Anton Tolchanov
parent b3fc345aba
commit 7bac5dffcb

@ -46,6 +46,7 @@ import (
"tailscale.com/net/sockstats" "tailscale.com/net/sockstats"
"tailscale.com/net/tlsdial" "tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy" "tailscale.com/net/tshttpproxy"
"tailscale.com/syncs"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstime" "tailscale.com/tstime"
"tailscale.com/util/multierr" "tailscale.com/util/multierr"
@ -497,11 +498,9 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
tr.DisableCompression = true tr.DisableCompression = true
// (mis)use httptrace to extract the underlying net.Conn from the // (mis)use httptrace to extract the underlying net.Conn from the
// transport. We make exactly 1 request using this transport, so // transport. The transport handles 101 Switching Protocols correctly,
// there will be exactly 1 GotConn call. Additionally, the // such that the Conn will not be reused or kept alive by the transport
// transport handles 101 Switching Protocols correctly, such that // once the response has been handed back from RoundTrip.
// the Conn will not be reused or kept alive by the transport once
// the response has been handed back from RoundTrip.
// //
// In theory, the machinery of net/http should make it such that // In theory, the machinery of net/http should make it such that
// the trace callback happens-before we get the response, but // the trace callback happens-before we get the response, but
@ -517,10 +516,16 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
// unexpected EOFs...), and we're bound to forget someday and // unexpected EOFs...), and we're bound to forget someday and
// introduce a protocol optimization at a higher level that starts // introduce a protocol optimization at a higher level that starts
// eagerly transmitting from the server. // eagerly transmitting from the server.
connCh := make(chan net.Conn, 1) var lastConn syncs.AtomicValue[net.Conn]
trace := httptrace.ClientTrace{ trace := httptrace.ClientTrace{
// Even though we only make a single HTTP request which should
// require a single connection, the context (with the attached
// trace configuration) might be used by our custom dialer to
// make other HTTP requests (e.g. BootstrapDNS). We only care
// about the last connection made, which should be the one to
// the control server.
GotConn: func(info httptrace.GotConnInfo) { GotConn: func(info httptrace.GotConnInfo) {
connCh <- info.Conn lastConn.Store(info.Conn)
}, },
} }
ctx = httptrace.WithClientTrace(ctx, &trace) ctx = httptrace.WithClientTrace(ctx, &trace)
@ -548,11 +553,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
// is still a read buffer attached to it within resp.Body. So, we // is still a read buffer attached to it within resp.Body. So, we
// must direct I/O through resp.Body, but we can still use the // must direct I/O through resp.Body, but we can still use the
// underlying net.Conn for stuff like deadlines. // underlying net.Conn for stuff like deadlines.
var switchedConn net.Conn switchedConn := lastConn.Load()
select {
case switchedConn = <-connCh:
default:
}
if switchedConn == nil { if switchedConn == nil {
resp.Body.Close() resp.Body.Close()
return nil, fmt.Errorf("httptrace didn't provide a connection") return nil, fmt.Errorf("httptrace didn't provide a connection")

@ -11,10 +11,12 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"net/http/httputil" "net/http/httputil"
"net/netip" "net/netip"
"net/url" "net/url"
"runtime" "runtime"
"slices"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
@ -41,6 +43,8 @@ type httpTestParam struct {
makeHTTPHangAfterUpgrade bool makeHTTPHangAfterUpgrade bool
doEarlyWrite bool doEarlyWrite bool
httpInDial bool
} }
func TestControlHTTP(t *testing.T) { func TestControlHTTP(t *testing.T) {
@ -120,6 +124,12 @@ func TestControlHTTP(t *testing.T) {
name: "early_write", name: "early_write",
doEarlyWrite: true, doEarlyWrite: true,
}, },
// Dialer needed to make another HTTP request along the way (e.g. to
// resolve the hostname via BootstrapDNS).
{
name: "http_request_in_dial",
httpInDial: true,
},
} }
for _, test := range tests { for _, test := range tests {
@ -217,6 +227,29 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
Clock: clock, Clock: clock,
} }
if param.httpInDial {
// Spin up a separate server to get a different port on localhost.
secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return }))
defer secondServer.Close()
prev := a.Dialer
a.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", secondServer.URL, nil)
if err != nil {
t.Errorf("http.NewRequest: %v", err)
}
r, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("http.Get: %v", err)
}
r.Body.Close()
return prev(ctx, network, addr)
}
}
if proxy != nil { if proxy != nil {
proxyEnv := proxy.Start(t) proxyEnv := proxy.Start(t)
defer proxy.Close() defer proxy.Close()
@ -238,6 +271,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
t.Fatalf("dialing controlhttp: %v", err) t.Fatalf("dialing controlhttp: %v", err)
} }
defer conn.Close() defer conn.Close()
si := <-sch si := <-sch
if si.conn != nil { if si.conn != nil {
defer si.conn.Close() defer si.conn.Close()
@ -266,6 +300,19 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
t.Errorf("early write = %q; want %q", buf, earlyWriteMsg) t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
} }
} }
// When no proxy is used, the RemoteAddr of the returned connection should match
// one of the listeners of the test server.
if proxy == nil {
var expectedAddrs []string
for _, ln := range []net.Listener{httpLn, httpsLn} {
expectedAddrs = append(expectedAddrs, fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port))
expectedAddrs = append(expectedAddrs, fmt.Sprintf("[::1]:%d", ln.Addr().(*net.TCPAddr).Port))
}
if !slices.Contains(expectedAddrs, conn.RemoteAddr().String()) {
t.Errorf("unexpected remote addr: %s, want %s", conn.RemoteAddr(), expectedAddrs)
}
}
} }
type serverResult struct { type serverResult struct {

Loading…
Cancel
Save