From 3c2cd854beef3952eb883fe269cab4fe80f020a8 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 19 Feb 2022 15:37:13 -0800 Subject: [PATCH] ssh/tailssh: flesh out env, support non-pty commands Updates #3802 Change-Id: I7022460117542a5424919144828bf571c7c19ec0 Signed-off-by: Brad Fitzpatrick --- ssh/tailssh/tailssh.go | 136 +++++++++++++++++++++++++++++++---------- 1 file changed, 105 insertions(+), 31 deletions(-) diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index a17af5cee..061987305 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -150,7 +150,7 @@ func (srv *server) handleSSH(s ssh.Session) { } action, localUser, ok := evalSSHPolicy(pol, sctx) if ok && action.Message != "" { - io.WriteString(s, action.Message) + io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1)) } if !ok || action.Reject { logf("ssh: access denied for %q from %v", uprof.LoginName, srcIP) @@ -160,62 +160,102 @@ func (srv *server) handleSSH(s ssh.Session) { if !action.Accept || action.HoldAndDelegate != "" { fmt.Fprintf(s, "TODO: other SSHAction outcomes") s.Exit(1) - } - if !isPty { - fmt.Fprintf(s, "TODO scp etc\n") + lu, err := user.Lookup(localUser) + if err != nil { + logf("ssh: user Lookup %q: %v", localUser, err) s.Exit(1) - return } + + logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ()) var cmd *exec.Cmd - if os.Getuid() != 0 { - u, err := user.Current() - if err != nil { - logf("failed to get current user: %v", err) + if euid := os.Geteuid(); euid != 0 { + if lu.Uid != fmt.Sprint(euid) { + logf("ssh: can't switch to user %q from process euid %v", localUser, euid) + fmt.Fprintf(s, "can't switch user\n") s.Exit(1) return } - if u.Username != localUser { - fmt.Fprintf(s, "can't switch user\n") + cmd = exec.Command(loginShell(lu.Uid)) + } else { + if rawCmd := s.RawCommand(); rawCmd != "" { + cmd = exec.Command("/usr/bin/env", "su", "-c", rawCmd, localUser) + cmd.Dir = lu.HomeDir + cmd.Env = append(cmd.Env, envForUser(lu)...) + // TODO: and Env for PATH, SSH_CONNECTION, SSH_CLIENT, XDG_SESSION_TYPE, XDG_*, etc + } else { + cmd = exec.Command("/usr/bin/env", "su", "-", localUser) + } + } + if ptyReq.Term != "" { + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) + } + // TODO(bradfitz,maisem): also blend in user's s.Environ() + logf("Running: %q", cmd.Args) + var toCmd io.WriteCloser + var fromCmd io.ReadCloser + if isPty { + f, err := pty.StartWithSize(cmd, &pty.Winsize{ + Rows: uint16(ptyReq.Window.Width), + Cols: uint16(ptyReq.Window.Height), + }) + if err != nil { + logf("running shell: %v", err) s.Exit(1) return } - cmd = exec.Command(loginShell(u.Uid)) + defer f.Close() + toCmd = f + fromCmd = f + go func() { + for win := range winCh { + setWinsize(f, win.Width, win.Height) + } + }() } else { - cmd = exec.Command("/usr/bin/env", "su", "-", localUser) - } - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) - f, err := pty.Start(cmd) - if err != nil { - logf("running shell: %v", err) - s.Exit(1) - return + stdin, stdout, stderr, err := startWithStdPipes(cmd) + if err != nil { + logf("ssh: start error: %f", err) + s.Exit(1) + return + } + fromCmd, toCmd = stdout, stdin + go func() { io.Copy(s.Stderr(), stderr) }() } if action.SesssionDuration != 0 { t := time.AfterFunc(action.SesssionDuration, func() { logf("terminating SSH session from %v after max duration", srcIP) cmd.Process.Kill() - f.Close() }) defer t.Stop() } - defer f.Close() go func() { - for win := range winCh { - setWinsize(f, win.Width, win.Height) - } + _, err := io.Copy(toCmd, s) // stdin + logf("ssh: stdin copy: %v", err) + toCmd.Close() }() go func() { - io.Copy(f, s) // stdin + _, err := io.Copy(s, fromCmd) // stdout + logf("ssh: stdout copy: %v", err) }() - io.Copy(s, f) // stdout - cmd.Process.Kill() - if err := cmd.Wait(); err != nil { - s.Exit(1) + + err = cmd.Wait() + if err == nil { + logf("ssh: Wait: ok") + s.Exit(0) + return + } + if ee, ok := err.(*exec.ExitError); ok { + code := ee.ProcessState.ExitCode() + logf("ssh: Wait: code=%v", code) + s.Exit(code) + return } - s.Exit(0) + + logf("ssh: Wait: %v", err) + s.Exit(1) return } @@ -327,3 +367,37 @@ func loginShell(uid string) string { } return "/bin/bash" } + +func startWithStdPipes(cmd *exec.Cmd) (stdin io.WriteCloser, stdout, stderr io.ReadCloser, err error) { + defer func() { + if err != nil { + for _, c := range []io.Closer{stdin, stdout, stderr} { + if c != nil { + c.Close() + } + } + } + }() + stdin, err = cmd.StdinPipe() + if err != nil { + return + } + stdout, err = cmd.StdoutPipe() + if err != nil { + return + } + stderr, err = cmd.StderrPipe() + if err != nil { + return + } + err = cmd.Start() + return +} + +func envForUser(u *user.User) []string { + return []string{ + fmt.Sprintf("SHELL=" + loginShell(u.Uid)), + fmt.Sprintf("USER=" + u.Username), + fmt.Sprintf("HOME=" + u.HomeDir), + } +}