diff --git a/ssh/tailssh/auditd_linux.go b/ssh/tailssh/auditd_linux.go new file mode 100644 index 000000000..e9f551d9e --- /dev/null +++ b/ssh/tailssh/auditd_linux.go @@ -0,0 +1,176 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android + +package tailssh + +import ( + "bytes" + "encoding/binary" + "fmt" + "os" + "syscall" + + "golang.org/x/sys/unix" + "tailscale.com/types/logger" +) + +const ( + auditUserLogin = 1112 // audit message type for user login (from linux/audit.h) + netlinkAudit = 9 // AF_NETLINK protocol number for audit (from linux/netlink.h) + nlmFRequest = 0x01 // netlink message flag: request (from linux/netlink.h) + + // maxAuditMessageLength is the maximum length of an audit message payload. + // This is derived from MAX_AUDIT_MESSAGE_LENGTH (8970) in the Linux kernel + // (linux/audit.h), minus overhead for the netlink header and safety margin. + maxAuditMessageLength = 8192 +) + +// hasAuditWriteCap checks if the process has CAP_AUDIT_WRITE in its effective capability set. +func hasAuditWriteCap() bool { + var hdr unix.CapUserHeader + var data [2]unix.CapUserData + + hdr.Version = unix.LINUX_CAPABILITY_VERSION_3 + hdr.Pid = int32(os.Getpid()) + + if err := unix.Capget(&hdr, &data[0]); err != nil { + return false + } + + const capBit = uint32(1 << (unix.CAP_AUDIT_WRITE % 32)) + const capIdx = unix.CAP_AUDIT_WRITE / 32 + return (data[capIdx].Effective & capBit) != 0 +} + +// buildAuditNetlinkMessage constructs a netlink audit message. +// This is separated from sendAuditMessage to allow testing the message format +// without requiring CAP_AUDIT_WRITE or a netlink socket. +func buildAuditNetlinkMessage(msgType uint16, message string) ([]byte, error) { + msgBytes := []byte(message) + if len(msgBytes) > maxAuditMessageLength { + msgBytes = msgBytes[:maxAuditMessageLength] + } + msgLen := len(msgBytes) + + totalLen := syscall.NLMSG_HDRLEN + msgLen + alignedLen := (totalLen + syscall.NLMSG_ALIGNTO - 1) & ^(syscall.NLMSG_ALIGNTO - 1) + + nlh := syscall.NlMsghdr{ + Len: uint32(totalLen), + Type: msgType, + Flags: nlmFRequest, + Seq: 1, + Pid: uint32(os.Getpid()), + } + + buf := bytes.NewBuffer(make([]byte, 0, alignedLen)) + if err := binary.Write(buf, binary.NativeEndian, nlh); err != nil { + return nil, err + } + buf.Write(msgBytes) + + for buf.Len() < alignedLen { + buf.WriteByte(0) + } + + return buf.Bytes(), nil +} + +// sendAuditMessage sends a message to the audit subsystem using raw netlink. +// It logs errors but does not return them. +func sendAuditMessage(logf logger.Logf, msgType uint16, message string) { + if !hasAuditWriteCap() { + return + } + + fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, netlinkAudit) + if err != nil { + logf("auditd: failed to create netlink socket: %v", err) + return + } + defer syscall.Close(fd) + + bindAddr := &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: uint32(os.Getpid()), + Groups: 0, + } + + if err := syscall.Bind(fd, bindAddr); err != nil { + logf("auditd: failed to bind netlink socket: %v", err) + return + } + + kernelAddr := &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 0, + Groups: 0, + } + + msgBytes, err := buildAuditNetlinkMessage(msgType, message) + if err != nil { + logf("auditd: failed to build audit message: %v", err) + return + } + + if err := syscall.Sendto(fd, msgBytes, 0, kernelAddr); err != nil { + logf("auditd: failed to send audit message: %v", err) + return + } +} + +// logSSHLogin logs an SSH login event to auditd with whois information. +func logSSHLogin(logf logger.Logf, c *conn) { + if c == nil || c.info == nil || c.localUser == nil { + return + } + + exePath := c.srv.tailscaledPath + if exePath == "" { + exePath = "tailscaled" + } + + srcIP := c.info.src.Addr().String() + srcPort := c.info.src.Port() + dstIP := c.info.dst.Addr().String() + dstPort := c.info.dst.Port() + + tailscaleUser := c.info.uprof.LoginName + tailscaleUserID := c.info.uprof.ID + tailscaleDisplayName := c.info.uprof.DisplayName + nodeName := c.info.node.Name() + nodeID := c.info.node.ID() + + localUser := c.localUser.Username + localUID := c.localUser.Uid + localGID := c.localUser.Gid + + hostname, err := os.Hostname() + if err != nil { + hostname = "unknown" + } + + // use principally the same format as ssh / PAM, which come from the audit userspace, i.e. + // https://github.com/linux-audit/audit-userspace/blob/b6f8c208435038df113a9795e3e202720aee6b70/lib/audit_logging.c#L515 + msg := fmt.Sprintf( + "op=login acct=%s uid=%s gid=%s "+ + "src=%s src_port=%d dst=%s dst_port=%d "+ + "hostname=%q exe=%q terminal=ssh res=success "+ + "ts_user=%q ts_user_id=%d ts_display_name=%q ts_node=%q ts_node_id=%d", + localUser, localUID, localGID, + srcIP, srcPort, dstIP, dstPort, + hostname, exePath, + tailscaleUser, tailscaleUserID, tailscaleDisplayName, nodeName, nodeID, + ) + + sendAuditMessage(logf, auditUserLogin, msg) + + logf("audit: SSH login: user=%s uid=%s from=%s ts_user=%s node=%s", + localUser, localUID, srcIP, tailscaleUser, nodeName) +} + +func init() { + hookSSHLoginSuccess.Set(logSSHLogin) +} diff --git a/ssh/tailssh/auditd_linux_test.go b/ssh/tailssh/auditd_linux_test.go new file mode 100644 index 000000000..93f544291 --- /dev/null +++ b/ssh/tailssh/auditd_linux_test.go @@ -0,0 +1,180 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android + +package tailssh + +import ( + "bufio" + "bytes" + "context" + "encoding/binary" + "fmt" + "os" + "os/exec" + "strings" + "syscall" + "testing" + "time" +) + +// maybeWithSudo returns a command with context that may be prefixed with sudo if not running as root. +func maybeWithSudo(ctx context.Context, name string, args ...string) *exec.Cmd { + if os.Geteuid() == 0 { + return exec.CommandContext(ctx, name, args...) + } + sudoArgs := append([]string{name}, args...) + return exec.CommandContext(ctx, "sudo", sudoArgs...) +} + +func TestBuildAuditNetlinkMessage(t *testing.T) { + testCases := []struct { + name string + msgType uint16 + message string + wantType uint16 + }{ + { + name: "simple-message", + msgType: auditUserLogin, + message: "op=login acct=test", + wantType: auditUserLogin, + }, + { + name: "message-with-quoted-fields", + msgType: auditUserLogin, + message: `op=login hostname="test-host" exe="/usr/bin/tailscaled" ts_user="user@example.com" ts_node="node.tail-scale.ts.net"`, + wantType: auditUserLogin, + }, + { + name: "message-with-special-chars", + msgType: auditUserLogin, + message: `op=login hostname="host with spaces" ts_user="user name@example.com" ts_display_name="User \"Quote\" Name"`, + wantType: auditUserLogin, + }, + { + name: "long-message-truncated", + msgType: auditUserLogin, + message: string(make([]byte, 2000)), + wantType: auditUserLogin, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + msg, err := buildAuditNetlinkMessage(tc.msgType, tc.message) + if err != nil { + t.Fatalf("buildAuditNetlinkMessage failed: %v", err) + } + + if len(msg) < syscall.NLMSG_HDRLEN { + t.Fatalf("message too short: got %d bytes, want at least %d", len(msg), syscall.NLMSG_HDRLEN) + } + + var nlh syscall.NlMsghdr + buf := bytes.NewReader(msg[:syscall.NLMSG_HDRLEN]) + if err := binary.Read(buf, binary.NativeEndian, &nlh); err != nil { + t.Fatalf("failed to parse netlink header: %v", err) + } + + if nlh.Type != tc.wantType { + t.Errorf("message type: got %d, want %d", nlh.Type, tc.wantType) + } + + if nlh.Flags != nlmFRequest { + t.Errorf("flags: got 0x%x, want 0x%x", nlh.Flags, nlmFRequest) + } + + if len(msg)%syscall.NLMSG_ALIGNTO != 0 { + t.Errorf("message not aligned: len=%d, alignment=%d", len(msg), syscall.NLMSG_ALIGNTO) + } + + payloadLen := int(nlh.Len) - syscall.NLMSG_HDRLEN + if payloadLen < 0 { + t.Fatalf("invalid payload length: %d", payloadLen) + } + + payload := msg[syscall.NLMSG_HDRLEN : syscall.NLMSG_HDRLEN+payloadLen] + + expectedMsg := tc.message + if len(expectedMsg) > maxAuditMessageLength { + expectedMsg = expectedMsg[:maxAuditMessageLength] + } + if string(payload) != expectedMsg { + t.Errorf("payload mismatch:\ngot: %q\nwant: %q", string(payload), expectedMsg) + } + + expectedLen := syscall.NLMSG_HDRLEN + len(payload) + if int(nlh.Len) != expectedLen { + t.Errorf("length field: got %d, want %d", nlh.Len, expectedLen) + } + }) + } +} + +func TestAuditIntegration(t *testing.T) { + if !hasAuditWriteCap() { + t.Skip("skipping: CAP_AUDIT_WRITE not in effective capability set") + } + + if _, err := exec.LookPath("journalctl"); err != nil { + t.Skip("skipping: journalctl not available") + } + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + checkCmd := maybeWithSudo(ctx, "journalctl", "--field", "_TRANSPORT") + var out bytes.Buffer + checkCmd.Stdout = &out + if err := checkCmd.Run(); err != nil { + t.Skipf("skipping: cannot query journalctl transports: %v", err) + } + if !strings.Contains(out.String(), "audit") { + t.Skip("skipping: journald not configured for audit messages, try: systemctl enable systemd-journald-audit.socket && systemctl restart systemd-journald") + } + + testID := fmt.Sprintf("tailscale-test-%d", time.Now().UnixNano()) + testMsg := fmt.Sprintf("op=test-audit test_id=%s res=success", testID) + + followCmd := maybeWithSudo(ctx, "journalctl", "-f", "_TRANSPORT=audit", "--no-pager") + + stdout, err := followCmd.StdoutPipe() + if err != nil { + t.Fatalf("failed to get stdout pipe: %v", err) + } + + if err := followCmd.Start(); err != nil { + t.Fatalf("failed to start journalctl: %v", err) + } + defer followCmd.Process.Kill() + + testLogf := func(format string, args ...any) { + t.Logf(format, args...) + } + sendAuditMessage(testLogf, auditUserLogin, testMsg) + + bs := bufio.NewScanner(stdout) + found := false + for bs.Scan() { + line := bs.Text() + if strings.Contains(line, testID) { + t.Logf("found audit log entry: %s", line) + found = true + break + } + } + + if err := bs.Err(); err != nil && ctx.Err() == nil { + t.Fatalf("error reading journalctl output: %v", err) + } + + if !found { + if ctx.Err() == context.DeadlineExceeded { + t.Errorf("timeout waiting for audit message with test_id=%s", testID) + } else { + t.Errorf("audit message with test_id=%s not found in journald audit log", testID) + } + } +} diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 7d12ab45f..91e1779bf 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -31,6 +31,7 @@ import ( gossh "golang.org/x/crypto/ssh" "tailscale.com/envknob" + "tailscale.com/feature" "tailscale.com/ipn/ipnlocal" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" @@ -56,6 +57,10 @@ var ( // authentication methods that may proceed), which results in the SSH // server immediately disconnecting the client. errTerminal = &gossh.PartialSuccessError{} + + // hookSSHLoginSuccess is called after successful SSH authentication. + // It is set by platform-specific code (e.g., auditd_linux.go). + hookSSHLoginSuccess feature.Hook[func(logf logger.Logf, c *conn)] ) const ( @@ -647,6 +652,11 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { ss := c.newSSHSession(s) ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username) ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username) + + if f, ok := hookSSHLoginSuccess.GetOk(); ok { + f(c.srv.logf, c) + } + ss.run() }