ssh/tailssh: use context.WithCancelCause

It was using a custom implmentation of the context.WithCancelCause,
replace usage with stdlib.

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/7518/head
Maisem Ali 2 years ago committed by Maisem Ali
parent a2be1aabfa
commit e69682678f

@ -1,63 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package tailssh
import (
"context"
"sync"
"time"
)
// sshContext is the context.Context implementation we use for SSH
// that adds a CloseWithError method. Otherwise it's just a normalish
// Context.
type sshContext struct {
underlying context.Context
cancel context.CancelFunc // cancels underlying
mu sync.Mutex
closed bool
err error
}
func newSSHContext(ctx context.Context) *sshContext {
ctx, cancel := context.WithCancel(ctx)
return &sshContext{underlying: ctx, cancel: cancel}
}
func (ctx *sshContext) CloseWithError(err error) {
ctx.mu.Lock()
defer ctx.mu.Unlock()
if ctx.closed {
return
}
ctx.closed = true
ctx.err = err
ctx.cancel()
}
func (ctx *sshContext) Err() error {
ctx.mu.Lock()
defer ctx.mu.Unlock()
return ctx.err
}
func (ctx *sshContext) Done() <-chan struct{} { return ctx.underlying.Done() }
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return }
func (ctx *sshContext) Value(k any) any { return ctx.underlying.Value(k) }
// userVisibleError is a wrapper around an error that implements
// SSHTerminationError, so msg is written to their session.
type userVisibleError struct {
msg string
error
}
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
// SSHTerminationError is implemented by errors that terminate an SSH
// session and should be written to user's sessions.
type SSHTerminationError interface {
error
SSHTerminationMessage() string
}

@ -787,7 +787,8 @@ type sshSession struct {
sharedID string // ID that's shared with control sharedID string // ID that's shared with control
logf logger.Logf logf logger.Logf
ctx *sshContext // implements context.Context ctx context.Context
cancelCtx context.CancelCauseFunc
conn *conn conn *conn
agentListener net.Listener // non-nil if agent-forwarding requested+allowed agentListener net.Listener // non-nil if agent-forwarding requested+allowed
@ -812,10 +813,12 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) {
func (c *conn) newSSHSession(s ssh.Session) *sshSession { func (c *conn) newSSHSession(s ssh.Session) *sshSession {
sharedID := fmt.Sprintf("sess-%s-%02x", c.srv.now().UTC().Format("20060102T150405"), randBytes(5)) sharedID := fmt.Sprintf("sess-%s-%02x", c.srv.now().UTC().Format("20060102T150405"), randBytes(5))
c.logf("starting session: %v", sharedID) c.logf("starting session: %v", sharedID)
ctx, cancel := context.WithCancelCause(s.Context())
return &sshSession{ return &sshSession{
Session: s, Session: s,
sharedID: sharedID, sharedID: sharedID,
ctx: newSSHContext(s.Context()), ctx: ctx,
cancelCtx: cancel,
conn: c, conn: c,
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "), logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
} }
@ -844,7 +847,7 @@ func (c *conn) checkStillValid() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
for _, s := range c.sessions { for _, s := range c.sessions {
s.ctx.CloseWithError(userVisibleError{ s.cancelCtx(userVisibleError{
fmt.Sprintf("Access revoked.\r\n"), fmt.Sprintf("Access revoked.\r\n"),
context.Canceled, context.Canceled,
}) })
@ -897,7 +900,7 @@ func (ss *sshSession) killProcessOnContextDone() {
// Either the process has already exited, in which case this does nothing. // Either the process has already exited, in which case this does nothing.
// Or, the process is still running in which case this will kill it. // Or, the process is still running in which case this will kill it.
ss.exitOnce.Do(func() { ss.exitOnce.Do(func() {
err := ss.ctx.Err() err := context.Cause(ss.ctx)
if serr, ok := err.(SSHTerminationError); ok { if serr, ok := err.(SSHTerminationError); ok {
msg := serr.SSHTerminationMessage() msg := serr.SSHTerminationMessage()
if msg != "" { if msg != "" {
@ -997,7 +1000,7 @@ var recordSSH = envknob.RegisterBool("TS_DEBUG_LOG_SSH")
func (ss *sshSession) run() { func (ss *sshSession) run() {
metricActiveSessions.Add(1) metricActiveSessions.Add(1)
defer metricActiveSessions.Add(-1) defer metricActiveSessions.Add(-1)
defer ss.ctx.CloseWithError(errSessionDone) defer ss.cancelCtx(errSessionDone)
if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached { if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached {
fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n") fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n")
@ -1011,7 +1014,7 @@ func (ss *sshSession) run() {
if ss.conn.finalAction.SessionDuration != 0 { if ss.conn.finalAction.SessionDuration != 0 {
t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() { t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
ss.ctx.CloseWithError(userVisibleError{ ss.cancelCtx(userVisibleError{
fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration), fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration),
context.DeadlineExceeded, context.DeadlineExceeded,
}) })
@ -1066,7 +1069,7 @@ func (ss *sshSession) run() {
defer ss.stdin.Close() defer ss.stdin.Close()
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil { if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
logf("stdin copy: %v", err) logf("stdin copy: %v", err)
ss.ctx.CloseWithError(err) ss.cancelCtx(err)
} }
}() }()
var openOutputStreams atomic.Int32 var openOutputStreams atomic.Int32
@ -1080,7 +1083,7 @@ func (ss *sshSession) run() {
_, err := io.Copy(rec.writer("o", ss), ss.stdout) _, err := io.Copy(rec.writer("o", ss), ss.stdout)
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
logf("stdout copy: %v", err) logf("stdout copy: %v", err)
ss.ctx.CloseWithError(err) ss.cancelCtx(err)
} }
if openOutputStreams.Add(-1) == 0 { if openOutputStreams.Add(-1) == 0 {
ss.CloseWrite() ss.CloseWrite()
@ -1489,3 +1492,19 @@ var (
metricSFTP = clientmetric.NewCounter("ssh_sftp_requests") metricSFTP = clientmetric.NewCounter("ssh_sftp_requests")
metricLocalPortForward = clientmetric.NewCounter("ssh_local_port_forward_requests") metricLocalPortForward = clientmetric.NewCounter("ssh_local_port_forward_requests")
) )
// userVisibleError is a wrapper around an error that implements
// SSHTerminationError, so msg is written to their session.
type userVisibleError struct {
msg string
error
}
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
// SSHTerminationError is implemented by errors that terminate an SSH
// session and should be written to user's sessions.
type SSHTerminationError interface {
error
SSHTerminationMessage() string
}

Loading…
Cancel
Save