diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index a9f62a88b..7aa4403e7 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -1077,6 +1077,13 @@ func (ss *sshSession) run() { err := ss.launchProcess() if err != nil { logf("start failed: %v", err.Error()) + if errors.Is(err, context.Canceled) { + err := context.Cause(ss.ctx) + uve := userVisibleError{} + if errors.As(err, &uve) { + fmt.Fprintf(ss, "%s\r\n", uve) + } + } ss.Exit(1) return } @@ -1425,20 +1432,35 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { pw.Close() return nil, err } + // We want to wait for the server to respond with 100 Continue to notifiy us + // that it's ready to receive data. We do this to block the session from + // starting until the server is ready to receive data. + // It also allows the server to reject the request before we start sending + // data. + req.Header.Set("Expect", "100-continue") go func() { defer pw.Close() ss.logf("starting asciinema recording to %s", recorder) hc := ss.conn.srv.sessionRecordingClient() resp, err := hc.Do(req) if err != nil { - ss.cancelCtx(err) - ss.logf("recording: error sending recording to %s: %v", recorder, err) + err := fmt.Errorf("recording: error sending recording: %w", err) + ss.logf("%v", err) + ss.cancelCtx(userVisibleError{ + msg: "recording: error sending recording", + error: err, + }) return } defer resp.Body.Close() defer ss.cancelCtx(errors.New("recording: done")) if resp.StatusCode != http.StatusOK { - ss.logf("recording: error sending recording to %s: %v", recorder, resp.Status) + err := fmt.Errorf("recording: server responded with %s", resp.Status) + ss.logf("%v", err) + ss.cancelCtx(userVisibleError{ + msg: "recording server responded with: " + resp.Status, + error: err, + }) } }() diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 51db71c7f..443e0b3cb 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -326,6 +326,108 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule { } } +func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) + } + + var handler http.HandlerFunc + recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler(w, r) + })) + defer recordingServer.Close() + + s := &server{ + logf: t.Logf, + httpc: recordingServer.Client(), + lb: &localState{ + sshEnabled: true, + matchingRule: newSSHRule( + &tailcfg.SSHAction{ + Accept: true, + Recorders: []netip.AddrPort{ + netip.MustParseAddrPort(recordingServer.Listener.Addr().String()), + }, + }, + ), + }, + } + defer s.Shutdown() + + const sshUser = "alice" + cfg := &gossh.ClientConfig{ + User: sshUser, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + } + + tests := []struct { + name string + handler func(w http.ResponseWriter, r *http.Request) + sshCommand string + wantClientOutput string + }{ + { + name: "upload-denied", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + }, + sshCommand: "echo hello", + wantClientOutput: "recording: server responded with 403 Forbidden\r\n", + }, + { + name: "upload-fails-after-starting", + handler: func(w http.ResponseWriter, r *http.Request) { + r.Body.Read(make([]byte, 1)) + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusInternalServerError) + }, + sshCommand: "echo hello && sleep 1 && echo world", + wantClientOutput: "hello\n\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n", + }, + } + + src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tstest.Replace(t, &handler, tt.handler) + sc, dc := memnet.NewTCPConn(src, dst, 1024) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + if err != nil { + t.Errorf("client: %v", err) + return + } + client := gossh.NewClient(c, chans, reqs) + defer client.Close() + session, err := client.NewSession() + if err != nil { + t.Errorf("client: %v", err) + return + } + defer session.Close() + t.Logf("client established session") + got, err := session.CombinedOutput(tt.sshCommand) + if err != nil { + t.Logf("client got: %q: %v", got, err) + } else { + t.Errorf("client did not get kicked out: %q", got) + } + if string(got) != tt.wantClientOutput { + t.Errorf("client got %q, want %q", got, tt.wantClientOutput) + } + }() + if err := s.HandleSSHConn(dc); err != nil { + t.Errorf("unexpected error: %v", err) + } + wg.Wait() + }) + } +} + // TestSSHRecordingNonInteractive tests that the SSH server records the SSH session // when the client is not interactive (i.e. no PTY). // It starts a local SSH server and a recording server. The recording server @@ -346,30 +448,28 @@ func TestSSHRecordingNonInteractive(t *testing.T) { t.Error(err) return } - w.WriteHeader(http.StatusOK) })) defer recordingServer.Close() - state := &localState{ - sshEnabled: true, - matchingRule: newSSHRule( - &tailcfg.SSHAction{ - Accept: true, - Recorders: []netip.AddrPort{ - must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())), - }, - }, - ), - } s := &server{ - logf: t.Logf, + logf: logger.Discard, httpc: recordingServer.Client(), + lb: &localState{ + sshEnabled: true, + matchingRule: newSSHRule( + &tailcfg.SSHAction{ + Accept: true, + Recorders: []netip.AddrPort{ + must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())), + }, + }, + ), + }, } defer s.Shutdown() src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) sc, dc := memnet.NewTCPConn(src, dst, 1024) - s.lb = state const sshUser = "alice" cfg := &gossh.ClientConfig{