diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index fc9260e77..51db71c7f 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -7,6 +7,7 @@ package tailssh import ( "bytes" + "context" "crypto/ed25519" "crypto/rand" "crypto/sha256" @@ -14,6 +15,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "net/http" "net/http/httptest" @@ -324,9 +326,101 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule { } } +// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session +// when the client is not interactive (i.e. no PTY). +// It starts a local SSH server and a recording server. The recording server +// records the SSH session and returns it to the test. +// The test then verifies that the recording has a valid CastHeader, it does not +// validate the contents of the recording. +func TestSSHRecordingNonInteractive(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) + } + var recording []byte + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer cancel() + var err error + recording, err = ioutil.ReadAll(r.Body) + if err != nil { + t.Error(err) + return + } + w.WriteHeader(http.StatusOK) + })) + defer recordingServer.Close() + + state := &localState{ + sshEnabled: true, + matchingRule: newSSHRule( + &tailcfg.SSHAction{ + Accept: true, + Recorders: []netip.AddrPort{ + must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())), + }, + }, + ), + } + s := &server{ + logf: t.Logf, + httpc: recordingServer.Client(), + } + defer s.Shutdown() + + src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) + sc, dc := memnet.NewTCPConn(src, dst, 1024) + s.lb = state + + const sshUser = "alice" + cfg := &gossh.ClientConfig{ + User: sshUser, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + if err != nil { + t.Errorf("client: %v", err) + return + } + client := gossh.NewClient(c, chans, reqs) + defer client.Close() + session, err := client.NewSession() + if err != nil { + t.Errorf("client: %v", err) + return + } + defer session.Close() + t.Logf("client established session") + _, err = session.CombinedOutput("echo Ran echo!") + if err != nil { + t.Errorf("client: %v", err) + } + }() + if err := s.HandleSSHConn(dc); err != nil { + t.Errorf("unexpected error: %v", err) + } + wg.Wait() + + <-ctx.Done() // wait for recording to finish + var ch CastHeader + if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil { + t.Fatal(err) + } + if ch.SSHUser != sshUser { + t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser) + } + if ch.Command != "echo Ran echo!" { + t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!") + } +} + func TestSSHAuthFlow(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skip("Not running on Linux, skipping") + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } acceptRule := newSSHRule(&tailcfg.SSHAction{ Accept: true,