diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 4aaaeafdb..9b751554e 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -232,11 +232,11 @@ func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *ssh } } cmd.Dir = lu.HomeDir + cmd.Env = append(cmd.Env, s.Environ()...) cmd.Env = append(cmd.Env, envForUser(lu)...) if ptyReq.Term != "" { cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) } - // TODO(bradfitz,maisem): also blend in user's s.Environ() logf("Running: %q", cmd.Args) var toCmd io.WriteCloser var fromCmd io.ReadCloser diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index be0b4febf..401febc52 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -8,12 +8,15 @@ package tailssh import ( + "bytes" "context" "errors" "fmt" "net" + "os" "os/exec" "os/user" + "strings" "testing" "time" @@ -25,6 +28,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/types/logger" + "tailscale.com/util/lineread" "tailscale.com/wgengine" ) @@ -231,12 +235,78 @@ func TestSSH(t *testing.T) { } }() - got, err := exec.Command("ssh", - "-p", fmt.Sprint(port), - "-o", "StrictHostKeyChecking=no", - "user@127.0.0.1", "env").CombinedOutput() - if err != nil { - t.Fatal(err) + execSSH := func(args ...string) *exec.Cmd { + cmd := exec.Command("ssh", + "-p", fmt.Sprint(port), + "-o", "StrictHostKeyChecking=no", + "user@127.0.0.1") + cmd.Args = append(cmd.Args, args...) + return cmd } - t.Logf("Got: %s", got) + + t.Run("env", func(t *testing.T) { + cmd := execSSH("env") + cmd.Env = append(os.Environ(), "LANG=foo") + got, err := cmd.CombinedOutput() + if err != nil { + t.Fatal(err) + } + m := parseEnv(got) + if got := m["USER"]; got == "" || got != u.Username { + t.Errorf("USER = %q; want %q", got, u.Username) + } + if got := m["HOME"]; got == "" || got != u.HomeDir { + t.Errorf("HOME = %q; want %q", got, u.HomeDir) + } + if got := m["PWD"]; got == "" || got != u.HomeDir { + t.Errorf("PWD = %q; want %q", got, u.HomeDir) + } + if got := m["SHELL"]; got == "" { + t.Errorf("no SHELL") + } + if got, want := m["LANG"], "foo"; got != want { + t.Errorf("LANG = %q; want %q", got, want) + } + t.Logf("got: %+v", m) + }) + + t.Run("stdout_stderr", func(t *testing.T) { + cmd := execSSH("sh", "-c", "echo foo; echo bar >&2") + var outBuf, errBuf bytes.Buffer + cmd.Stdout = &outBuf + cmd.Stderr = &errBuf + if err := cmd.Run(); err != nil { + t.Fatal(err) + } + t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes()) + // TODO: figure out why these aren't right. should be + // "foo\n" and "bar\n", not "\n" and "bar\n". + }) + + t.Run("stdin", func(t *testing.T) { + cmd := execSSH("cat") + var outBuf bytes.Buffer + cmd.Stdout = &outBuf + const str = "foo\nbar\n" + cmd.Stdin = strings.NewReader(str) + if err := cmd.Run(); err != nil { + t.Fatal(err) + } + if got := outBuf.String(); got != str { + t.Errorf("got %q; want %q", got, str) + } + }) +} + +func parseEnv(out []byte) map[string]string { + e := map[string]string{} + lineread.Reader(bytes.NewReader(out), func(line []byte) error { + i := bytes.IndexByte(line, '=') + if i == -1 { + return nil + } + e[string(line[:i])] = string(line[i+1:]) + return nil + }) + return e }