@ -823,12 +823,16 @@ type sshSession struct {
agentListener net . Listener // non-nil if agent-forwarding requested+allowed
agentListener net . Listener // non-nil if agent-forwarding requested+allowed
// initialized by launchProcess:
// initialized by launchProcess:
cmd * exec . Cmd
cmd * exec . Cmd
stdin io . WriteCloser
wrStdin io . WriteCloser
stdout io . ReadCloser
rdStdout io . ReadCloser
stderr io . Reader // nil for pty sessions
rdStderr io . ReadCloser // rdStderr is nil for pty sessions
ptyReq * ssh . Pty // non-nil for pty sessions
ptyReq * ssh . Pty // non-nil for pty sessions
tty * os . File // non-nil for pty sessions, must be closed after process exits
// childPipes is a list of pipes that need to be closed when the process exits.
// For pty sessions, this is the tty fd.
// For non-pty sessions, this is the stdin, stdout, stderr fds.
childPipes [ ] io . Closer
// We use this sync.Once to ensure that we only terminate the process once,
// We use this sync.Once to ensure that we only terminate the process once,
// either it exits itself or is terminated
// either it exits itself or is terminated
@ -1107,21 +1111,22 @@ func (ss *sshSession) run() {
var processDone atomic . Bool
var processDone atomic . Bool
go func ( ) {
go func ( ) {
defer ss . s tdin. Close ( )
defer ss . wrS tdin. Close ( )
if _ , err := io . Copy ( rec . writer ( "i" , ss . s tdin) , ss ) ; err != nil {
if _ , err := io . Copy ( rec . writer ( "i" , ss . wrS tdin) , ss ) ; err != nil {
logf ( "stdin copy: %v" , err )
logf ( "stdin copy: %v" , err )
ss . cancelCtx ( err )
ss . cancelCtx ( err )
}
}
} ( )
} ( )
outputDone := make ( chan struct { } )
var openOutputStreams atomic . Int32
var openOutputStreams atomic . Int32
if ss . s tderr != nil {
if ss . rdS tderr != nil {
openOutputStreams . Store ( 2 )
openOutputStreams . Store ( 2 )
} else {
} else {
openOutputStreams . Store ( 1 )
openOutputStreams . Store ( 1 )
}
}
go func ( ) {
go func ( ) {
defer ss . s tdout. Close ( )
defer ss . rdS tdout. Close ( )
_ , err := io . Copy ( rec . writer ( "o" , ss ) , ss . s tdout)
_ , err := io . Copy ( rec . writer ( "o" , ss ) , ss . rdS tdout)
if err != nil && ! errors . Is ( err , io . EOF ) {
if err != nil && ! errors . Is ( err , io . EOF ) {
isErrBecauseProcessExited := processDone . Load ( ) && errors . Is ( err , syscall . EIO )
isErrBecauseProcessExited := processDone . Load ( ) && errors . Is ( err , syscall . EIO )
if ! isErrBecauseProcessExited {
if ! isErrBecauseProcessExited {
@ -1131,32 +1136,41 @@ func (ss *sshSession) run() {
}
}
if openOutputStreams . Add ( - 1 ) == 0 {
if openOutputStreams . Add ( - 1 ) == 0 {
ss . CloseWrite ( )
ss . CloseWrite ( )
close ( outputDone )
}
}
} ( )
} ( )
// s tderr is nil for ptys.
// rdS tderr is nil for ptys.
if ss . s tderr != nil {
if ss . rdS tderr != nil {
go func ( ) {
go func ( ) {
_ , err := io . Copy ( ss . Stderr ( ) , ss . stderr )
defer ss . rdStderr . Close ( )
_ , err := io . Copy ( ss . Stderr ( ) , ss . rdStderr )
if err != nil {
if err != nil {
logf ( "stderr copy: %v" , err )
logf ( "stderr copy: %v" , err )
}
}
if openOutputStreams . Add ( - 1 ) == 0 {
if openOutputStreams . Add ( - 1 ) == 0 {
ss . CloseWrite ( )
ss . CloseWrite ( )
close ( outputDone )
}
}
} ( )
} ( )
}
}
if ss . tty != nil {
// If running a tty session, close the tty when the session is done.
defer ss . tty . Close ( )
}
err = ss . cmd . Wait ( )
err = ss . cmd . Wait ( )
processDone . Store ( true )
processDone . Store ( true )
// This will either make the SSH Termination goroutine be a no-op,
// This will either make the SSH Termination goroutine be a no-op,
// or itself will be a no-op because the process was killed by the
// or itself will be a no-op because the process was killed by the
// aforementioned goroutine.
// aforementioned goroutine.
ss . exitOnce . Do ( func ( ) { } )
ss . exitOnce . Do ( func ( ) { } )
// Close the process-side of all pipes to signal the asynchronous
// io.Copy routines reading/writing from the pipes to terminate.
// Block for the io.Copy to finish before calling ss.Exit below.
closeAll ( ss . childPipes ... )
select {
case <- outputDone :
case <- ss . ctx . Done ( ) :
}
if err == nil {
if err == nil {
ss . logf ( "Session complete" )
ss . logf ( "Session complete" )
ss . Exit ( 0 )
ss . Exit ( 0 )
@ -1894,3 +1908,11 @@ type SSHTerminationError interface {
error
error
SSHTerminationMessage ( ) string
SSHTerminationMessage ( ) string
}
}
func closeAll ( cs ... io . Closer ) {
for _ , c := range cs {
if c != nil {
c . Close ( )
}
}
}