ssh/tailssh: close sshContext on context cancellation

This was preventing tailscaled from shutting down properly if there were
active sessions in certain states (e.g. waiting in check mode).

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/5886/head
Maisem Ali 2 years ago committed by Maisem Ali
parent 8fe04b035c
commit f172fc42f7

@ -5,6 +5,7 @@
package tailssh package tailssh
import ( import (
"context"
"sync" "sync"
"time" "time"
) )
@ -13,14 +14,16 @@ import (
// that adds a CloseWithError method. Otherwise it's just a normalish // that adds a CloseWithError method. Otherwise it's just a normalish
// Context. // Context.
type sshContext struct { type sshContext struct {
mu sync.Mutex underlying context.Context
closed bool cancel context.CancelFunc // cancels underlying
done chan struct{} mu sync.Mutex
err error closed bool
err error
} }
func newSSHContext() *sshContext { func newSSHContext(ctx context.Context) *sshContext {
return &sshContext{done: make(chan struct{})} ctx, cancel := context.WithCancel(ctx)
return &sshContext{underlying: ctx, cancel: cancel}
} }
func (ctx *sshContext) CloseWithError(err error) { func (ctx *sshContext) CloseWithError(err error) {
@ -31,7 +34,7 @@ func (ctx *sshContext) CloseWithError(err error) {
} }
ctx.closed = true ctx.closed = true
ctx.err = err ctx.err = err
close(ctx.done) ctx.cancel()
} }
func (ctx *sshContext) Err() error { func (ctx *sshContext) Err() error {
@ -40,9 +43,9 @@ func (ctx *sshContext) Err() error {
return ctx.err return ctx.err
} }
func (ctx *sshContext) Done() <-chan struct{} { return ctx.done } func (ctx *sshContext) Done() <-chan struct{} { return ctx.underlying.Done() }
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return } func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return }
func (ctx *sshContext) Value(any) any { return nil } func (ctx *sshContext) Value(k any) any { return ctx.underlying.Value(k) }
// userVisibleError is a wrapper around an error that implements // userVisibleError is a wrapper around an error that implements
// SSHTerminationError, so msg is written to their session. // SSHTerminationError, so msg is written to their session.

@ -770,7 +770,7 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession {
return &sshSession{ return &sshSession{
Session: s, Session: s,
sharedID: sharedID, sharedID: sharedID,
ctx: newSSHContext(), ctx: newSSHContext(s.Context()),
conn: c, conn: c,
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "), logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
} }

Loading…
Cancel
Save