ssh/tailssh: add more SSH tests, blend in env from ssh session

Updates #3802

Change-Id: I568c661cacbb0524afcd8be9577457ddba611f19
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/4035/head
Brad Fitzpatrick 2 years ago committed by Brad Fitzpatrick
parent 4686224e5a
commit 4b50977422

@ -232,11 +232,11 @@ func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *ssh
}
}
cmd.Dir = lu.HomeDir
cmd.Env = append(cmd.Env, s.Environ()...)
cmd.Env = append(cmd.Env, envForUser(lu)...)
if ptyReq.Term != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
}
// TODO(bradfitz,maisem): also blend in user's s.Environ()
logf("Running: %q", cmd.Args)
var toCmd io.WriteCloser
var fromCmd io.ReadCloser

@ -8,12 +8,15 @@
package tailssh
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"strings"
"testing"
"time"
@ -25,6 +28,7 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/types/logger"
"tailscale.com/util/lineread"
"tailscale.com/wgengine"
)
@ -231,12 +235,78 @@ func TestSSH(t *testing.T) {
}
}()
got, err := exec.Command("ssh",
"-p", fmt.Sprint(port),
"-o", "StrictHostKeyChecking=no",
"user@127.0.0.1", "env").CombinedOutput()
if err != nil {
t.Fatal(err)
execSSH := func(args ...string) *exec.Cmd {
cmd := exec.Command("ssh",
"-p", fmt.Sprint(port),
"-o", "StrictHostKeyChecking=no",
"user@127.0.0.1")
cmd.Args = append(cmd.Args, args...)
return cmd
}
t.Logf("Got: %s", got)
t.Run("env", func(t *testing.T) {
cmd := execSSH("env")
cmd.Env = append(os.Environ(), "LANG=foo")
got, err := cmd.CombinedOutput()
if err != nil {
t.Fatal(err)
}
m := parseEnv(got)
if got := m["USER"]; got == "" || got != u.Username {
t.Errorf("USER = %q; want %q", got, u.Username)
}
if got := m["HOME"]; got == "" || got != u.HomeDir {
t.Errorf("HOME = %q; want %q", got, u.HomeDir)
}
if got := m["PWD"]; got == "" || got != u.HomeDir {
t.Errorf("PWD = %q; want %q", got, u.HomeDir)
}
if got := m["SHELL"]; got == "" {
t.Errorf("no SHELL")
}
if got, want := m["LANG"], "foo"; got != want {
t.Errorf("LANG = %q; want %q", got, want)
}
t.Logf("got: %+v", m)
})
t.Run("stdout_stderr", func(t *testing.T) {
cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
var outBuf, errBuf bytes.Buffer
cmd.Stdout = &outBuf
cmd.Stderr = &errBuf
if err := cmd.Run(); err != nil {
t.Fatal(err)
}
t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
// TODO: figure out why these aren't right. should be
// "foo\n" and "bar\n", not "\n" and "bar\n".
})
t.Run("stdin", func(t *testing.T) {
cmd := execSSH("cat")
var outBuf bytes.Buffer
cmd.Stdout = &outBuf
const str = "foo\nbar\n"
cmd.Stdin = strings.NewReader(str)
if err := cmd.Run(); err != nil {
t.Fatal(err)
}
if got := outBuf.String(); got != str {
t.Errorf("got %q; want %q", got, str)
}
})
}
func parseEnv(out []byte) map[string]string {
e := map[string]string{}
lineread.Reader(bytes.NewReader(out), func(line []byte) error {
i := bytes.IndexByte(line, '=')
if i == -1 {
return nil
}
e[string(line[:i])] = string(line[i+1:])
return nil
})
return e
}

Loading…
Cancel
Save