// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package sessionrecording import ( "bytes" "context" "crypto/rand" "crypto/sha256" "encoding/json" "io" "net" "net/http" "net/http/httptest" "net/netip" "testing" "time" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) func TestConnectToRecorder(t *testing.T) { tests := []struct { desc string http2 bool // setup returns a recorder server mux, and a channel which sends the // hash of the recording uploaded to it. The channel is expected to // fire only once. setup func(t *testing.T) (*http.ServeMux, <-chan []byte) wantErr bool }{ { desc: "v1 recorder", setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { uploadHash := make(chan []byte, 1) mux := http.NewServeMux() mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { hash := sha256.New() if _, err := io.Copy(hash, r.Body); err != nil { t.Error(err) } uploadHash <- hash.Sum(nil) }) return mux, uploadHash }, }, { desc: "v2 recorder", http2: true, setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { uploadHash := make(chan []byte, 1) mux := http.NewServeMux() mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { t.Error("received request to v1 endpoint") http.Error(w, "not found", http.StatusNotFound) }) mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { // Force the status to send to unblock the client waiting // for it. w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() body := &readCounter{r: r.Body} hash := sha256.New() ctx, cancel := context.WithCancel(r.Context()) go func() { defer cancel() if _, err := io.Copy(hash, body); err != nil { t.Error(err) } }() // Send acks for received bytes. tick := time.NewTicker(time.Millisecond) defer tick.Stop() enc := json.NewEncoder(w) outer: for { select { case <-ctx.Done(): break outer case <-tick.C: if err := enc.Encode(v2ResponseFrame{Ack: body.sent.Load()}); err != nil { t.Errorf("writing ack frame: %v", err) break outer } } } uploadHash <- hash.Sum(nil) }) // Probing HEAD endpoint which always returns 200 OK. mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) return mux, uploadHash }, }, { desc: "v2 recorder no acks", http2: true, wantErr: true, setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { // Make the client no-ack timeout quick for the test. oldAckWindow := uploadAckWindow uploadAckWindow = 100 * time.Millisecond t.Cleanup(func() { uploadAckWindow = oldAckWindow }) uploadHash := make(chan []byte, 1) mux := http.NewServeMux() mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { t.Error("received request to v1 endpoint") http.Error(w, "not found", http.StatusNotFound) }) mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { // Force the status to send to unblock the client waiting // for it. w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() // Consume the whole request body but don't send any acks // back. hash := sha256.New() if _, err := io.Copy(hash, r.Body); err != nil { t.Error(err) } // Goes in the channel buffer, non-blocking. uploadHash <- hash.Sum(nil) // Block until the parent test case ends to prevent the // request termination. We want to exercise the ack // tracking logic specifically. ctx, cancel := context.WithCancel(r.Context()) t.Cleanup(cancel) <-ctx.Done() }) mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) return mux, uploadHash }, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { mux, uploadHash := tt.setup(t) srv := httptest.NewUnstartedServer(mux) if tt.http2 { // Wire up h2c-compatible HTTP/2 server. This is optional // because the v1 recorder didn't support HTTP/2 and we try to // mimic that. h2s := &http2.Server{} srv.Config.Handler = h2c.NewHandler(mux, h2s) if err := http2.ConfigureServer(srv.Config, h2s); err != nil { t.Errorf("configuring HTTP/2 support in server: %v", err) } } srv.Start() t.Cleanup(srv.Close) d := new(net.Dialer) ctx := context.Background() w, _, errc, err := ConnectToRecorder(ctx, []netip.AddrPort{netip.MustParseAddrPort(srv.Listener.Addr().String())}, d.DialContext) if err != nil { t.Fatalf("ConnectToRecorder: %v", err) } // Send some random data and hash it to compare with the recorded // data hash. hash := sha256.New() const numBytes = 1 << 20 // 1MB if _, err := io.CopyN(io.MultiWriter(w, hash), rand.Reader, numBytes); err != nil { t.Fatalf("writing recording data: %v", err) } if err := w.Close(); err != nil { t.Fatalf("closing recording stream: %v", err) } if err := <-errc; err != nil && !tt.wantErr { t.Fatalf("error from the channel: %v", err) } else if err == nil && tt.wantErr { t.Fatalf("did not receive expected error from the channel") } if recv, sent := <-uploadHash, hash.Sum(nil); !bytes.Equal(recv, sent) { t.Errorf("mismatch in recording data hash, sent %x, received %x", sent, recv) } }) } }