mirror of https://github.com/tailscale/tailscale/
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
581 lines
13 KiB
Go
581 lines
13 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package ipnext
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"testing"
|
|
|
|
"tailscale.com/ipn"
|
|
"tailscale.com/tsd"
|
|
"tailscale.com/tstime"
|
|
"tailscale.com/types/logger"
|
|
)
|
|
|
|
// mockExtension implements Extension for testing
|
|
type mockExtension struct {
|
|
name string
|
|
initErr error
|
|
shutdownErr error
|
|
initCalled bool
|
|
shutdownCalled bool
|
|
}
|
|
|
|
func (m *mockExtension) Name() string { return m.name }
|
|
|
|
func (m *mockExtension) Init(Host) error {
|
|
m.initCalled = true
|
|
return m.initErr
|
|
}
|
|
|
|
func (m *mockExtension) Shutdown() error {
|
|
m.shutdownCalled = true
|
|
return m.shutdownErr
|
|
}
|
|
|
|
// mockSafeBackend implements SafeBackend for testing
|
|
type mockSafeBackend struct{}
|
|
|
|
func (m *mockSafeBackend) Sys() *tsd.System { return nil }
|
|
func (m *mockSafeBackend) Clock() tstime.Clock { return nil }
|
|
func (m *mockSafeBackend) TailscaleVarRoot() string { return "/tmp" }
|
|
|
|
// TestDefinition_Name tests Definition.Name()
|
|
func TestDefinition_Name(t *testing.T) {
|
|
d := &Definition{name: "test-extension"}
|
|
if got := d.Name(); got != "test-extension" {
|
|
t.Errorf("Name() = %q, want %q", got, "test-extension")
|
|
}
|
|
}
|
|
|
|
// TestDefinition_MakeExtension tests successful extension creation
|
|
func TestDefinition_MakeExtension(t *testing.T) {
|
|
ext := &mockExtension{name: "test"}
|
|
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
|
|
return ext, nil
|
|
}
|
|
|
|
d := &Definition{
|
|
name: "test",
|
|
newFn: newFn,
|
|
}
|
|
|
|
logf := logger.Discard
|
|
sb := &mockSafeBackend{}
|
|
|
|
got, err := d.MakeExtension(logf, sb)
|
|
if err != nil {
|
|
t.Fatalf("MakeExtension() error = %v", err)
|
|
}
|
|
|
|
if got != ext {
|
|
t.Error("MakeExtension() returned wrong extension")
|
|
}
|
|
}
|
|
|
|
// TestDefinition_MakeExtension_NameMismatch tests name validation
|
|
func TestDefinition_MakeExtension_NameMismatch(t *testing.T) {
|
|
ext := &mockExtension{name: "wrong-name"}
|
|
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
|
|
return ext, nil
|
|
}
|
|
|
|
d := &Definition{
|
|
name: "expected-name",
|
|
newFn: newFn,
|
|
}
|
|
|
|
logf := logger.Discard
|
|
sb := &mockSafeBackend{}
|
|
|
|
_, err := d.MakeExtension(logf, sb)
|
|
if err == nil {
|
|
t.Fatal("MakeExtension() should error on name mismatch")
|
|
}
|
|
|
|
wantErr := `extension name mismatch: registered "expected-name"; actual "wrong-name"`
|
|
if err.Error() != wantErr {
|
|
t.Errorf("error = %q, want %q", err.Error(), wantErr)
|
|
}
|
|
}
|
|
|
|
// TestDefinition_MakeExtension_NewFnError tests error propagation
|
|
func TestDefinition_MakeExtension_NewFnError(t *testing.T) {
|
|
expectedErr := errors.New("creation failed")
|
|
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
|
|
return nil, expectedErr
|
|
}
|
|
|
|
d := &Definition{
|
|
name: "test",
|
|
newFn: newFn,
|
|
}
|
|
|
|
logf := logger.Discard
|
|
sb := &mockSafeBackend{}
|
|
|
|
_, err := d.MakeExtension(logf, sb)
|
|
if !errors.Is(err, expectedErr) {
|
|
t.Errorf("MakeExtension() error = %v, want %v", err, expectedErr)
|
|
}
|
|
}
|
|
|
|
// TestRegisterExtension_Panic_NilFunc tests nil function panic
|
|
func TestRegisterExtension_Panic_NilFunc(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Error("RegisterExtension() should panic with nil function")
|
|
} else {
|
|
got := fmt.Sprint(r)
|
|
want := `ipnext: newExt is nil: "test"`
|
|
if got != want {
|
|
t.Errorf("panic message = %q, want %q", got, want)
|
|
}
|
|
}
|
|
// Reset extensions map after test
|
|
extensions = extensions[:0]
|
|
}()
|
|
|
|
RegisterExtension("test", nil)
|
|
}
|
|
|
|
// TestRegisterExtension_Panic_Duplicate tests duplicate name panic
|
|
func TestRegisterExtension_Panic_Duplicate(t *testing.T) {
|
|
defer func() {
|
|
// Reset extensions map after test
|
|
extensions = extensions[:0]
|
|
}()
|
|
|
|
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
|
|
return &mockExtension{name: "test"}, nil
|
|
}
|
|
|
|
// First registration should succeed
|
|
RegisterExtension("test", newFn)
|
|
|
|
// Second registration should panic
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Error("RegisterExtension() should panic on duplicate")
|
|
} else {
|
|
got := fmt.Sprint(r)
|
|
want := `ipnext: duplicate extension name "test"`
|
|
if got != want {
|
|
t.Errorf("panic message = %q, want %q", got, want)
|
|
}
|
|
}
|
|
}()
|
|
|
|
RegisterExtension("test", newFn)
|
|
}
|
|
|
|
// TestRegisterExtension_Success tests successful registration
|
|
func TestRegisterExtension_Success(t *testing.T) {
|
|
defer func() {
|
|
extensions = extensions[:0]
|
|
}()
|
|
|
|
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
|
|
return &mockExtension{name: "test"}, nil
|
|
}
|
|
|
|
RegisterExtension("test", newFn)
|
|
|
|
if !extensions.Contains("test") {
|
|
t.Error("extension not registered")
|
|
}
|
|
|
|
def, ok := extensions.Get("test")
|
|
if !ok {
|
|
t.Fatal("failed to get registered extension")
|
|
}
|
|
|
|
if def.name != "test" {
|
|
t.Errorf("registered name = %q, want %q", def.name, "test")
|
|
}
|
|
}
|
|
|
|
// TestExtensions_Iterator tests Extensions() iteration
|
|
func TestExtensions_Iterator(t *testing.T) {
|
|
defer func() {
|
|
extensions = extensions[:0]
|
|
}()
|
|
|
|
newFn := func(name string) NewExtensionFn {
|
|
return func(logger.Logf, SafeBackend) (Extension, error) {
|
|
return &mockExtension{name: name}, nil
|
|
}
|
|
}
|
|
|
|
RegisterExtension("ext1", newFn("ext1"))
|
|
RegisterExtension("ext2", newFn("ext2"))
|
|
RegisterExtension("ext3", newFn("ext3"))
|
|
|
|
count := 0
|
|
seen := make(map[string]bool)
|
|
|
|
for def := range Extensions() {
|
|
count++
|
|
seen[def.name] = true
|
|
}
|
|
|
|
if count != 3 {
|
|
t.Errorf("Extensions() count = %d, want 3", count)
|
|
}
|
|
|
|
for _, name := range []string{"ext1", "ext2", "ext3"} {
|
|
if !seen[name] {
|
|
t.Errorf("extension %q not seen in iteration", name)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestExtensions_Order tests iteration order preservation
|
|
func TestExtensions_Order(t *testing.T) {
|
|
defer func() {
|
|
extensions = extensions[:0]
|
|
}()
|
|
|
|
newFn := func(name string) NewExtensionFn {
|
|
return func(logger.Logf, SafeBackend) (Extension, error) {
|
|
return &mockExtension{name: name}, nil
|
|
}
|
|
}
|
|
|
|
RegisterExtension("first", newFn("first"))
|
|
RegisterExtension("second", newFn("second"))
|
|
RegisterExtension("third", newFn("third"))
|
|
|
|
var order []string
|
|
for def := range Extensions() {
|
|
order = append(order, def.name)
|
|
}
|
|
|
|
want := []string{"first", "second", "third"}
|
|
if len(order) != len(want) {
|
|
t.Fatalf("order length = %d, want %d", len(order), len(want))
|
|
}
|
|
|
|
for i, name := range want {
|
|
if order[i] != name {
|
|
t.Errorf("order[%d] = %q, want %q", i, order[i], name)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestDefinitionForTest tests test helper
|
|
func TestDefinitionForTest(t *testing.T) {
|
|
ext := &mockExtension{name: "test-ext"}
|
|
def := DefinitionForTest(ext)
|
|
|
|
if def.name != "test-ext" {
|
|
t.Errorf("name = %q, want %q", def.name, "test-ext")
|
|
}
|
|
|
|
logf := logger.Discard
|
|
sb := &mockSafeBackend{}
|
|
|
|
got, err := def.MakeExtension(logf, sb)
|
|
if err != nil {
|
|
t.Fatalf("MakeExtension() error = %v", err)
|
|
}
|
|
|
|
if got != ext {
|
|
t.Error("MakeExtension() returned wrong extension")
|
|
}
|
|
}
|
|
|
|
// TestDefinitionWithErrForTest tests error test helper
|
|
func TestDefinitionWithErrForTest(t *testing.T) {
|
|
expectedErr := errors.New("test error")
|
|
def := DefinitionWithErrForTest("error-ext", expectedErr)
|
|
|
|
if def.name != "error-ext" {
|
|
t.Errorf("name = %q, want %q", def.name, "error-ext")
|
|
}
|
|
|
|
logf := logger.Discard
|
|
sb := &mockSafeBackend{}
|
|
|
|
_, err := def.MakeExtension(logf, sb)
|
|
if !errors.Is(err, expectedErr) {
|
|
t.Errorf("MakeExtension() error = %v, want %v", err, expectedErr)
|
|
}
|
|
}
|
|
|
|
// TestSkipExtension_Error tests SkipExtension error
|
|
func TestSkipExtension_Error(t *testing.T) {
|
|
if SkipExtension == nil {
|
|
t.Fatal("SkipExtension should not be nil")
|
|
}
|
|
|
|
want := "skipping extension"
|
|
if SkipExtension.Error() != want {
|
|
t.Errorf("SkipExtension.Error() = %q, want %q", SkipExtension.Error(), want)
|
|
}
|
|
}
|
|
|
|
// TestSkipExtension_Wrapped tests wrapped SkipExtension
|
|
func TestSkipExtension_Wrapped(t *testing.T) {
|
|
wrapped := fmt.Errorf("platform not supported: %w", SkipExtension)
|
|
|
|
if !errors.Is(wrapped, SkipExtension) {
|
|
t.Error("wrapped error should be SkipExtension")
|
|
}
|
|
}
|
|
|
|
// TestMockExtension_Interface tests mock implements Extension
|
|
func TestMockExtension_Interface(t *testing.T) {
|
|
var _ Extension = (*mockExtension)(nil)
|
|
}
|
|
|
|
// TestMockExtension_Init tests Init tracking
|
|
func TestMockExtension_Init(t *testing.T) {
|
|
ext := &mockExtension{name: "test"}
|
|
|
|
if ext.initCalled {
|
|
t.Error("initCalled should be false initially")
|
|
}
|
|
|
|
err := ext.Init(nil)
|
|
if err != nil {
|
|
t.Errorf("Init() error = %v", err)
|
|
}
|
|
|
|
if !ext.initCalled {
|
|
t.Error("initCalled should be true after Init()")
|
|
}
|
|
}
|
|
|
|
// TestMockExtension_InitError tests Init error
|
|
func TestMockExtension_InitError(t *testing.T) {
|
|
expectedErr := errors.New("init failed")
|
|
ext := &mockExtension{
|
|
name: "test",
|
|
initErr: expectedErr,
|
|
}
|
|
|
|
err := ext.Init(nil)
|
|
if !errors.Is(err, expectedErr) {
|
|
t.Errorf("Init() error = %v, want %v", err, expectedErr)
|
|
}
|
|
|
|
if !ext.initCalled {
|
|
t.Error("initCalled should be true even on error")
|
|
}
|
|
}
|
|
|
|
// TestMockExtension_Shutdown tests Shutdown tracking
|
|
func TestMockExtension_Shutdown(t *testing.T) {
|
|
ext := &mockExtension{name: "test"}
|
|
|
|
if ext.shutdownCalled {
|
|
t.Error("shutdownCalled should be false initially")
|
|
}
|
|
|
|
err := ext.Shutdown()
|
|
if err != nil {
|
|
t.Errorf("Shutdown() error = %v", err)
|
|
}
|
|
|
|
if !ext.shutdownCalled {
|
|
t.Error("shutdownCalled should be true after Shutdown()")
|
|
}
|
|
}
|
|
|
|
// TestMockExtension_ShutdownError tests Shutdown error
|
|
func TestMockExtension_ShutdownError(t *testing.T) {
|
|
expectedErr := errors.New("shutdown failed")
|
|
ext := &mockExtension{
|
|
name: "test",
|
|
shutdownErr: expectedErr,
|
|
}
|
|
|
|
err := ext.Shutdown()
|
|
if !errors.Is(err, expectedErr) {
|
|
t.Errorf("Shutdown() error = %v, want %v", err, expectedErr)
|
|
}
|
|
|
|
if !ext.shutdownCalled {
|
|
t.Error("shutdownCalled should be true even on error")
|
|
}
|
|
}
|
|
|
|
// TestMockSafeBackend_Interface tests mock implements SafeBackend
|
|
func TestMockSafeBackend_Interface(t *testing.T) {
|
|
var _ SafeBackend = (*mockSafeBackend)(nil)
|
|
}
|
|
|
|
// TestMockSafeBackend_Methods tests SafeBackend methods
|
|
func TestMockSafeBackend_Methods(t *testing.T) {
|
|
sb := &mockSafeBackend{}
|
|
|
|
if sb.Sys() != nil {
|
|
t.Error("Sys() should return nil")
|
|
}
|
|
|
|
if sb.Clock() != nil {
|
|
t.Error("Clock() should return nil")
|
|
}
|
|
|
|
if sb.TailscaleVarRoot() != "/tmp" {
|
|
t.Errorf("TailscaleVarRoot() = %q, want /tmp", sb.TailscaleVarRoot())
|
|
}
|
|
}
|
|
|
|
// TestHooks_ZeroValue tests Hooks zero value
|
|
func TestHooks_ZeroValue(t *testing.T) {
|
|
var h Hooks
|
|
|
|
// Verify all hooks are zero-valued and usable
|
|
_ = h.BackendStateChange
|
|
_ = h.ProfileStateChange
|
|
_ = h.BackgroundProfileResolvers
|
|
_ = h.AuditLoggers
|
|
_ = h.NewControlClient
|
|
_ = h.OnSelfChange
|
|
_ = h.MutateNotifyLocked
|
|
_ = h.SetPeerStatus
|
|
_ = h.ShouldUploadServices
|
|
}
|
|
|
|
// TestProfileStateChangeCallback_Type tests callback signature
|
|
func TestProfileStateChangeCallback_Type(t *testing.T) {
|
|
var callback ProfileStateChangeCallback = func(p ipn.LoginProfileView, pr ipn.PrefsView, sameNode bool) {
|
|
// Callback implementation
|
|
_ = p
|
|
_ = pr
|
|
_ = sameNode
|
|
}
|
|
|
|
if callback == nil {
|
|
t.Error("callback should not be nil")
|
|
}
|
|
|
|
// Test calling the callback
|
|
callback(ipn.LoginProfileView{}, ipn.PrefsView{}, true)
|
|
}
|
|
|
|
// TestNewExtensionFn_Type tests function type
|
|
func TestNewExtensionFn_Type(t *testing.T) {
|
|
var fn NewExtensionFn = func(logger.Logf, SafeBackend) (Extension, error) {
|
|
return &mockExtension{name: "test"}, nil
|
|
}
|
|
|
|
if fn == nil {
|
|
t.Error("fn should not be nil")
|
|
}
|
|
|
|
ext, err := fn(logger.Discard, &mockSafeBackend{})
|
|
if err != nil {
|
|
t.Fatalf("fn() error = %v", err)
|
|
}
|
|
|
|
if ext.Name() != "test" {
|
|
t.Errorf("extension name = %q, want %q", ext.Name(), "test")
|
|
}
|
|
}
|
|
|
|
// TestAuditLogProvider_Type tests provider type
|
|
func TestAuditLogProvider_Type(t *testing.T) {
|
|
var provider AuditLogProvider = func() ipnauth.AuditLogFunc {
|
|
return func(*ipnauth.AuditLogEntry) error {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if provider == nil {
|
|
t.Error("provider should not be nil")
|
|
}
|
|
|
|
fn := provider()
|
|
if fn == nil {
|
|
t.Error("audit log func should not be nil")
|
|
}
|
|
|
|
err := fn(&ipnauth.AuditLogEntry{})
|
|
if err != nil {
|
|
t.Errorf("audit log func error = %v", err)
|
|
}
|
|
}
|
|
|
|
// TestProfileResolver_Type tests resolver type
|
|
func TestProfileResolver_Type(t *testing.T) {
|
|
var resolver ProfileResolver = func(ps ProfileStore) ipn.LoginProfileView {
|
|
return ps.CurrentProfile()
|
|
}
|
|
|
|
if resolver == nil {
|
|
t.Error("resolver should not be nil")
|
|
}
|
|
}
|
|
|
|
// TestExtensions_EmptyMap tests empty extensions map
|
|
func TestExtensions_EmptyMap(t *testing.T) {
|
|
defer func() {
|
|
extensions = extensions[:0]
|
|
}()
|
|
|
|
// Reset to empty
|
|
extensions = extensions[:0]
|
|
|
|
count := 0
|
|
for range Extensions() {
|
|
count++
|
|
}
|
|
|
|
if count != 0 {
|
|
t.Errorf("empty Extensions() should yield 0 items, got %d", count)
|
|
}
|
|
}
|
|
|
|
// TestDefinition_NilNewFn tests nil newFn handling
|
|
func TestDefinition_NilNewFn(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
// MakeExtension might panic on nil newFn
|
|
t.Logf("panic (expected): %v", r)
|
|
}
|
|
}()
|
|
|
|
d := &Definition{
|
|
name: "test",
|
|
newFn: nil,
|
|
}
|
|
|
|
// This should panic or error
|
|
_, err := d.MakeExtension(logger.Discard, &mockSafeBackend{})
|
|
if err == nil {
|
|
t.Error("MakeExtension() with nil newFn should fail")
|
|
}
|
|
}
|
|
|
|
// TestMultipleExtensions_Registration tests multiple extensions
|
|
func TestMultipleExtensions_Registration(t *testing.T) {
|
|
defer func() {
|
|
extensions = extensions[:0]
|
|
}()
|
|
|
|
names := []string{"ext-a", "ext-b", "ext-c", "ext-d", "ext-e"}
|
|
|
|
for _, name := range names {
|
|
n := name // capture
|
|
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
|
|
return &mockExtension{name: n}, nil
|
|
}
|
|
RegisterExtension(n, newFn)
|
|
}
|
|
|
|
if extensions.Len() != 5 {
|
|
t.Errorf("extensions count = %d, want 5", extensions.Len())
|
|
}
|
|
|
|
for _, name := range names {
|
|
if !extensions.Contains(name) {
|
|
t.Errorf("extension %q not registered", name)
|
|
}
|
|
}
|
|
}
|