From c1445155ef7faadece3a8fb5343caad04ac09e81 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 27 Apr 2022 13:23:13 -0700 Subject: [PATCH] ssh/tailssh: handle Control-C during hold-and-delegate prompt Fixes #4549 Change-Id: Iafc61af5e08cd03564d39cf667e940b2417714cc Signed-off-by: Brad Fitzpatrick --- ssh/tailssh/ctxreader.go | 112 +++++++++++++++++++++++++++++++++++++++ ssh/tailssh/tailssh.go | 43 +++++++++++++-- 2 files changed, 152 insertions(+), 3 deletions(-) create mode 100644 ssh/tailssh/ctxreader.go diff --git a/ssh/tailssh/ctxreader.go b/ssh/tailssh/ctxreader.go new file mode 100644 index 000000000..ce0a03526 --- /dev/null +++ b/ssh/tailssh/ctxreader.go @@ -0,0 +1,112 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tailssh + +import ( + "context" + "io" + "sync" + + "tailscale.com/tempfork/gliderlabs/ssh" +) + +// readResult is a result from a io.Reader.Read call, +// as used by contextReader. +type readResult struct { + buf []byte // ownership passed on chan send + err error +} + +// contextReader wraps an io.Reader, providing a ReadContext method +// that can be aborted before yielding bytes. If it's aborted, subsequent +// reads can get those byte(s) later. +type contextReader struct { + r io.Reader + + // buffered is leftover data from a previous read call that wasn't entirely + // consumed. + buffered []byte + // readErr is a previous read error that was seen while filling buffered. It + // should be returned to the caller after bufffered is consumed. + readErr error + + mu sync.Mutex // guards ch only + + // ch is non-nil if a goroutine had been started and has a result to be + // read. The goroutine may be either still running or done and has + // send to the channel. + ch chan readResult +} + +// HasOutstandingRead reports whether there's an oustanding Read call that's +// either currently blocked in a Read or whose result hasn't been consumed. +func (w *contextReader) HasOutstandingRead() bool { + w.mu.Lock() + defer w.mu.Unlock() + return w.ch != nil +} + +func (w *contextReader) setChan(c chan readResult) { + w.mu.Lock() + defer w.mu.Unlock() + w.ch = c +} + +// ReadContext is like Read, but takes a context permitting the read to be canceled. +// +// If the context becomes done, the underlying Read call continues and its result +// will be given to the next caller to ReadContext. +func (w *contextReader) ReadContext(ctx context.Context, p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + + n = copy(p, w.buffered) + if n > 0 { + w.buffered = w.buffered[n:] + if len(w.buffered) == 0 { + err = w.readErr + } + return n, err + } + + if w.ch == nil { + ch := make(chan readResult, 1) + w.setChan(ch) + go func() { + rbuf := make([]byte, len(p)) + n, err := w.r.Read(rbuf) + ch <- readResult{rbuf[:n], err} + }() + } + + select { + case <-ctx.Done(): + return 0, ctx.Err() + case rr := <-w.ch: + w.setChan(nil) + n = copy(p, rr.buf) + w.buffered = rr.buf[n:] + w.readErr = rr.err + if len(w.buffered) == 0 { + err = rr.err + } + return n, err + } +} + +// contextReaderSesssion implements ssh.Session, wrapping another +// ssh.Session but changing its Read method to use contextReader. +type contextReaderSesssion struct { + ssh.Session + cr *contextReader +} + +func (a contextReaderSesssion) Read(p []byte) (n int, err error) { + if a.cr.HasOutstandingRead() { + return a.cr.ReadContext(context.Background(), p) + } + return a.Session.Read(p) +} diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index e284c89ce..bed2c32fc 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -37,6 +37,7 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/logtail/backoff" "tailscale.com/net/tsaddr" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/types/logger" @@ -488,7 +489,8 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { // completed. It also handles SFTP requests. func (c *conn) handleConnPostSSHAuth(s ssh.Session) { sshUser := s.User() - action, err := c.resolveTerminalAction(s) + cr := &contextReader{r: s} + action, err := c.resolveTerminalAction(s, cr) if err != nil { c.logf("resolveTerminalAction: %v", err) io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n") @@ -501,6 +503,10 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) { return } + if cr.HasOutstandingRead() { + s = contextReaderSesssion{s, cr} + } + // Do this check after auth, but before starting the session. switch s.Subsystem() { case "sftp", "": @@ -522,8 +528,17 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) { // Any action with a Message in the chain will be printed to s. // // The returned SSHAction will be either Reject or Accept. -func (c *conn) resolveTerminalAction(s ssh.Session) (*tailcfg.SSHAction, error) { +func (c *conn) resolveTerminalAction(s ssh.Session, cr *contextReader) (*tailcfg.SSHAction, error) { action := c.action0 + + var awaitReadOnce sync.Once // to start Reads on cr + var sawInterrupt syncs.AtomicBool + var wg sync.WaitGroup + defer wg.Wait() // wait for awaitIntrOnce's goroutine to exit + + ctx, cancel := context.WithCancel(s.Context()) + defer cancel() + // Loop processing/fetching Actions until one reaches a // terminal state (Accept, Reject, or invalid Action), or // until fetchSSHAction times out due to the context being @@ -541,10 +556,32 @@ func (c *conn) resolveTerminalAction(s ssh.Session) (*tailcfg.SSHAction, error) if url == "" { return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate") } + awaitReadOnce.Do(func() { + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 1) + for { + n, err := cr.ReadContext(ctx, buf) + if err != nil { + return + } + if n > 0 && buf[0] == 0x03 { // Ctrl-C + sawInterrupt.Set(true) + s.Stderr().Write([]byte("Canceled.\r\n")) + s.Exit(1) + return + } + } + }() + }) url = c.expandDelegateURL(url) var err error - action, err = c.fetchSSHAction(s.Context(), url) + action, err = c.fetchSSHAction(ctx, url) if err != nil { + if sawInterrupt.Get() { + return nil, fmt.Errorf("aborted by user") + } return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err) } }