From c2a7f17f2b378897f4545ad6f43891f150423487 Mon Sep 17 00:00:00 2001 From: Andrew Lytvynov Date: Mon, 18 Nov 2024 09:55:54 -0800 Subject: [PATCH] sessionrecording: implement v2 recording endpoint support (#14105) The v2 endpoint supports HTTP/2 bidirectional streaming and acks for received bytes. This is used to detect when a recorder disappears to more quickly terminate the session. Updates https://github.com/tailscale/corp/issues/24023 Signed-off-by: Andrew Lytvynov --- k8s-operator/sessionrecording/hijacker.go | 2 +- .../sessionrecording/hijacker_test.go | 4 +- sessionrecording/connect.go | 320 ++++++++++++++---- sessionrecording/connect_test.go | 189 +++++++++++ ssh/tailssh/tailssh.go | 13 +- ssh/tailssh/tailssh_test.go | 61 ++-- 6 files changed, 500 insertions(+), 89 deletions(-) create mode 100644 sessionrecording/connect_test.go diff --git a/k8s-operator/sessionrecording/hijacker.go b/k8s-operator/sessionrecording/hijacker.go index f8ef951d4..43aa14e61 100644 --- a/k8s-operator/sessionrecording/hijacker.go +++ b/k8s-operator/sessionrecording/hijacker.go @@ -102,7 +102,7 @@ type Hijacker struct { // connection succeeds. In case of success, returns a list with a single // successful recording attempt and an error channel. If the connection errors // after having been established, an error is sent down the channel. -type RecorderDialFn func(context.Context, []netip.AddrPort, func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) +type RecorderDialFn func(context.Context, []netip.AddrPort, sessionrecording.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) // Hijack hijacks a 'kubectl exec' session and configures for the session // contents to be sent to a recorder. diff --git a/k8s-operator/sessionrecording/hijacker_test.go b/k8s-operator/sessionrecording/hijacker_test.go index 440d9c942..e166ce63b 100644 --- a/k8s-operator/sessionrecording/hijacker_test.go +++ b/k8s-operator/sessionrecording/hijacker_test.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/netip" "net/url" @@ -20,6 +19,7 @@ import ( "go.uber.org/zap" "tailscale.com/client/tailscale/apitype" "tailscale.com/k8s-operator/sessionrecording/fakes" + "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tstest" @@ -80,7 +80,7 @@ func Test_Hijacker(t *testing.T) { h := &Hijacker{ connectToRecorder: func(context.Context, []netip.AddrPort, - func(context.Context, string, string) (net.Conn, error), + sessionrecording.DialFunc, ) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) { if tt.failRecorderConnect { err = errors.New("test") diff --git a/sessionrecording/connect.go b/sessionrecording/connect.go index db966ba2c..94761393f 100644 --- a/sessionrecording/connect.go +++ b/sessionrecording/connect.go @@ -7,6 +7,8 @@ package sessionrecording import ( "context" + "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -14,12 +16,33 @@ import ( "net/http" "net/http/httptrace" "net/netip" + "sync/atomic" "time" + "golang.org/x/net/http2" "tailscale.com/tailcfg" + "tailscale.com/util/httpm" "tailscale.com/util/multierr" ) +const ( + // Timeout for an individual DialFunc call for a single recorder address. + perDialAttemptTimeout = 5 * time.Second + // Timeout for the V2 API HEAD probe request (supportsV2). + http2ProbeTimeout = 10 * time.Second + // Maximum timeout for trying all available recorders, including V2 API + // probes and dial attempts. + allDialAttemptsTimeout = 30 * time.Second +) + +// uploadAckWindow is the period of time to wait for an ackFrame from recorder +// before terminating the connection. This is a variable to allow overriding it +// in tests. +var uploadAckWindow = 30 * time.Second + +// DialFunc is a function for dialing the recorder. +type DialFunc func(ctx context.Context, network, host string) (net.Conn, error) + // ConnectToRecorder connects to the recorder at any of the provided addresses. // It returns the first successful response, or a multierr if all attempts fail. // @@ -32,19 +55,15 @@ import ( // attempts are in order the recorder(s) was attempted. If successful a // successful connection is made, the last attempt in the slice is the // attempt for connected recorder. -func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { +func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { if len(recs) == 0 { return nil, nil, nil, errors.New("no recorders configured") } // We use a special context for dialing the recorder, so that we can // limit the time we spend dialing to 30 seconds and still have an // unbounded context for the upload. - dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second) + dialCtx, dialCancel := context.WithTimeout(ctx, allDialAttemptsTimeout) defer dialCancel() - hc, err := SessionRecordingClientForDialer(dialCtx, dial) - if err != nil { - return nil, nil, nil, err - } var errs []error var attempts []*tailcfg.SSHRecordingAttempt @@ -54,74 +73,230 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con } attempts = append(attempts, attempt) - // We dial the recorder and wait for it to send a 100-continue - // response before returning from this function. This ensures that - // the recorder is ready to accept the recording. - - // got100 is closed when we receive the 100-continue response. - got100 := make(chan struct{}) - ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ - Got100Continue: func() { - close(got100) - }, - }) - - pr, pw := io.Pipe() - req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr) + var pw io.WriteCloser + var errChan <-chan error + var err error + hc := clientHTTP2(dialCtx, dial) + // We need to probe V2 support using a separate HEAD request. Sending + // an HTTP/2 POST request to a HTTP/1 server will just "hang" until the + // request body is closed (instead of returning a 404 as one would + // expect). Sending a HEAD request without a body does not have that + // problem. + if supportsV2(ctx, hc, ap) { + pw, errChan, err = connectV2(ctx, hc, ap) + } else { + pw, errChan, err = connectV1(ctx, clientHTTP1(dialCtx, dial), ap) + } if err != nil { - err = fmt.Errorf("recording: error starting recording: %w", err) + err = fmt.Errorf("recording: error starting recording on %q: %w", ap, err) attempt.FailureMessage = err.Error() errs = append(errs, err) continue } - // We set the Expect header to 100-continue, so that the recorder - // will send a 100-continue response before it starts reading the - // request body. - req.Header.Set("Expect", "100-continue") + return pw, attempts, errChan, nil + } + return nil, attempts, nil, multierr.New(errs...) +} - // errChan is used to indicate the result of the request. - errChan := make(chan error, 1) - go func() { - resp, err := hc.Do(req) - if err != nil { - errChan <- fmt.Errorf("recording: error starting recording: %w", err) +// supportsV2 checks whether a recorder instance supports the /v2/record +// endpoint. +func supportsV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) bool { + ctx, cancel := context.WithTimeout(ctx, http2ProbeTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, httpm.HEAD, fmt.Sprintf("http://%s/v2/record", ap), nil) + if err != nil { + return false + } + resp, err := hc.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode == http.StatusOK && resp.ProtoMajor > 1 +} + +// connectV1 connects to the legacy /record endpoint on the recorder. It is +// used for backwards-compatibility with older tsrecorder instances. +// +// On success, it returns a WriteCloser that can be used to upload the +// recording, and a channel that will be sent an error (or nil) when the upload +// fails or completes. +func connectV1(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) { + // We dial the recorder and wait for it to send a 100-continue + // response before returning from this function. This ensures that + // the recorder is ready to accept the recording. + + // got100 is closed when we receive the 100-continue response. + got100 := make(chan struct{}) + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + Got100Continue: func() { + close(got100) + }, + }) + + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/record", ap), pr) + if err != nil { + return nil, nil, err + } + // We set the Expect header to 100-continue, so that the recorder + // will send a 100-continue response before it starts reading the + // request body. + req.Header.Set("Expect", "100-continue") + + // errChan is used to indicate the result of the request. + errChan := make(chan error, 1) + go func() { + defer close(errChan) + resp, err := hc.Do(req) + if err != nil { + errChan <- err + return + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status) + return + } + }() + select { + case <-got100: + return pw, errChan, nil + case err := <-errChan: + // If we get an error before we get the 100-continue response, + // we need to try another recorder. + if err == nil { + // If the error is nil, we got a 200 response, which + // is unexpected as we haven't sent any data yet. + err = errors.New("recording: unexpected EOF") + } + return nil, nil, err + } +} + +// connectV2 connects to the /v2/record endpoint on the recorder over HTTP/2. +// It explicitly tracks ack frames sent in the response and terminates the +// connection if sent recording data is un-acked for uploadAckWindow. +// +// On success, it returns a WriteCloser that can be used to upload the +// recording, and a channel that will be sent an error (or nil) when the upload +// fails or completes. +func connectV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) { + pr, pw := io.Pipe() + upload := &readCounter{r: pr} + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/v2/record", ap), upload) + if err != nil { + return nil, nil, err + } + + // With HTTP/2, hc.Do will not block while the request body is being sent. + // It will return immediately and allow us to consume the response body at + // the same time. + resp, err := hc.Do(req) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, nil, fmt.Errorf("recording: unexpected status: %v", resp.Status) + } + + errChan := make(chan error, 1) + acks := make(chan int64) + // Read acks from the response and send them to the acks channel. + go func() { + defer close(errChan) + defer close(acks) + defer resp.Body.Close() + defer pw.Close() + dec := json.NewDecoder(resp.Body) + for { + var frame v2ResponseFrame + if err := dec.Decode(&frame); err != nil { + if !errors.Is(err, io.EOF) { + errChan <- fmt.Errorf("recording: unexpected error receiving acks: %w", err) + } return } - if resp.StatusCode != 200 { - errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status) + if frame.Error != "" { + errChan <- fmt.Errorf("recording: received error from the recorder: %q", frame.Error) return } - errChan <- nil - }() - select { - case <-got100: - case err := <-errChan: - // If we get an error before we get the 100-continue response, - // we need to try another recorder. - if err == nil { - // If the error is nil, we got a 200 response, which - // is unexpected as we haven't sent any data yet. - err = errors.New("recording: unexpected EOF") + select { + case acks <- frame.Ack: + case <-ctx.Done(): + return } - attempt.FailureMessage = err.Error() - errs = append(errs, err) - continue // try the next recorder } - return pw, attempts, errChan, nil - } - return nil, attempts, nil, multierr.New(errs...) + }() + // Track acks from the acks channel. + go func() { + // Hack for tests: some tests modify uploadAckWindow and reset it when + // the test ends. This can race with t.Reset call below. Making a copy + // here is a lazy workaround to not wait for this goroutine to exit in + // the test cases. + uploadAckWindow := uploadAckWindow + // This timer fires if we didn't receive an ack for too long. + t := time.NewTimer(uploadAckWindow) + defer t.Stop() + for { + select { + case <-t.C: + // Close the pipe which terminates the connection and cleans up + // other goroutines. Note that tsrecorder will send us ack + // frames even if there is no new data to ack. This helps + // detect broken recorder connection if the session is idle. + pr.CloseWithError(errNoAcks) + resp.Body.Close() + return + case _, ok := <-acks: + if !ok { + // acks channel closed means that the goroutine reading them + // finished, which means that the request has ended. + return + } + // TODO(awly): limit how far behind the received acks can be. This + // should handle scenarios where a session suddenly dumps a lot of + // output. + t.Reset(uploadAckWindow) + case <-ctx.Done(): + return + } + } + }() + + return pw, errChan, nil } -// SessionRecordingClientForDialer returns an http.Client that uses a clone of -// the provided Dialer's PeerTransport to dial connections. This is used to make -// requests to the session recording server to upload session recordings. It -// uses the provided dialCtx to dial connections, and limits a single dial to 5 -// seconds. -func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.Context, string, string) (net.Conn, error)) (*http.Client, error) { - tr := http.DefaultTransport.(*http.Transport).Clone() +var errNoAcks = errors.New("did not receive ack frames from the recorder in 30s") + +type v2ResponseFrame struct { + // Ack is the number of bytes received from the client so far. The bytes + // are not guaranteed to be durably stored yet. + Ack int64 `json:"ack,omitempty"` + // Error is an error encountered while storing the recording. Error is only + // ever set as the last frame in the response. + Error string `json:"error,omitempty"` +} +// readCounter is an io.Reader that counts how many bytes were read. +type readCounter struct { + r io.Reader + sent atomic.Int64 +} + +func (u *readCounter) Read(buf []byte) (int, error) { + n, err := u.r.Read(buf) + u.sent.Add(int64(n)) + return n, err +} + +// clientHTTP1 returns a claassic http.Client with a per-dial context. It uses +// dialCtx and adds a 5s timeout to it. +func clientHTTP1(dialCtx context.Context, dial DialFunc) *http.Client { + tr := http.DefaultTransport.(*http.Transport).Clone() tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout) defer cancel() go func() { select { @@ -132,7 +307,32 @@ func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context. }() return dial(perAttemptCtx, network, addr) } + return &http.Client{Transport: tr} +} + +// clientHTTP2 is like clientHTTP1 but returns an http.Client suitable for h2c +// requests (HTTP/2 over plaintext). Unfortunately the same client does not +// work for HTTP/1 so we need to split these up. +func clientHTTP2(dialCtx context.Context, dial DialFunc) *http.Client { return &http.Client{ - Transport: tr, - }, nil + Transport: &http2.Transport{ + // Allow "http://" scheme in URLs. + AllowHTTP: true, + // Pretend like we're using TLS, but actually use the provided + // DialFunc underneath. This is necessary to convince the transport + // to actually dial. + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout) + defer cancel() + go func() { + select { + case <-perAttemptCtx.Done(): + case <-dialCtx.Done(): + cancel() + } + }() + return dial(perAttemptCtx, network, addr) + }, + }, + } } diff --git a/sessionrecording/connect_test.go b/sessionrecording/connect_test.go new file mode 100644 index 000000000..c0fcf6d40 --- /dev/null +++ b/sessionrecording/connect_test.go @@ -0,0 +1,189 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sessionrecording + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "testing" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +func TestConnectToRecorder(t *testing.T) { + tests := []struct { + desc string + http2 bool + // setup returns a recorder server mux, and a channel which sends the + // hash of the recording uploaded to it. The channel is expected to + // fire only once. + setup func(t *testing.T) (*http.ServeMux, <-chan []byte) + wantErr bool + }{ + { + desc: "v1 recorder", + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + hash := sha256.New() + if _, err := io.Copy(hash, r.Body); err != nil { + t.Error(err) + } + uploadHash <- hash.Sum(nil) + }) + return mux, uploadHash + }, + }, + { + desc: "v2 recorder", + http2: true, + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + t.Error("received request to v1 endpoint") + http.Error(w, "not found", http.StatusNotFound) + }) + mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { + // Force the status to send to unblock the client waiting + // for it. + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + body := &readCounter{r: r.Body} + hash := sha256.New() + ctx, cancel := context.WithCancel(r.Context()) + go func() { + defer cancel() + if _, err := io.Copy(hash, body); err != nil { + t.Error(err) + } + }() + + // Send acks for received bytes. + tick := time.NewTicker(time.Millisecond) + defer tick.Stop() + enc := json.NewEncoder(w) + outer: + for { + select { + case <-ctx.Done(): + break outer + case <-tick.C: + if err := enc.Encode(v2ResponseFrame{Ack: body.sent.Load()}); err != nil { + t.Errorf("writing ack frame: %v", err) + break outer + } + } + } + + uploadHash <- hash.Sum(nil) + }) + // Probing HEAD endpoint which always returns 200 OK. + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + return mux, uploadHash + }, + }, + { + desc: "v2 recorder no acks", + http2: true, + wantErr: true, + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + // Make the client no-ack timeout quick for the test. + oldAckWindow := uploadAckWindow + uploadAckWindow = 100 * time.Millisecond + t.Cleanup(func() { uploadAckWindow = oldAckWindow }) + + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + t.Error("received request to v1 endpoint") + http.Error(w, "not found", http.StatusNotFound) + }) + mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { + // Force the status to send to unblock the client waiting + // for it. + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + // Consume the whole request body but don't send any acks + // back. + hash := sha256.New() + if _, err := io.Copy(hash, r.Body); err != nil { + t.Error(err) + } + // Goes in the channel buffer, non-blocking. + uploadHash <- hash.Sum(nil) + + // Block until the parent test case ends to prevent the + // request termination. We want to exercise the ack + // tracking logic specifically. + ctx, cancel := context.WithCancel(r.Context()) + t.Cleanup(cancel) + <-ctx.Done() + }) + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + return mux, uploadHash + }, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + mux, uploadHash := tt.setup(t) + + srv := httptest.NewUnstartedServer(mux) + if tt.http2 { + // Wire up h2c-compatible HTTP/2 server. This is optional + // because the v1 recorder didn't support HTTP/2 and we try to + // mimic that. + h2s := &http2.Server{} + srv.Config.Handler = h2c.NewHandler(mux, h2s) + if err := http2.ConfigureServer(srv.Config, h2s); err != nil { + t.Errorf("configuring HTTP/2 support in server: %v", err) + } + } + srv.Start() + t.Cleanup(srv.Close) + + d := new(net.Dialer) + + ctx := context.Background() + w, _, errc, err := ConnectToRecorder(ctx, []netip.AddrPort{netip.MustParseAddrPort(srv.Listener.Addr().String())}, d.DialContext) + if err != nil { + t.Fatalf("ConnectToRecorder: %v", err) + } + + // Send some random data and hash it to compare with the recorded + // data hash. + hash := sha256.New() + const numBytes = 1 << 20 // 1MB + if _, err := io.CopyN(io.MultiWriter(w, hash), rand.Reader, numBytes); err != nil { + t.Fatalf("writing recording data: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("closing recording stream: %v", err) + } + if err := <-errc; err != nil && !tt.wantErr { + t.Fatalf("error from the channel: %v", err) + } else if err == nil && tt.wantErr { + t.Fatalf("did not receive expected error from the channel") + } + + if recv, sent := <-uploadHash, hash.Sum(nil); !bytes.Equal(recv, sent) { + t.Errorf("mismatch in recording data hash, sent %x, received %x", sent, recv) + } + }) + } +} diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 9ade1847e..7cb99c381 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -1170,7 +1170,7 @@ func (ss *sshSession) run() { if err != nil && !errors.Is(err, io.EOF) { isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO) if !isErrBecauseProcessExited { - logf("stdout copy: %v, %T", err) + logf("stdout copy: %v", err) ss.cancelCtx(err) } } @@ -1520,9 +1520,14 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { go func() { err := <-errChan if err == nil { - // Success. - ss.logf("recording: finished uploading recording") - return + select { + case <-ss.ctx.Done(): + // Success. + ss.logf("recording: finished uploading recording") + return + default: + err = errors.New("recording upload ended before the SSH session") + } } if onFailure != nil && onFailure.NotifyURL != "" && len(attempts) > 0 { lastAttempt := attempts[len(attempts)-1] diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 7ce0aeea3..ad9cb1e57 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -33,6 +33,8 @@ import ( "time" gossh "github.com/tailscale/golang-x-crypto/ssh" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/net/memnet" @@ -481,10 +483,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { } var handler http.HandlerFunc - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { handler(w, r) - })) - defer recordingServer.Close() + }) s := &server{ logf: t.Logf, @@ -533,9 +534,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { { name: "upload-fails-after-starting", handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() r.Body.Read(make([]byte, 1)) time.Sleep(100 * time.Millisecond) - w.WriteHeader(http.StatusInternalServerError) }, sshCommand: "echo hello && sleep 1 && echo world", wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n", @@ -548,6 +550,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + s.logf = t.Logf tstest.Replace(t, &handler, tt.handler) sc, dc := memnet.NewTCPConn(src, dst, 1024) var wg sync.WaitGroup @@ -597,12 +600,12 @@ func TestMultipleRecorders(t *testing.T) { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } done := make(chan struct{}) - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { defer close(done) - io.ReadAll(r.Body) w.WriteHeader(http.StatusOK) - })) - defer recordingServer.Close() + w.(http.Flusher).Flush() + io.ReadAll(r.Body) + }) badRecorder, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) @@ -610,15 +613,9 @@ func TestMultipleRecorders(t *testing.T) { badRecorderAddr := badRecorder.Addr().String() badRecorder.Close() - badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(500) - })) - defer badRecordingServer500.Close() - - badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - })) - defer badRecordingServer200.Close() + badRecordingServer500 := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) s := &server{ logf: t.Logf, @@ -630,7 +627,6 @@ func TestMultipleRecorders(t *testing.T) { Recorders: []netip.AddrPort{ netip.MustParseAddrPort(badRecorderAddr), netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()), - netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()), netip.MustParseAddrPort(recordingServer.Listener.Addr().String()), }, OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{ @@ -701,19 +697,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) { } var recording []byte ctx, cancel := context.WithTimeout(context.Background(), time.Second) - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { defer cancel() + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + var err error recording, err = io.ReadAll(r.Body) if err != nil { t.Error(err) return } - })) - defer recordingServer.Close() + }) s := &server{ - logf: logger.Discard, + logf: t.Logf, lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -1299,3 +1297,22 @@ func TestStdOsUserUserAssumptions(t *testing.T) { t.Errorf("os/user.User has %v fields; this package assumes %v", got, want) } } + +func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(http.ResponseWriter, *http.Request) { + t.Errorf("v1 recording endpoint called") + }) + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + mux.HandleFunc("POST /v2/record", handleRecord) + + h2s := &http2.Server{} + srv := httptest.NewUnstartedServer(h2c.NewHandler(mux, h2s)) + if err := http2.ConfigureServer(srv.Config, h2s); err != nil { + t.Errorf("configuring HTTP/2 support in recording server: %v", err) + } + srv.Start() + t.Cleanup(srv.Close) + return srv +}