mirror of https://github.com/tailscale/tailscale/
Merge 34dc74f4eb into f36eb81e61
commit
163d50b1e3
@ -0,0 +1,462 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
//go:build linux && ts_swtpm
|
||||||
|
|
||||||
|
package tpm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-tpm/tpm2"
|
||||||
|
"github.com/google/go-tpm/tpm2/transport/linuxtpm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// swtpmBinary is the name of the swtpm executable to use
|
||||||
|
const swtpmBinary = "swtpm"
|
||||||
|
const example32ByteKey = "12345678901234567890123456789012"
|
||||||
|
|
||||||
|
type swtpm struct {
|
||||||
|
dataDir string
|
||||||
|
deviceName string
|
||||||
|
devicePath string
|
||||||
|
pidFile string
|
||||||
|
opts *swtpmOptions
|
||||||
|
t testing.TB
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSWTPM(t testing.TB, opts ...swtpmOption) *swtpm {
|
||||||
|
t.Helper()
|
||||||
|
options := &swtpmOptions{
|
||||||
|
version: "2.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
dataDir := t.TempDir()
|
||||||
|
|
||||||
|
suffix := make([]byte, 8)
|
||||||
|
if _, err := crand.Read(suffix); err != nil {
|
||||||
|
t.Fatalf("failed to generate random suffix: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// use unique per-test vtpm device names to avoid conflicts
|
||||||
|
deviceName := fmt.Sprintf("vtpm-%d-%s", time.Now().UnixNano(), hex.EncodeToString(suffix))
|
||||||
|
devicePath := filepath.Join("/dev", deviceName)
|
||||||
|
pidFile := filepath.Join(dataDir, "swtpm.pid")
|
||||||
|
|
||||||
|
s := &swtpm{
|
||||||
|
dataDir: dataDir,
|
||||||
|
deviceName: deviceName,
|
||||||
|
devicePath: devicePath,
|
||||||
|
pidFile: pidFile,
|
||||||
|
opts: options,
|
||||||
|
t: t,
|
||||||
|
}
|
||||||
|
s.t.Logf("created swtpm with device %s, data dir %s", deviceName, dataDir)
|
||||||
|
|
||||||
|
if err := s.start(); err != nil {
|
||||||
|
t.Fatalf("failed to start swtpm: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
s.stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSetup initializes the TPM state using swtpm_setup
|
||||||
|
func (s *swtpm) runSetup() error {
|
||||||
|
args := []string{
|
||||||
|
"--tpmstate", s.dataDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch s.opts.version {
|
||||||
|
case "1.2":
|
||||||
|
// TPM 1.2 is the default for swtpm_setup, no flag needed
|
||||||
|
case "2.0":
|
||||||
|
args = append(args, "--tpm2")
|
||||||
|
default:
|
||||||
|
s.t.Fatalf("unsupported swtpm version for setup: %q", s.opts.version)
|
||||||
|
}
|
||||||
|
|
||||||
|
fullArgs := append([]string{"swtpm_setup"}, args...)
|
||||||
|
cmd := exec.Command("sudo", fullArgs...)
|
||||||
|
s.t.Logf("running swtpm_setup with args: %v", args)
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("swtpm_setup failed: %w: %s", err, output)
|
||||||
|
} else {
|
||||||
|
s.t.Logf("swtpm_setup output: %s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// start launches the swtpm process with the configured options
|
||||||
|
func (s *swtpm) start() error {
|
||||||
|
// init state if requested
|
||||||
|
if s.opts.withSetup {
|
||||||
|
if err := s.runSetup(); err != nil {
|
||||||
|
return fmt.Errorf("swtpm_setup failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{
|
||||||
|
"cuse",
|
||||||
|
"--tpmstate", fmt.Sprintf("dir=%s", s.dataDir),
|
||||||
|
"--name", s.deviceName,
|
||||||
|
"--pid", fmt.Sprintf("file=%s", s.pidFile),
|
||||||
|
}
|
||||||
|
|
||||||
|
switch s.opts.version {
|
||||||
|
case "1.2":
|
||||||
|
case "2.0":
|
||||||
|
args = append(args, "--tpm2")
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported swtpm version: %s", s.opts.version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// when using swtpm_setup, we need to tell swtpm to send TPM2_Startup(CLEAR)
|
||||||
|
if s.opts.withSetup {
|
||||||
|
args = append(args, "--flags", "startup-clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.opts.flags != nil {
|
||||||
|
args = append(args, s.opts.flags...)
|
||||||
|
}
|
||||||
|
|
||||||
|
fullArgs := append([]string{swtpmBinary}, args...)
|
||||||
|
cmd := exec.Command("sudo", fullArgs...)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start swtpm with args %v: %w: %s", args, err, out)
|
||||||
|
} else {
|
||||||
|
s.t.Logf("swtpm started with args %v output: %s", args, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.t.Logf("waiting for swtpm device at %s", s.devicePath)
|
||||||
|
if err := s.waitForDevice(); err != nil {
|
||||||
|
s.stop()
|
||||||
|
return fmt.Errorf("swtpm device not available: %w", err)
|
||||||
|
} else {
|
||||||
|
s.t.Logf("swtpm device available at %s", s.devicePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForDevice waits for the swtpm character device to be created
|
||||||
|
func (s *swtpm) waitForDevice() error {
|
||||||
|
timeout := time.After(5 * time.Second)
|
||||||
|
ticker := time.NewTicker(50 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
return fmt.Errorf("timeout waiting for device at %s", s.devicePath)
|
||||||
|
case <-ticker.C:
|
||||||
|
if _, err := os.Stat(s.devicePath); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop terminates the swtpm process and cleans up resources
|
||||||
|
func (s *swtpm) stop() {
|
||||||
|
pidBytes, err := os.ReadFile(s.pidFile)
|
||||||
|
if err != nil {
|
||||||
|
s.t.Logf("failed to read swtpm pid file %s: %v", s.pidFile, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var pid int
|
||||||
|
pid, err = strconv.Atoi(string(bytes.TrimSpace(pidBytes)))
|
||||||
|
if err != nil {
|
||||||
|
s.t.Logf("failed to parse swtpm pid %q: %v", string(pidBytes), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var process *os.Process
|
||||||
|
process, err = os.FindProcess(pid)
|
||||||
|
if err != nil {
|
||||||
|
s.t.Logf("failed to find swtpm process with pid %d: %v", pid, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := process.Signal(syscall.SIGTERM); err != nil {
|
||||||
|
s.t.Logf("failed to send SIGTERM to swtpm PID %d: %v", pid, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.t.Logf("sent SIGTERM to swtpm PID %d", pid)
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := process.Wait()
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
s.t.Logf("swtpm PID %d exited with error: %v", pid, err)
|
||||||
|
} else {
|
||||||
|
s.t.Logf("swtpm PID %d exited after SIGTERM", pid)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
s.t.Fatalf("timed out waiting for process %d to terminate", pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DevicePath returns the path to the swtpm character device
|
||||||
|
func (s *swtpm) DevicePath() string {
|
||||||
|
return s.devicePath
|
||||||
|
}
|
||||||
|
|
||||||
|
type swtpmOptions struct {
|
||||||
|
version string
|
||||||
|
flags []string
|
||||||
|
withSetup bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type swtpmOption func(*swtpmOptions)
|
||||||
|
|
||||||
|
// withSetup enables TPM initialization via swtpm_setup
|
||||||
|
func withSetup() swtpmOption {
|
||||||
|
return func(o *swtpmOptions) {
|
||||||
|
o.withSetup = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// withSwtpmVersion sets the TPM version (either "1.2" or "2.0")
|
||||||
|
func withSwtpmVersion(version string) swtpmOption {
|
||||||
|
return func(o *swtpmOptions) {
|
||||||
|
o.version = version
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTPM12 configures swtpm to use TPM 1.2
|
||||||
|
func withTPM12() swtpmOption {
|
||||||
|
return withSwtpmVersion("1.2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTPM20 configures swtpm to use TPM 2.0 (default)
|
||||||
|
func withTPM20() swtpmOption {
|
||||||
|
return withSwtpmVersion("2.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSWTPMAvailable checks if swtpm is available and errors if not
|
||||||
|
func checkSWTPMAvailable(t testing.TB) {
|
||||||
|
t.Helper()
|
||||||
|
p, err := exec.LookPath(swtpmBinary)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("swtpm binary not found in PATH: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure version 0.10.1
|
||||||
|
cmd := exec.Command(p, "--version")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to execute swtpm --version: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(output, []byte("TPM emulator version 0.10.1")) {
|
||||||
|
t.Fatalf("swtpm version is not compatible: %s", output)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure that we can run swtpm with sudo non-interactive (necessary to create CUSE devices)
|
||||||
|
cmd = exec.Command("sudo", "-n", p, "cuse", "--help")
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
t.Fatalf("swtpm cannot be run with sudo without password prompt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSWTPM_Integration(t *testing.T) {
|
||||||
|
checkSWTPMAvailable(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
opts []swtpmOption
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "broken-1.2-no-setup",
|
||||||
|
opts: []swtpmOption{withTPM12()},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "broken-2.0-no-setup",
|
||||||
|
opts: []swtpmOption{withTPM20()},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "working-2.0",
|
||||||
|
opts: []swtpmOption{withTPM20(), withSetup()},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
swtpm := newSWTPM(t, tt.opts...)
|
||||||
|
devicePath := swtpm.DevicePath()
|
||||||
|
|
||||||
|
if _, err := os.Stat(devicePath); err != nil {
|
||||||
|
t.Fatalf("swtpm device does not exist at %s: %v", devicePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tpmDev, err := linuxtpm.Open(devicePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("linuxtpm.Open(%s) failed: %v", devicePath, err)
|
||||||
|
}
|
||||||
|
defer tpmDev.Close()
|
||||||
|
|
||||||
|
err = withSRK(t.Logf, tpmDev, func(srk tpm2.AuthHandle) error {
|
||||||
|
t.Logf("Successfully loaded SRK with handle: %v", srk.Handle)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if tt.wantErr != (err != nil) {
|
||||||
|
t.Errorf("withSRK() error = %v, wantErr = %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSWTPM_SealUnseal(t *testing.T) {
|
||||||
|
checkSWTPMAvailable(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
opts []swtpmOption
|
||||||
|
data []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "1.2-fail-no-setup",
|
||||||
|
opts: []swtpmOption{withTPM12()},
|
||||||
|
data: []byte(example32ByteKey),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "1.2-fail-32-byte-key",
|
||||||
|
opts: []swtpmOption{withTPM12(), withSetup()},
|
||||||
|
data: []byte(example32ByteKey),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2.0-seal-fail-no-setup",
|
||||||
|
opts: []swtpmOption{withTPM20()},
|
||||||
|
data: []byte("test data"),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2.0-seal-unseal-32-byte-key",
|
||||||
|
opts: []swtpmOption{withTPM20(), withSetup()},
|
||||||
|
data: []byte(example32ByteKey),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
swtpm := newSWTPM(t, tt.opts...)
|
||||||
|
devicePath := swtpm.DevicePath()
|
||||||
|
|
||||||
|
tpmDev, err := linuxtpm.Open(devicePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("linuxtpm.Open(%s) failed: %v", devicePath, err)
|
||||||
|
}
|
||||||
|
defer tpmDev.Close()
|
||||||
|
|
||||||
|
sealed, err := tpmSealWithTPM(t.Logf, tpmDev, tt.data)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("tpmSealWithTPM() expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tpmSealWithTPM() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sealed == nil {
|
||||||
|
t.Fatal("tpmSealWithTPM() returned nil sealed data")
|
||||||
|
}
|
||||||
|
if len(sealed.Private) == 0 {
|
||||||
|
t.Error("sealed.Private is empty")
|
||||||
|
}
|
||||||
|
if len(sealed.Public) == 0 {
|
||||||
|
t.Error("sealed.Public is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
unsealed, err := tpmUnsealWithTPM(t.Logf, tpmDev, sealed)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tpmUnsealWithTPM() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(unsealed, tt.data) {
|
||||||
|
t.Errorf("unsealed data mismatch:\ngot: %q\nwant: %q", unsealed, tt.data)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSWTPM_SealUnsealCrossDevice(t *testing.T) {
|
||||||
|
checkSWTPMAvailable(t)
|
||||||
|
|
||||||
|
swtpm1 := newSWTPM(t, withTPM20(), withSetup())
|
||||||
|
tpmDev1, err := linuxtpm.Open(swtpm1.DevicePath())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("linuxtpm.Open(%s) failed: %v", swtpm1.DevicePath(), err)
|
||||||
|
}
|
||||||
|
defer tpmDev1.Close()
|
||||||
|
|
||||||
|
logf := func(format string, args ...any) {
|
||||||
|
t.Logf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
testData := []byte("TPM1 secret data")
|
||||||
|
sealed, err := tpmSealWithTPM(logf, tpmDev1, testData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tpmSealWithTPM() on first device failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// round trip on the same TPM
|
||||||
|
unsealed, err := tpmUnsealWithTPM(logf, tpmDev1, sealed)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tpmUnsealWithTPM() on first device failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(unsealed, testData) {
|
||||||
|
t.Errorf("unsealed data mismatch on first device:\ngot: %q\nwant: %q", unsealed, testData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a second device
|
||||||
|
swtpm2 := newSWTPM(t, withTPM20(), withSetup())
|
||||||
|
tpmDev2, err := linuxtpm.Open(swtpm2.DevicePath())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("linuxtpm.Open(%s) failed: %v", swtpm2.DevicePath(), err)
|
||||||
|
}
|
||||||
|
defer tpmDev2.Close()
|
||||||
|
|
||||||
|
// confirm we cannot unseal with the second TPM
|
||||||
|
_, err = tpmUnsealWithTPM(logf, tpmDev2, sealed)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("tpmUnsealWithTPM() on second device should have failed but succeeded")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue