// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause //go:build integrationtest // +build integrationtest package tailssh import ( "bufio" "bytes" "context" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "fmt" "io" "log" "net" "net/http" "net/netip" "os" "os/exec" "path/filepath" "runtime" "strings" "testing" "time" "github.com/bramvdbogaerde/go-scp" "github.com/google/go-cmp/cmp" "github.com/pkg/sftp" gossh "github.com/tailscale/golang-x-crypto/ssh" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" glider "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/util/set" ) // This file contains integration tests of the SSH functionality. These tests // exercise everything except for the authentication logic. // // The tests make the following assumptions about the environment: // // - OS is one of MacOS or Linux // - Test is being run as root (e.g. go test -tags integrationtest -c . && sudo ./tailssh.test -test.run TestIntegration) // - TAILSCALED_PATH environment variable points at tailscaled binary // - User "testuser" exists // - "testuser" is in groups "groupone" and "grouptwo" func TestMain(m *testing.M) { // Create our log file. file, err := os.OpenFile("/tmp/tailscalessh.log", os.O_CREATE|os.O_WRONLY, 0666) if err != nil { log.Fatal(err) } file.Close() // Tail our log file. cmd := exec.Command("tail", "-F", "/tmp/tailscalessh.log") r, err := cmd.StdoutPipe() if err != nil { return } scanner := bufio.NewScanner(r) go func() { for scanner.Scan() { line := scanner.Text() log.Println(line) } }() err = cmd.Start() if err != nil { return } defer func() { // tail -f has a default sleep interval of 1 second, so it takes a // moment for it to finish reading our log file after we've terminated. // So, wait a bit to let it catch up. time.Sleep(2 * time.Second) }() m.Run() } func TestIntegrationSSH(t *testing.T) { debugTest.Store(true) t.Cleanup(func() { debugTest.Store(false) }) homeDir := "/home/testuser" if runtime.GOOS == "darwin" { homeDir = "/Users/testuser" } tests := []struct { cmd string want []string forceV1Behavior bool skip bool }{ { cmd: "id", want: []string{"testuser", "groupone", "grouptwo"}, forceV1Behavior: false, }, { cmd: "id", want: []string{"testuser", "groupone", "grouptwo"}, forceV1Behavior: true, }, { cmd: "pwd", want: []string{homeDir}, skip: os.Getenv("SKIP_FILE_OPS") == "1" || !fallbackToSUAvailable(), forceV1Behavior: false, }, { cmd: "echo 'hello'", want: []string{"hello"}, skip: os.Getenv("SKIP_FILE_OPS") == "1" || !fallbackToSUAvailable(), forceV1Behavior: false, }, } for _, test := range tests { if test.skip { continue } // run every test both without and with a shell for _, shell := range []bool{false, true} { shellQualifier := "no_shell" if shell { shellQualifier = "shell" } versionQualifier := "v2" if test.forceV1Behavior { versionQualifier = "v1" } t.Run(fmt.Sprintf("%s_%s_%s", test.cmd, shellQualifier, versionQualifier), func(t *testing.T) { s := testSession(t, test.forceV1Behavior) if shell { err := s.RequestPty("xterm", 40, 80, ssh.TerminalModes{ ssh.ECHO: 1, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, }) if err != nil { t.Fatalf("unable to request PTY: %s", err) } err = s.Shell() if err != nil { t.Fatalf("unable to request shell: %s", err) } // Read the shell prompt s.read() } got := s.run(t, test.cmd, shell) for _, want := range test.want { if !strings.Contains(got, want) { t.Errorf("%q does not contain %q", got, want) } } }) } } } func TestIntegrationSFTP(t *testing.T) { debugTest.Store(true) t.Cleanup(func() { debugTest.Store(false) }) for _, forceV1Behavior := range []bool{false, true} { name := "v2" if forceV1Behavior { name = "v1" } t.Run(name, func(t *testing.T) { filePath := "/home/testuser/sftptest.dat" if forceV1Behavior || !fallbackToSUAvailable() { filePath = "/tmp/sftptest.dat" } wantText := "hello world" cl := testClient(t, forceV1Behavior) scl, err := sftp.NewClient(cl) if err != nil { t.Fatalf("can't get sftp client: %s", err) } file, err := scl.Create(filePath) if err != nil { t.Fatalf("can't create file: %s", err) } _, err = file.Write([]byte(wantText)) if err != nil { t.Fatalf("can't write to file: %s", err) } err = file.Close() if err != nil { t.Fatalf("can't close file: %s", err) } file, err = scl.OpenFile(filePath, os.O_RDONLY) if err != nil { t.Fatalf("can't open file: %s", err) } defer file.Close() gotText, err := io.ReadAll(file) if err != nil { t.Fatalf("can't read file: %s", err) } if diff := cmp.Diff(string(gotText), wantText); diff != "" { t.Fatalf("unexpected file contents (-got +want):\n%s", diff) } s := testSessionFor(t, cl) got := s.run(t, "ls -l "+filePath, false) if !strings.Contains(got, "testuser") { t.Fatalf("unexpected file owner user: %s", got) } else if !strings.Contains(got, "testuser") { t.Fatalf("unexpected file owner group: %s", got) } }) } } func TestIntegrationSCP(t *testing.T) { debugTest.Store(true) t.Cleanup(func() { debugTest.Store(false) }) for _, forceV1Behavior := range []bool{false, true} { name := "v2" if forceV1Behavior { name = "v1" } t.Run(name, func(t *testing.T) { filePath := "/home/testuser/scptest.dat" if !fallbackToSUAvailable() { filePath = "/tmp/scptest.dat" } wantText := "hello world" cl := testClient(t, forceV1Behavior) scl, err := scp.NewClientBySSH(cl) if err != nil { t.Fatalf("can't get sftp client: %s", err) } err = scl.Copy(context.Background(), strings.NewReader(wantText), filePath, "0644", int64(len(wantText))) if err != nil { t.Fatalf("can't create file: %s", err) } outfile, err := os.CreateTemp("", "") if err != nil { t.Fatalf("can't create temp file: %s", err) } err = scl.CopyFromRemote(context.Background(), outfile, filePath) if err != nil { t.Fatalf("can't copy file from remote: %s", err) } outfile.Close() gotText, err := os.ReadFile(outfile.Name()) if err != nil { t.Fatalf("can't read file: %s", err) } if diff := cmp.Diff(string(gotText), wantText); diff != "" { t.Fatalf("unexpected file contents (-got +want):\n%s", diff) } s := testSessionFor(t, cl) got := s.run(t, "ls -l "+filePath, false) if !strings.Contains(got, "testuser") { t.Fatalf("unexpected file owner user: %s", got) } else if !strings.Contains(got, "testuser") { t.Fatalf("unexpected file owner group: %s", got) } }) } } func TestSSHAgentForwarding(t *testing.T) { debugTest.Store(true) t.Cleanup(func() { debugTest.Store(false) }) // Create a client SSH key tmpDir, err := os.MkdirTemp("", "") if err != nil { t.Fatal(err) } t.Cleanup(func() { _ = os.RemoveAll(tmpDir) }) pkFile := filepath.Join(tmpDir, "pk") clientKey, clientKeyRSA := generateClientKey(t, pkFile) // Start upstream SSH server l, err := net.Listen("tcp", "127.0.0.1:") if err != nil { t.Fatalf("unable to listen for SSH: %s", err) } t.Cleanup(func() { _ = l.Close() }) // Run an SSH server that accepts connections from that client SSH key. gs := glider.Server{ Handler: func(s glider.Session) { io.WriteString(s, "Hello world\n") }, PublicKeyHandler: func(ctx glider.Context, key glider.PublicKey) error { // Note - this is not meant to be cryptographically secure, it's // just checking that SSH agent forwarding is forwarding the right // key. a := key.Marshal() b := clientKey.PublicKey().Marshal() if !bytes.Equal(a, b) { return errors.New("key mismatch") } return nil }, } go gs.Serve(l) // Run tailscale SSH server and connect to it username := "testuser" tailscaleAddr := testServer(t, username, false) tcl, err := ssh.Dial("tcp", tailscaleAddr, &ssh.ClientConfig{ HostKeyCallback: ssh.InsecureIgnoreHostKey(), }) if err != nil { t.Fatal(err) } t.Cleanup(func() { tcl.Close() }) s, err := tcl.NewSession() if err != nil { t.Fatal(err) } // Set up SSH agent forwarding on the client err = agent.RequestAgentForwarding(s) if err != nil { t.Fatal(err) } keyring := agent.NewKeyring() keyring.Add(agent.AddedKey{ PrivateKey: clientKeyRSA, }) err = agent.ForwardToAgent(tcl, keyring) if err != nil { t.Fatal(err) } // Attempt to SSH to the upstream test server using the forwarded SSH key // and run the "true" command. upstreamHost, upstreamPort, err := net.SplitHostPort(l.Addr().String()) if err != nil { t.Fatal(err) } o, err := s.CombinedOutput(fmt.Sprintf(`ssh -T -o StrictHostKeyChecking=no -p %s upstreamuser@%s "true"`, upstreamPort, upstreamHost)) if err != nil { t.Fatalf("unable to call true command: %s\n%s\n-------------------------", err, o) } } func fallbackToSUAvailable() bool { if runtime.GOOS != "linux" { return false } _, err := exec.LookPath("su") if err != nil { return false } // Some operating systems like Fedora seem to require login to be present // in order for su to work. _, err = exec.LookPath("login") return err == nil } type session struct { *ssh.Session stdin io.WriteCloser stdout io.ReadCloser stderr io.ReadCloser } func (s *session) run(t *testing.T, cmdString string, shell bool) string { t.Helper() if shell { _, err := s.stdin.Write([]byte(fmt.Sprintf("%s\n", cmdString))) if err != nil { t.Fatalf("unable to send command to shell: %s", err) } } else { err := s.Start(cmdString) if err != nil { t.Fatalf("unable to start command: %s", err) } } return s.read() } func (s *session) read() string { ch := make(chan []byte) go func() { for { b := make([]byte, 1) n, err := s.stdout.Read(b) if n > 0 { ch <- b } if err == io.EOF { return } } }() // Read first byte in blocking fashion. _got := <-ch // Read subsequent bytes in non-blocking fashion. readLoop: for { select { case b := <-ch: _got = append(_got, b...) case <-time.After(1 * time.Second): break readLoop } } return string(_got) } func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMethod) *ssh.Client { t.Helper() username := "testuser" addr := testServer(t, username, forceV1Behavior) cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{ HostKeyCallback: ssh.InsecureIgnoreHostKey(), Auth: authMethods, }) if err != nil { t.Fatal(err) } t.Cleanup(func() { cl.Close() }) return cl } func testServer(t *testing.T, username string, forceV1Behavior bool) string { srv := &server{ lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior}, logf: log.Printf, tailscaledPath: os.Getenv("TAILSCALED_PATH"), timeNow: time.Now, } l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } t.Cleanup(func() { l.Close() }) go func() { for { conn, err := l.Accept() if err == nil { go srv.HandleSSHConn(&addressFakingConn{conn}) } } }() return l.Addr().String() } func testSession(t *testing.T, forceV1Behavior bool) *session { cl := testClient(t, forceV1Behavior) return testSessionFor(t, cl) } func testSessionFor(t *testing.T, cl *ssh.Client) *session { s, err := cl.NewSession() if err != nil { t.Fatal(err) } t.Cleanup(func() { s.Close() }) stdinReader, stdinWriter := io.Pipe() stdoutReader, stdoutWriter := io.Pipe() stderrReader, stderrWriter := io.Pipe() s.Stdin = stdinReader s.Stdout = io.MultiWriter(stdoutWriter, os.Stdout) s.Stderr = io.MultiWriter(stderrWriter, os.Stderr) return &session{ Session: s, stdin: stdinWriter, stdout: stdoutReader, stderr: stderrReader, } } func generateClientKey(t *testing.T, privateKeyFile string) (ssh.Signer, *rsa.PrivateKey) { t.Helper() priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) } mk, err := x509.MarshalPKCS8PrivateKey(priv) if err != nil { t.Fatal(err) } privateKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk}) if privateKey == nil { t.Fatal("failed to encoded private key") } err = os.WriteFile(privateKeyFile, privateKey, 0600) if err != nil { t.Fatal(err) } signer, err := ssh.ParsePrivateKey(privateKey) if err != nil { t.Fatal(err) } return signer, priv } // testBackend implements ipnLocalBackend type testBackend struct { localUser string forceV1Behavior bool } func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) { var result []gossh.Signer var priv any var err error const keySize = 2048 priv, err = rsa.GenerateKey(rand.Reader, keySize) if err != nil { return nil, err } mk, err := x509.MarshalPKCS8PrivateKey(priv) if err != nil { return nil, err } hostKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk}) signer, err := gossh.ParsePrivateKey(hostKey) if err != nil { return nil, err } result = append(result, signer) return result, nil } func (tb *testBackend) ShouldRunSSH() bool { return true } func (tb *testBackend) NetMap() *netmap.NetworkMap { capMap := make(set.Set[tailcfg.NodeCapability]) if tb.forceV1Behavior { capMap[tailcfg.NodeAttrSSHBehaviorV1] = struct{}{} } return &netmap.NetworkMap{ SSHPolicy: &tailcfg.SSHPolicy{ Rules: []*tailcfg.SSHRule{ { Principals: []*tailcfg.SSHPrincipal{{Any: true}}, Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true}, SSHUsers: map[string]string{"*": tb.localUser}, }, }, }, AllCaps: capMap, } } func (tb *testBackend) WhoIs(_ string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { return (&tailcfg.Node{}).View(), tailcfg.UserProfile{ LoginName: tb.localUser + "@example.com", }, true } func (tb *testBackend) DoNoiseRequest(req *http.Request) (*http.Response, error) { return nil, nil } func (tb *testBackend) Dialer() *tsdial.Dialer { return nil } func (tb *testBackend) TailscaleVarRoot() string { return "" } func (tb *testBackend) NodeKey() key.NodePublic { return key.NodePublic{} } type addressFakingConn struct { net.Conn } func (conn *addressFakingConn) LocalAddr() net.Addr { return &net.TCPAddr{ IP: net.ParseIP("100.100.100.101"), Port: 22, } } func (conn *addressFakingConn) RemoteAddr() net.Addr { return &net.TCPAddr{ IP: net.ParseIP("100.100.100.102"), Port: 10002, } }