ssh/tailssh: flesh out env, support non-pty commands

Updates #3802

Change-Id: I7022460117542a5424919144828bf571c7c19ec0
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/4021/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent 7d897229d9
commit 3c2cd854be

@ -150,7 +150,7 @@ func (srv *server) handleSSH(s ssh.Session) {
} }
action, localUser, ok := evalSSHPolicy(pol, sctx) action, localUser, ok := evalSSHPolicy(pol, sctx)
if ok && action.Message != "" { if ok && action.Message != "" {
io.WriteString(s, action.Message) io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1))
} }
if !ok || action.Reject { if !ok || action.Reject {
logf("ssh: access denied for %q from %v", uprof.LoginName, srcIP) logf("ssh: access denied for %q from %v", uprof.LoginName, srcIP)
@ -160,62 +160,102 @@ func (srv *server) handleSSH(s ssh.Session) {
if !action.Accept || action.HoldAndDelegate != "" { if !action.Accept || action.HoldAndDelegate != "" {
fmt.Fprintf(s, "TODO: other SSHAction outcomes") fmt.Fprintf(s, "TODO: other SSHAction outcomes")
s.Exit(1) s.Exit(1)
} }
if !isPty { lu, err := user.Lookup(localUser)
fmt.Fprintf(s, "TODO scp etc\n") if err != nil {
logf("ssh: user Lookup %q: %v", localUser, err)
s.Exit(1) s.Exit(1)
return
} }
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ())
var cmd *exec.Cmd var cmd *exec.Cmd
if os.Getuid() != 0 { if euid := os.Geteuid(); euid != 0 {
u, err := user.Current() if lu.Uid != fmt.Sprint(euid) {
if err != nil { logf("ssh: can't switch to user %q from process euid %v", localUser, euid)
logf("failed to get current user: %v", err) fmt.Fprintf(s, "can't switch user\n")
s.Exit(1) s.Exit(1)
return return
} }
if u.Username != localUser { cmd = exec.Command(loginShell(lu.Uid))
fmt.Fprintf(s, "can't switch user\n") } else {
if rawCmd := s.RawCommand(); rawCmd != "" {
cmd = exec.Command("/usr/bin/env", "su", "-c", rawCmd, localUser)
cmd.Dir = lu.HomeDir
cmd.Env = append(cmd.Env, envForUser(lu)...)
// TODO: and Env for PATH, SSH_CONNECTION, SSH_CLIENT, XDG_SESSION_TYPE, XDG_*, etc
} else {
cmd = exec.Command("/usr/bin/env", "su", "-", localUser)
}
}
if ptyReq.Term != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
}
// TODO(bradfitz,maisem): also blend in user's s.Environ()
logf("Running: %q", cmd.Args)
var toCmd io.WriteCloser
var fromCmd io.ReadCloser
if isPty {
f, err := pty.StartWithSize(cmd, &pty.Winsize{
Rows: uint16(ptyReq.Window.Width),
Cols: uint16(ptyReq.Window.Height),
})
if err != nil {
logf("running shell: %v", err)
s.Exit(1) s.Exit(1)
return return
} }
cmd = exec.Command(loginShell(u.Uid)) defer f.Close()
toCmd = f
fromCmd = f
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
}()
} else { } else {
cmd = exec.Command("/usr/bin/env", "su", "-", localUser) stdin, stdout, stderr, err := startWithStdPipes(cmd)
} if err != nil {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) logf("ssh: start error: %f", err)
f, err := pty.Start(cmd) s.Exit(1)
if err != nil { return
logf("running shell: %v", err) }
s.Exit(1) fromCmd, toCmd = stdout, stdin
return go func() { io.Copy(s.Stderr(), stderr) }()
} }
if action.SesssionDuration != 0 { if action.SesssionDuration != 0 {
t := time.AfterFunc(action.SesssionDuration, func() { t := time.AfterFunc(action.SesssionDuration, func() {
logf("terminating SSH session from %v after max duration", srcIP) logf("terminating SSH session from %v after max duration", srcIP)
cmd.Process.Kill() cmd.Process.Kill()
f.Close()
}) })
defer t.Stop() defer t.Stop()
} }
defer f.Close()
go func() { go func() {
for win := range winCh { _, err := io.Copy(toCmd, s) // stdin
setWinsize(f, win.Width, win.Height) logf("ssh: stdin copy: %v", err)
} toCmd.Close()
}() }()
go func() { go func() {
io.Copy(f, s) // stdin _, err := io.Copy(s, fromCmd) // stdout
logf("ssh: stdout copy: %v", err)
}() }()
io.Copy(s, f) // stdout
cmd.Process.Kill() err = cmd.Wait()
if err := cmd.Wait(); err != nil { if err == nil {
s.Exit(1) logf("ssh: Wait: ok")
s.Exit(0)
return
}
if ee, ok := err.(*exec.ExitError); ok {
code := ee.ProcessState.ExitCode()
logf("ssh: Wait: code=%v", code)
s.Exit(code)
return
} }
s.Exit(0)
logf("ssh: Wait: %v", err)
s.Exit(1)
return return
} }
@ -327,3 +367,37 @@ func loginShell(uid string) string {
} }
return "/bin/bash" return "/bin/bash"
} }
func startWithStdPipes(cmd *exec.Cmd) (stdin io.WriteCloser, stdout, stderr io.ReadCloser, err error) {
defer func() {
if err != nil {
for _, c := range []io.Closer{stdin, stdout, stderr} {
if c != nil {
c.Close()
}
}
}
}()
stdin, err = cmd.StdinPipe()
if err != nil {
return
}
stdout, err = cmd.StdoutPipe()
if err != nil {
return
}
stderr, err = cmd.StderrPipe()
if err != nil {
return
}
err = cmd.Start()
return
}
func envForUser(u *user.User) []string {
return []string{
fmt.Sprintf("SHELL=" + loginShell(u.Uid)),
fmt.Sprintf("USER=" + u.Username),
fmt.Sprintf("HOME=" + u.HomeDir),
}
}

Loading…
Cancel
Save