From 732605f961f32f0f428a8fb355fa5359d6f33ed2 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Wed, 19 Jun 2024 18:23:01 -0400 Subject: [PATCH] control/controlclient: move noiseConn to internal package So that it can be later used in the 'tailscale debug ts2021' function in the CLI, to aid in debugging captive portals/WAFs/etc. Updates #1634 Signed-off-by: Andrew Dunham Change-Id: Iec9423f5e7570f2c2c8218d27fc0902137e73909 --- cmd/tailscale/cli/debug.go | 27 ++++++ control/controlclient/noise.go | 151 +++-------------------------- internal/noiseconn/conn.go | 170 +++++++++++++++++++++++++++++++++ 3 files changed, 212 insertions(+), 136 deletions(-) create mode 100644 internal/noiseconn/conn.go diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index 6793e94cc..5cfc4aa34 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -834,6 +834,33 @@ func runTS2021(ctx context.Context, args []string) error { } log.Printf("final underlying conn: %v / %v", conn.LocalAddr(), conn.RemoteAddr()) + + // Make a /whois request to the server to verify that we can actually + // communicate over the newly-established connection. + whoisURL := "http://" + ts2021Args.host + "/machine/whois" + req, err = http.NewRequestWithContext(ctx, "GET", whoisURL, nil) + if err != nil { + return err + } + + // Use a fake http.Transport that just "dials" by returning the above + // conn. + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.ForceAttemptHTTP2 = true + tr.DialContext = func(context.Context, string, string) (net.Conn, error) { + return conn, nil + } + resp, err := tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip whois request: %w", err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("reading whois response: %w", err) + } + + log.Printf("whois response: %q", body) return nil } diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index 5a1c25e96..fdda96743 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -6,10 +6,8 @@ package controlclient import ( "bytes" "context" - "encoding/binary" "encoding/json" "errors" - "io" "math" "net/http" "net/url" @@ -17,9 +15,9 @@ import ( "time" "golang.org/x/net/http2" - "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp" "tailscale.com/health" + "tailscale.com/internal/noiseconn" "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" @@ -32,113 +30,6 @@ import ( "tailscale.com/util/singleflight" ) -// noiseConn is a wrapper around controlbase.Conn. -// It allows attaching an ID to a connection to allow -// cleaning up references in the pool when the connection -// is closed. -type noiseConn struct { - *controlbase.Conn - id int - pool *NoiseClient - h2cc *http2.ClientConn - - readHeaderOnce sync.Once // guards init of reader field - reader io.Reader // (effectively Conn.Reader after header) - earlyPayloadReady chan struct{} // closed after earlyPayload is set (including set to nil) - earlyPayload *tailcfg.EarlyNoise - earlyPayloadErr error -} - -func (c *noiseConn) RoundTrip(r *http.Request) (*http.Response, error) { - return c.h2cc.RoundTrip(r) -} - -// getEarlyPayload waits for the early noise payload to arrive. -// It may return (nil, nil) if the server begins HTTP/2 without one. -func (c *noiseConn) getEarlyPayload(ctx context.Context) (*tailcfg.EarlyNoise, error) { - select { - case <-c.earlyPayloadReady: - return c.earlyPayload, c.earlyPayloadErr - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// The first 9 bytes from the server to client over Noise are either an HTTP/2 -// settings frame (a normal HTTP/2 setup) or, as we added later, an "early payload" -// header that's also 9 bytes long: 5 bytes (earlyPayloadMagic) followed by 4 bytes -// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise. -// The early payload is optional. Some servers may not send it. -const ( - hdrLen = 9 // http2 frame header size; also size of our early payload size header - earlyPayloadMagic = "\xff\xff\xffTS" -) - -// returnErrReader is an io.Reader that always returns an error. -type returnErrReader struct { - err error // the error to return -} - -func (r returnErrReader) Read([]byte) (int, error) { return 0, r.err } - -// Read is basically the same as controlbase.Conn.Read, but it first reads the -// "early payload" header from the server which may or may not be present, -// depending on the server. -func (c *noiseConn) Read(p []byte) (n int, err error) { - c.readHeaderOnce.Do(c.readHeader) - return c.reader.Read(p) -} - -// readHeader reads the optional "early payload" from the server that arrives -// after the Noise handshake but before the HTTP/2 session begins. -// -// readHeader is responsible for reading the header (if present), initializing -// c.earlyPayload, closing c.earlyPayloadReady, and initializing c.reader for -// future reads. -func (c *noiseConn) readHeader() { - defer close(c.earlyPayloadReady) - - setErr := func(err error) { - c.reader = returnErrReader{err} - c.earlyPayloadErr = err - } - - var hdr [hdrLen]byte - if _, err := io.ReadFull(c.Conn, hdr[:]); err != nil { - setErr(err) - return - } - if string(hdr[:len(earlyPayloadMagic)]) != earlyPayloadMagic { - // No early payload. We have to return the 9 bytes read we already - // consumed. - c.reader = io.MultiReader(bytes.NewReader(hdr[:]), c.Conn) - return - } - epLen := binary.BigEndian.Uint32(hdr[len(earlyPayloadMagic):]) - if epLen > 10<<20 { - setErr(errors.New("invalid early payload length")) - return - } - payBuf := make([]byte, epLen) - if _, err := io.ReadFull(c.Conn, payBuf); err != nil { - setErr(err) - return - } - if err := json.Unmarshal(payBuf, &c.earlyPayload); err != nil { - setErr(err) - return - } - c.reader = c.Conn -} - -func (c *noiseConn) Close() error { - if err := c.Conn.Close(); err != nil { - return err - } - c.pool.connClosed(c.id) - return nil -} - // NoiseClient provides a http.Client to connect to tailcontrol over // the ts2021 protocol. type NoiseClient struct { @@ -158,7 +49,7 @@ type NoiseClient struct { // sfDial ensures that two concurrent requests for a noise connection only // produce one shared one between the two callers. - sfDial singleflight.Group[struct{}, *noiseConn] + sfDial singleflight.Group[struct{}, *noiseconn.Conn] dialer *tsdial.Dialer dnsCache *dnscache.Resolver @@ -180,9 +71,9 @@ type NoiseClient struct { // mu only protects the following variables. mu sync.Mutex closed bool - last *noiseConn // or nil + last *noiseconn.Conn // or nil nextID int - connPool map[int]*noiseConn // active connections not yet closed; see noiseConn.Close + connPool map[int]*noiseconn.Conn // active connections not yet closed; see noiseconn.Conn.Close } // NoiseOpts contains options for the NewNoiseClient function. All fields are @@ -283,12 +174,12 @@ func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.Round if err != nil { return nil, nil, err } - earlyPayloadMaybeNil, err := conn.getEarlyPayload(ctx) + rt, earlyPayloadMaybeNil, err := conn.ReserveNewRequest(ctx) if err != nil { return nil, nil, err } - if conn.h2cc.ReserveNewRequest() { - return conn, earlyPayloadMaybeNil, nil + if rt != nil { + return rt, earlyPayloadMaybeNil, nil } } return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection") @@ -308,14 +199,14 @@ func (e contextErr) Unwrap() error { return e.err } -// getConn returns a noiseConn that can be used to make requests to the +// getConn returns a noiseconn.Conn that can be used to make requests to the // coordination server. It may return a cached connection or create a new one. // Dials are singleflighted, so concurrent calls to getConn may only dial once. // As such, context values may not be respected as there are no guarantees that // the context passed to getConn is the same as the context passed to dial. -func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) { +func (nc *NoiseClient) getConn(ctx context.Context) (*noiseconn.Conn, error) { nc.mu.Lock() - if last := nc.last; last != nil && last.canTakeNewRequest() { + if last := nc.last; last != nil && last.CanTakeNewRequest() { nc.mu.Unlock() return last, nil } @@ -327,7 +218,7 @@ func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) { // canceled. Instead, we have to additionally check that the context // which was canceled is our context and retry if our context is still // valid. - conn, err, _ := nc.sfDial.Do(struct{}{}, func() (*noiseConn, error) { + conn, err, _ := nc.sfDial.Do(struct{}{}, func() (*noiseconn.Conn, error) { c, err := nc.dial(ctx) if err != nil { if ctx.Err() != nil { @@ -395,7 +286,7 @@ func (nc *NoiseClient) Close() error { // dial opens a new connection to tailcontrol, fetching the server noise key // if not cached. -func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) { +func (nc *NoiseClient) dial(ctx context.Context) (*noiseconn.Conn, error) { nc.mu.Lock() connID := nc.nextID nc.nextID++ @@ -465,18 +356,10 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) { return nil, err } - ncc := &noiseConn{ - Conn: clientConn.Conn, - id: connID, - pool: nc, - earlyPayloadReady: make(chan struct{}), - } - - h2cc, err := nc.h2t.NewClientConn(ncc) + ncc, err := noiseconn.New(clientConn.Conn, nc.h2t, connID, nc.connClosed) if err != nil { return nil, err } - ncc.h2cc = h2cc nc.mu.Lock() if nc.closed { @@ -485,7 +368,7 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) { return nil, errors.New("noise client closed") } defer nc.mu.Unlock() - mak.Set(&nc.connPool, ncc.id, ncc) + mak.Set(&nc.connPool, connID, ncc) nc.last = ncc return ncc, nil } @@ -508,9 +391,5 @@ func (nc *NoiseClient) post(ctx context.Context, path string, nodeKey key.NodePu if err != nil { return nil, err } - return conn.h2cc.RoundTrip(req) -} - -func (c *noiseConn) canTakeNewRequest() bool { - return c.h2cc.CanTakeNewRequest() + return conn.RoundTrip(req) } diff --git a/internal/noiseconn/conn.go b/internal/noiseconn/conn.go new file mode 100644 index 000000000..d826dfb47 --- /dev/null +++ b/internal/noiseconn/conn.go @@ -0,0 +1,170 @@ +package noiseconn + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "errors" + "io" + "net/http" + "sync" + + "golang.org/x/net/http2" + "tailscale.com/control/controlbase" + "tailscale.com/tailcfg" +) + +// Conn is a wrapper around controlbase.Conn. +// It allows attaching an ID to a connection to allow +// cleaning up references in the pool when the connection +// is closed. +type Conn struct { + *controlbase.Conn + id int + onClose func(int) + h2cc *http2.ClientConn + + readHeaderOnce sync.Once // guards init of reader field + reader io.Reader // (effectively Conn.Reader after header) + earlyPayloadReady chan struct{} // closed after earlyPayload is set (including set to nil) + earlyPayload *tailcfg.EarlyNoise + earlyPayloadErr error +} + +// New creates a new Conn that wraps the given controlbase.Conn. +// +// h2t is the HTTP/2 transport to use for the connection; a new +// http2.ClientConn will be created that reads from the returned Conn. +// +// connID should be a unique ID for this connection. When the Conn is closed, +// the onClose function will be called with the connID if it is non-nil. +func New(conn *controlbase.Conn, h2t *http2.Transport, connID int, onClose func(int)) (*Conn, error) { + ncc := &Conn{ + Conn: conn, + id: connID, + onClose: onClose, + earlyPayloadReady: make(chan struct{}), + } + h2cc, err := h2t.NewClientConn(ncc) + if err != nil { + return nil, err + } + ncc.h2cc = h2cc + return ncc, nil +} + +// RoundTrip implements the http.RoundTripper interface. +func (c *Conn) RoundTrip(r *http.Request) (*http.Response, error) { + return c.h2cc.RoundTrip(r) +} + +// getEarlyPayload waits for the early noise payload to arrive. +// It may return (nil, nil) if the server begins HTTP/2 without one. +func (c *Conn) getEarlyPayload(ctx context.Context) (*tailcfg.EarlyNoise, error) { + select { + case <-c.earlyPayloadReady: + return c.earlyPayload, c.earlyPayloadErr + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// ReserveNewRequest will reserve a new concurrent request on the connection. +// It returns a non-nil http.RoundTripper if the reservation was successful, +// and any early Noise payload if present. If a reservation was not successful, +// it will return nil with no error. +func (c *Conn) ReserveNewRequest(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) { + earlyPayloadMaybeNil, err := c.getEarlyPayload(ctx) + if err != nil { + return nil, nil, err + } + if c.h2cc.ReserveNewRequest() { + return c, earlyPayloadMaybeNil, nil + } + return nil, nil, nil +} + +// The first 9 bytes from the server to client over Noise are either an HTTP/2 +// settings frame (a normal HTTP/2 setup) or, as we added later, an "early payload" +// header that's also 9 bytes long: 5 bytes (earlyPayloadMagic) followed by 4 bytes +// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise. +// The early payload is optional. Some servers may not send it. +const ( + hdrLen = 9 // http2 frame header size; also size of our early payload size header + earlyPayloadMagic = "\xff\xff\xffTS" +) + +// returnErrReader is an io.Reader that always returns an error. +type returnErrReader struct { + err error // the error to return +} + +func (r returnErrReader) Read([]byte) (int, error) { return 0, r.err } + +// Read is basically the same as controlbase.Conn.Read, but it first reads the +// "early payload" header from the server which may or may not be present, +// depending on the server. +func (c *Conn) Read(p []byte) (n int, err error) { + c.readHeaderOnce.Do(c.readHeader) + return c.reader.Read(p) +} + +// readHeader reads the optional "early payload" from the server that arrives +// after the Noise handshake but before the HTTP/2 session begins. +// +// readHeader is responsible for reading the header (if present), initializing +// c.earlyPayload, closing c.earlyPayloadReady, and initializing c.reader for +// future reads. +func (c *Conn) readHeader() { + defer close(c.earlyPayloadReady) + + setErr := func(err error) { + c.reader = returnErrReader{err} + c.earlyPayloadErr = err + } + + var hdr [hdrLen]byte + if _, err := io.ReadFull(c.Conn, hdr[:]); err != nil { + setErr(err) + return + } + if string(hdr[:len(earlyPayloadMagic)]) != earlyPayloadMagic { + // No early payload. We have to return the 9 bytes read we already + // consumed. + c.reader = io.MultiReader(bytes.NewReader(hdr[:]), c.Conn) + return + } + epLen := binary.BigEndian.Uint32(hdr[len(earlyPayloadMagic):]) + if epLen > 10<<20 { + setErr(errors.New("invalid early payload length")) + return + } + payBuf := make([]byte, epLen) + if _, err := io.ReadFull(c.Conn, payBuf); err != nil { + setErr(err) + return + } + if err := json.Unmarshal(payBuf, &c.earlyPayload); err != nil { + setErr(err) + return + } + c.reader = c.Conn +} + +// Close closes the connection. +func (c *Conn) Close() error { + if err := c.Conn.Close(); err != nil { + return err + } + if c.onClose != nil { + c.onClose(c.id) + } + return nil +} + +// CanTakeNewRequest reports whether the connection can take a new request, +// meaning it has not been closed or received or sent a GOAWAY. +func (c *Conn) CanTakeNewRequest() bool { + return c.h2cc.CanTakeNewRequest() +}