diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 5430b02fa..3865a50ba 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -17,6 +17,7 @@ import ( "io" "net" "net/http" + "net/http/httptrace" "net/netip" "net/url" "os" @@ -42,6 +43,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" + "tailscale.com/util/multierr" "tailscale.com/version/distro" ) @@ -79,33 +81,11 @@ type server struct { // mu protects the following mu sync.Mutex - httpc *http.Client // for calling out to peers. activeConns map[*conn]bool // set; value is always true fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL shutdownCalled bool } -// sessionRecordingClient returns an http.Client that uses srv.lb.Dialer() to -// dial connections. This is used to make requests to the session recording -// server to upload session recordings. -func (srv *server) sessionRecordingClient() *http.Client { - srv.mu.Lock() - defer srv.mu.Unlock() - if srv.httpc != nil { - return srv.httpc - } - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - return srv.lb.Dialer().UserDial(ctx, network, addr) - } - srv.httpc = &http.Client{ - Transport: tr, - } - return srv.httpc -} - func (srv *server) now() time.Time { if srv != nil && srv.timeNow != nil { return srv.timeNow() @@ -1078,7 +1058,7 @@ func (ss *sshSession) run() { if err != nil { var uve userVisibleError if errors.As(err, &uve) { - fmt.Fprintf(ss, "%s\r\n", uve) + fmt.Fprintf(ss, "%s\r\n", uve.SSHTerminationMessage()) } else { fmt.Fprintf(ss, "can't start new recording\r\n") } @@ -1086,7 +1066,9 @@ func (ss *sshSession) run() { ss.Exit(1) return } - defer rec.Close() + if rec != nil { + defer rec.Close() + } } } @@ -1169,15 +1151,16 @@ func (ss *sshSession) run() { // If the final action has a non-empty list of recorders, that list is // returned. Otherwise, the list of recorders from the initial action // is returned. -func (ss *sshSession) recorders() []netip.AddrPort { +func (ss *sshSession) recorders() ([]netip.AddrPort, *tailcfg.SSHRecorderFailureAction) { if len(ss.conn.finalAction.Recorders) > 0 { - return ss.conn.finalAction.Recorders + return ss.conn.finalAction.Recorders, ss.conn.finalAction.OnRecordingFailure } - return ss.conn.action0.Recorders + return ss.conn.action0.Recorders, ss.conn.action0.OnRecordingFailure } func (ss *sshSession) shouldRecord() bool { - return len(ss.recorders()) > 0 + recs, _ := ss.recorders() + return len(recs) > 0 } type sshConnInfo struct { @@ -1409,16 +1392,120 @@ type CastHeader struct { LocalUser string `json:"localUser"` } +// sessionRecordingClient returns an http.Client that uses srv.lb.Dialer() to +// dial connections. This is used to make requests to the session recording +// server to upload session recordings. +// It uses the provided dialCtx to dial connections, and limits a single dial +// to 5 seconds. +func (ss *sshSession) sessionRecordingClient(dialCtx context.Context) (*http.Client, error) { + dialer := ss.conn.srv.lb.Dialer() + if dialer == nil { + return nil, errors.New("no peer API transport") + } + tr := dialer.PeerAPITransport().Clone() + dialContextFn := tr.DialContext + + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + go func() { + select { + case <-perAttemptCtx.Done(): + case <-dialCtx.Done(): + cancel() + } + }() + return dialContextFn(perAttemptCtx, network, addr) + } + return &http.Client{ + Transport: tr, + }, nil +} + +// connectToRecorder connects to the recorder at any of the provided addresses. +// It returns the first successful response, or a multierr if all attempts fail. +// +// On success, it returns a WriteCloser that can be used to upload the +// recording, and a channel that will be sent an error (or nil) when the upload +// fails or completes. +func (ss *sshSession) connectToRecorder(ctx context.Context, recs []netip.AddrPort) (io.WriteCloser, <-chan error, error) { + if len(recs) == 0 { + return nil, nil, errors.New("no recorders configured") + } + + // We use a special context for dialing the recorder, so that we can + // limit the time we spend dialing to 30 seconds and still have an + // unbounded context for the upload. + dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second) + defer dialCancel() + hc, err := ss.sessionRecordingClient(dialCtx) + if err != nil { + return nil, nil, err + } + var errs []error + for _, ap := range recs { + // We dial the recorder and wait for it to send a 100-continue + // response before returning from this function. This ensures that + // the recorder is ready to accept the recording. + + // got100 is closed when we receive the 100-continue response. + got100 := make(chan struct{}) + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + Got100Continue: func() { + close(got100) + }, + }) + + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr) + if err != nil { + errs = append(errs, fmt.Errorf("recording: error starting recording: %w", err)) + continue + } + // We set the Expect header to 100-continue, so that the recorder + // will send a 100-continue response before it starts reading the + // request body. + req.Header.Set("Expect", "100-continue") + + // errChan is used to indicate the result of the request. + errChan := make(chan error, 1) + go func() { + resp, err := hc.Do(req) + if err != nil { + errChan <- fmt.Errorf("recording: error starting recording: %w", err) + return + } + if resp.StatusCode != 200 { + errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status) + return + } + errChan <- nil + }() + select { + case <-got100: + case err := <-errChan: + // If we get an error before we get the 100-continue response, + // we need to try another recorder. + if err == nil { + // If the error is nil, we got a 200 response, which + // is unexpected as we haven't sent any data yet. + err = errors.New("recording: unexpected EOF") + } + errs = append(errs, err) + continue + } + return pw, errChan, nil + } + return nil, nil, multierr.New(errs...) +} + // startNewRecording starts a new SSH session recording. +// It may return a nil recording if recording is not available. func (ss *sshSession) startNewRecording() (_ *recording, err error) { - recorders := ss.recorders() + recorders, onFailure := ss.recorders() if len(recorders) == 0 { return nil, errors.New("no recorders configured") } - recorder := recorders[0] - if len(recorders) > 1 { - ss.logf("warning: multiple recorders configured, using first one: %v", recorder) - } var w ssh.Window if ptyReq, _, isPtyReq := ss.Pty(); isPtyReq { @@ -1436,51 +1523,43 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { start: now, } - pr, pw := io.Pipe() - // We want to use a background context for uploading and not ss.ctx. // ss.ctx is closed when the session closes, but we don't want to break the upload at that time. // Instead we want to wait for the session to close the writer when it finishes. ctx := context.Background() - req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", recorder.Addr(), recorder.Port()), pr) + wc, errChan, err := ss.connectToRecorder(ctx, recorders) if err != nil { - pr.Close() - pw.Close() - return nil, err + // TODO(catzkorn): notify control here. + if onFailure != nil && onFailure.RejectSessionWithMessage != "" { + ss.logf("recording: error starting recording (rejecting session): %v", err) + return nil, userVisibleError{ + error: err, + msg: onFailure.RejectSessionWithMessage, + } + } + ss.logf("recording: error starting recording (failing open): %v", err) + return nil, nil } - // We want to wait for the server to respond with 100 Continue to notifiy us - // that it's ready to receive data. We do this to block the session from - // starting until the server is ready to receive data. - // It also allows the server to reject the request before we start sending - // data. - req.Header.Set("Expect", "100-continue") + go func() { - defer pw.Close() - ss.logf("starting asciinema recording to %s", recorder) - hc := ss.conn.srv.sessionRecordingClient() - resp, err := hc.Do(req) - if err != nil { - err := fmt.Errorf("recording: error sending recording: %w", err) - ss.logf("%v", err) - ss.cancelCtx(userVisibleError{ - msg: "recording: error sending recording", - error: err, - }) + err := <-errChan + if err == nil { + // Success. return } - defer resp.Body.Close() - defer ss.cancelCtx(errors.New("recording: done")) - if resp.StatusCode != http.StatusOK { - err := fmt.Errorf("recording: server responded with %s", resp.Status) - ss.logf("%v", err) + // TODO(catzkorn): notify control here. + if onFailure != nil && onFailure.TerminateSessionWithMessage != "" { + ss.logf("recording: error uploading recording (closing session): %v", err) ss.cancelCtx(userVisibleError{ - msg: "recording server responded with: " + resp.Status, error: err, + msg: onFailure.TerminateSessionWithMessage, }) + return } + ss.logf("recording: error uploading recording (failing open): %v", err) }() - rec.out = pw + rec.out = wc ch := CastHeader{ Version: 2, @@ -1515,7 +1594,7 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { return nil, err } j = append(j, '\n') - if _, err := pw.Write(j); err != nil { + if _, err := rec.out.Write(j); err != nil { if errors.Is(err, io.ErrClosedPipe) && ss.ctx.Err() != nil { // If we got an io.ErrClosedPipe, it's likely because // the recording server closed the connection on us. Return diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index d262093ec..c0935d24b 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -240,7 +240,7 @@ var ( ) func (ts *localState) Dialer() *tsdial.Dialer { - return nil + return &tsdial.Dialer{} } func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) { @@ -338,8 +338,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { defer recordingServer.Close() s := &server{ - logf: t.Logf, - httpc: recordingServer.Client(), + logf: t.Logf, lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -348,6 +347,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { Recorders: []netip.AddrPort{ netip.MustParseAddrPort(recordingServer.Listener.Addr().String()), }, + OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{ + RejectSessionWithMessage: "session rejected", + TerminateSessionWithMessage: "session terminated", + }, }, ), }, @@ -374,7 +377,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { w.WriteHeader(http.StatusForbidden) }, sshCommand: "echo hello", - wantClientOutput: "recording: server responded with 403 Forbidden\r\n", + wantClientOutput: "session rejected\r\n", clientOutputMustNotContain: []string{"hello"}, }, @@ -386,7 +389,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { w.WriteHeader(http.StatusInternalServerError) }, sshCommand: "echo hello && sleep 1 && echo world", - wantClientOutput: "\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n", + wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n", clientOutputMustNotContain: []string{"world"}, }, @@ -440,6 +443,103 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { } } +func TestMultipleRecorders(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) + } + done := make(chan struct{}) + recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(done) + io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + })) + defer recordingServer.Close() + badRecorder, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + badRecorderAddr := badRecorder.Addr().String() + badRecorder.Close() + + badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + defer badRecordingServer500.Close() + + badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer badRecordingServer200.Close() + + s := &server{ + logf: t.Logf, + lb: &localState{ + sshEnabled: true, + matchingRule: newSSHRule( + &tailcfg.SSHAction{ + Accept: true, + Recorders: []netip.AddrPort{ + netip.MustParseAddrPort(badRecorderAddr), + netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()), + netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()), + netip.MustParseAddrPort(recordingServer.Listener.Addr().String()), + }, + OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{ + RejectSessionWithMessage: "session rejected", + TerminateSessionWithMessage: "session terminated", + }, + }, + ), + }, + } + defer s.Shutdown() + + src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) + sc, dc := memnet.NewTCPConn(src, dst, 1024) + + const sshUser = "alice" + cfg := &gossh.ClientConfig{ + User: sshUser, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + if err != nil { + t.Errorf("client: %v", err) + return + } + client := gossh.NewClient(c, chans, reqs) + defer client.Close() + session, err := client.NewSession() + if err != nil { + t.Errorf("client: %v", err) + return + } + defer session.Close() + t.Logf("client established session") + out, err := session.CombinedOutput("echo Ran echo!") + if err != nil { + t.Errorf("client: %v", err) + } + if string(out) != "Ran echo!\n" { + t.Errorf("client: unexpected output: %q", out) + } + }() + if err := s.HandleSSHConn(dc); err != nil { + t.Errorf("unexpected error: %v", err) + } + wg.Wait() + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for recording") + } +} + // TestSSHRecordingNonInteractive tests that the SSH server records the SSH session // when the client is not interactive (i.e. no PTY). // It starts a local SSH server and a recording server. The recording server @@ -464,8 +564,7 @@ func TestSSHRecordingNonInteractive(t *testing.T) { defer recordingServer.Close() s := &server{ - logf: logger.Discard, - httpc: recordingServer.Client(), + logf: logger.Discard, lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -474,6 +573,10 @@ func TestSSHRecordingNonInteractive(t *testing.T) { Recorders: []netip.AddrPort{ must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())), }, + OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{ + RejectSessionWithMessage: "session rejected", + TerminateSessionWithMessage: "session terminated", + }, }, ), }, diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 95057d490..b77193f2b 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -97,7 +97,8 @@ type CapabilityVersion int // - 58: 2023-03-10: Client retries lite map updates before restarting map poll. // - 59: 2023-03-16: Client understands Peers[].SelfNodeV4MasqAddrForThisPeer // - 60: 2023-04-06: Client understands IsWireGuardOnly -const CurrentCapabilityVersion CapabilityVersion = 60 +// - 61: 2023-04-18: Client understand SSHAction.SSHRecorderFailureAction +const CurrentCapabilityVersion CapabilityVersion = 61 type StableID string