diff --git a/ipn/ipnext/ipnext_test.go b/ipn/ipnext/ipnext_test.go new file mode 100644 index 000000000..f0460db50 --- /dev/null +++ b/ipn/ipnext/ipnext_test.go @@ -0,0 +1,580 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnext + +import ( + "errors" + "fmt" + "testing" + + "tailscale.com/ipn" + "tailscale.com/tsd" + "tailscale.com/tstime" + "tailscale.com/types/logger" +) + +// mockExtension implements Extension for testing +type mockExtension struct { + name string + initErr error + shutdownErr error + initCalled bool + shutdownCalled bool +} + +func (m *mockExtension) Name() string { return m.name } + +func (m *mockExtension) Init(Host) error { + m.initCalled = true + return m.initErr +} + +func (m *mockExtension) Shutdown() error { + m.shutdownCalled = true + return m.shutdownErr +} + +// mockSafeBackend implements SafeBackend for testing +type mockSafeBackend struct{} + +func (m *mockSafeBackend) Sys() *tsd.System { return nil } +func (m *mockSafeBackend) Clock() tstime.Clock { return nil } +func (m *mockSafeBackend) TailscaleVarRoot() string { return "/tmp" } + +// TestDefinition_Name tests Definition.Name() +func TestDefinition_Name(t *testing.T) { + d := &Definition{name: "test-extension"} + if got := d.Name(); got != "test-extension" { + t.Errorf("Name() = %q, want %q", got, "test-extension") + } +} + +// TestDefinition_MakeExtension tests successful extension creation +func TestDefinition_MakeExtension(t *testing.T) { + ext := &mockExtension{name: "test"} + newFn := func(logger.Logf, SafeBackend) (Extension, error) { + return ext, nil + } + + d := &Definition{ + name: "test", + newFn: newFn, + } + + logf := logger.Discard + sb := &mockSafeBackend{} + + got, err := d.MakeExtension(logf, sb) + if err != nil { + t.Fatalf("MakeExtension() error = %v", err) + } + + if got != ext { + t.Error("MakeExtension() returned wrong extension") + } +} + +// TestDefinition_MakeExtension_NameMismatch tests name validation +func TestDefinition_MakeExtension_NameMismatch(t *testing.T) { + ext := &mockExtension{name: "wrong-name"} + newFn := func(logger.Logf, SafeBackend) (Extension, error) { + return ext, nil + } + + d := &Definition{ + name: "expected-name", + newFn: newFn, + } + + logf := logger.Discard + sb := &mockSafeBackend{} + + _, err := d.MakeExtension(logf, sb) + if err == nil { + t.Fatal("MakeExtension() should error on name mismatch") + } + + wantErr := `extension name mismatch: registered "expected-name"; actual "wrong-name"` + if err.Error() != wantErr { + t.Errorf("error = %q, want %q", err.Error(), wantErr) + } +} + +// TestDefinition_MakeExtension_NewFnError tests error propagation +func TestDefinition_MakeExtension_NewFnError(t *testing.T) { + expectedErr := errors.New("creation failed") + newFn := func(logger.Logf, SafeBackend) (Extension, error) { + return nil, expectedErr + } + + d := &Definition{ + name: "test", + newFn: newFn, + } + + logf := logger.Discard + sb := &mockSafeBackend{} + + _, err := d.MakeExtension(logf, sb) + if !errors.Is(err, expectedErr) { + t.Errorf("MakeExtension() error = %v, want %v", err, expectedErr) + } +} + +// TestRegisterExtension_Panic_NilFunc tests nil function panic +func TestRegisterExtension_Panic_NilFunc(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("RegisterExtension() should panic with nil function") + } else { + got := fmt.Sprint(r) + want := `ipnext: newExt is nil: "test"` + if got != want { + t.Errorf("panic message = %q, want %q", got, want) + } + } + // Reset extensions map after test + extensions = extensions[:0] + }() + + RegisterExtension("test", nil) +} + +// TestRegisterExtension_Panic_Duplicate tests duplicate name panic +func TestRegisterExtension_Panic_Duplicate(t *testing.T) { + defer func() { + // Reset extensions map after test + extensions = extensions[:0] + }() + + newFn := func(logger.Logf, SafeBackend) (Extension, error) { + return &mockExtension{name: "test"}, nil + } + + // First registration should succeed + RegisterExtension("test", newFn) + + // Second registration should panic + defer func() { + if r := recover(); r == nil { + t.Error("RegisterExtension() should panic on duplicate") + } else { + got := fmt.Sprint(r) + want := `ipnext: duplicate extension name "test"` + if got != want { + t.Errorf("panic message = %q, want %q", got, want) + } + } + }() + + RegisterExtension("test", newFn) +} + +// TestRegisterExtension_Success tests successful registration +func TestRegisterExtension_Success(t *testing.T) { + defer func() { + extensions = extensions[:0] + }() + + newFn := func(logger.Logf, SafeBackend) (Extension, error) { + return &mockExtension{name: "test"}, nil + } + + RegisterExtension("test", newFn) + + if !extensions.Contains("test") { + t.Error("extension not registered") + } + + def, ok := extensions.Get("test") + if !ok { + t.Fatal("failed to get registered extension") + } + + if def.name != "test" { + t.Errorf("registered name = %q, want %q", def.name, "test") + } +} + +// TestExtensions_Iterator tests Extensions() iteration +func TestExtensions_Iterator(t *testing.T) { + defer func() { + extensions = extensions[:0] + }() + + newFn := func(name string) NewExtensionFn { + return func(logger.Logf, SafeBackend) (Extension, error) { + return &mockExtension{name: name}, nil + } + } + + RegisterExtension("ext1", newFn("ext1")) + RegisterExtension("ext2", newFn("ext2")) + RegisterExtension("ext3", newFn("ext3")) + + count := 0 + seen := make(map[string]bool) + + for def := range Extensions() { + count++ + seen[def.name] = true + } + + if count != 3 { + t.Errorf("Extensions() count = %d, want 3", count) + } + + for _, name := range []string{"ext1", "ext2", "ext3"} { + if !seen[name] { + t.Errorf("extension %q not seen in iteration", name) + } + } +} + +// TestExtensions_Order tests iteration order preservation +func TestExtensions_Order(t *testing.T) { + defer func() { + extensions = extensions[:0] + }() + + newFn := func(name string) NewExtensionFn { + return func(logger.Logf, SafeBackend) (Extension, error) { + return &mockExtension{name: name}, nil + } + } + + RegisterExtension("first", newFn("first")) + RegisterExtension("second", newFn("second")) + RegisterExtension("third", newFn("third")) + + var order []string + for def := range Extensions() { + order = append(order, def.name) + } + + want := []string{"first", "second", "third"} + if len(order) != len(want) { + t.Fatalf("order length = %d, want %d", len(order), len(want)) + } + + for i, name := range want { + if order[i] != name { + t.Errorf("order[%d] = %q, want %q", i, order[i], name) + } + } +} + +// TestDefinitionForTest tests test helper +func TestDefinitionForTest(t *testing.T) { + ext := &mockExtension{name: "test-ext"} + def := DefinitionForTest(ext) + + if def.name != "test-ext" { + t.Errorf("name = %q, want %q", def.name, "test-ext") + } + + logf := logger.Discard + sb := &mockSafeBackend{} + + got, err := def.MakeExtension(logf, sb) + if err != nil { + t.Fatalf("MakeExtension() error = %v", err) + } + + if got != ext { + t.Error("MakeExtension() returned wrong extension") + } +} + +// TestDefinitionWithErrForTest tests error test helper +func TestDefinitionWithErrForTest(t *testing.T) { + expectedErr := errors.New("test error") + def := DefinitionWithErrForTest("error-ext", expectedErr) + + if def.name != "error-ext" { + t.Errorf("name = %q, want %q", def.name, "error-ext") + } + + logf := logger.Discard + sb := &mockSafeBackend{} + + _, err := def.MakeExtension(logf, sb) + if !errors.Is(err, expectedErr) { + t.Errorf("MakeExtension() error = %v, want %v", err, expectedErr) + } +} + +// TestSkipExtension_Error tests SkipExtension error +func TestSkipExtension_Error(t *testing.T) { + if SkipExtension == nil { + t.Fatal("SkipExtension should not be nil") + } + + want := "skipping extension" + if SkipExtension.Error() != want { + t.Errorf("SkipExtension.Error() = %q, want %q", SkipExtension.Error(), want) + } +} + +// TestSkipExtension_Wrapped tests wrapped SkipExtension +func TestSkipExtension_Wrapped(t *testing.T) { + wrapped := fmt.Errorf("platform not supported: %w", SkipExtension) + + if !errors.Is(wrapped, SkipExtension) { + t.Error("wrapped error should be SkipExtension") + } +} + +// TestMockExtension_Interface tests mock implements Extension +func TestMockExtension_Interface(t *testing.T) { + var _ Extension = (*mockExtension)(nil) +} + +// TestMockExtension_Init tests Init tracking +func TestMockExtension_Init(t *testing.T) { + ext := &mockExtension{name: "test"} + + if ext.initCalled { + t.Error("initCalled should be false initially") + } + + err := ext.Init(nil) + if err != nil { + t.Errorf("Init() error = %v", err) + } + + if !ext.initCalled { + t.Error("initCalled should be true after Init()") + } +} + +// TestMockExtension_InitError tests Init error +func TestMockExtension_InitError(t *testing.T) { + expectedErr := errors.New("init failed") + ext := &mockExtension{ + name: "test", + initErr: expectedErr, + } + + err := ext.Init(nil) + if !errors.Is(err, expectedErr) { + t.Errorf("Init() error = %v, want %v", err, expectedErr) + } + + if !ext.initCalled { + t.Error("initCalled should be true even on error") + } +} + +// TestMockExtension_Shutdown tests Shutdown tracking +func TestMockExtension_Shutdown(t *testing.T) { + ext := &mockExtension{name: "test"} + + if ext.shutdownCalled { + t.Error("shutdownCalled should be false initially") + } + + err := ext.Shutdown() + if err != nil { + t.Errorf("Shutdown() error = %v", err) + } + + if !ext.shutdownCalled { + t.Error("shutdownCalled should be true after Shutdown()") + } +} + +// TestMockExtension_ShutdownError tests Shutdown error +func TestMockExtension_ShutdownError(t *testing.T) { + expectedErr := errors.New("shutdown failed") + ext := &mockExtension{ + name: "test", + shutdownErr: expectedErr, + } + + err := ext.Shutdown() + if !errors.Is(err, expectedErr) { + t.Errorf("Shutdown() error = %v, want %v", err, expectedErr) + } + + if !ext.shutdownCalled { + t.Error("shutdownCalled should be true even on error") + } +} + +// TestMockSafeBackend_Interface tests mock implements SafeBackend +func TestMockSafeBackend_Interface(t *testing.T) { + var _ SafeBackend = (*mockSafeBackend)(nil) +} + +// TestMockSafeBackend_Methods tests SafeBackend methods +func TestMockSafeBackend_Methods(t *testing.T) { + sb := &mockSafeBackend{} + + if sb.Sys() != nil { + t.Error("Sys() should return nil") + } + + if sb.Clock() != nil { + t.Error("Clock() should return nil") + } + + if sb.TailscaleVarRoot() != "/tmp" { + t.Errorf("TailscaleVarRoot() = %q, want /tmp", sb.TailscaleVarRoot()) + } +} + +// TestHooks_ZeroValue tests Hooks zero value +func TestHooks_ZeroValue(t *testing.T) { + var h Hooks + + // Verify all hooks are zero-valued and usable + _ = h.BackendStateChange + _ = h.ProfileStateChange + _ = h.BackgroundProfileResolvers + _ = h.AuditLoggers + _ = h.NewControlClient + _ = h.OnSelfChange + _ = h.MutateNotifyLocked + _ = h.SetPeerStatus + _ = h.ShouldUploadServices +} + +// TestProfileStateChangeCallback_Type tests callback signature +func TestProfileStateChangeCallback_Type(t *testing.T) { + var callback ProfileStateChangeCallback = func(p ipn.LoginProfileView, pr ipn.PrefsView, sameNode bool) { + // Callback implementation + _ = p + _ = pr + _ = sameNode + } + + if callback == nil { + t.Error("callback should not be nil") + } + + // Test calling the callback + callback(ipn.LoginProfileView{}, ipn.PrefsView{}, true) +} + +// TestNewExtensionFn_Type tests function type +func TestNewExtensionFn_Type(t *testing.T) { + var fn NewExtensionFn = func(logger.Logf, SafeBackend) (Extension, error) { + return &mockExtension{name: "test"}, nil + } + + if fn == nil { + t.Error("fn should not be nil") + } + + ext, err := fn(logger.Discard, &mockSafeBackend{}) + if err != nil { + t.Fatalf("fn() error = %v", err) + } + + if ext.Name() != "test" { + t.Errorf("extension name = %q, want %q", ext.Name(), "test") + } +} + +// TestAuditLogProvider_Type tests provider type +func TestAuditLogProvider_Type(t *testing.T) { + var provider AuditLogProvider = func() ipnauth.AuditLogFunc { + return func(*ipnauth.AuditLogEntry) error { + return nil + } + } + + if provider == nil { + t.Error("provider should not be nil") + } + + fn := provider() + if fn == nil { + t.Error("audit log func should not be nil") + } + + err := fn(&ipnauth.AuditLogEntry{}) + if err != nil { + t.Errorf("audit log func error = %v", err) + } +} + +// TestProfileResolver_Type tests resolver type +func TestProfileResolver_Type(t *testing.T) { + var resolver ProfileResolver = func(ps ProfileStore) ipn.LoginProfileView { + return ps.CurrentProfile() + } + + if resolver == nil { + t.Error("resolver should not be nil") + } +} + +// TestExtensions_EmptyMap tests empty extensions map +func TestExtensions_EmptyMap(t *testing.T) { + defer func() { + extensions = extensions[:0] + }() + + // Reset to empty + extensions = extensions[:0] + + count := 0 + for range Extensions() { + count++ + } + + if count != 0 { + t.Errorf("empty Extensions() should yield 0 items, got %d", count) + } +} + +// TestDefinition_NilNewFn tests nil newFn handling +func TestDefinition_NilNewFn(t *testing.T) { + defer func() { + if r := recover(); r != nil { + // MakeExtension might panic on nil newFn + t.Logf("panic (expected): %v", r) + } + }() + + d := &Definition{ + name: "test", + newFn: nil, + } + + // This should panic or error + _, err := d.MakeExtension(logger.Discard, &mockSafeBackend{}) + if err == nil { + t.Error("MakeExtension() with nil newFn should fail") + } +} + +// TestMultipleExtensions_Registration tests multiple extensions +func TestMultipleExtensions_Registration(t *testing.T) { + defer func() { + extensions = extensions[:0] + }() + + names := []string{"ext-a", "ext-b", "ext-c", "ext-d", "ext-e"} + + for _, name := range names { + n := name // capture + newFn := func(logger.Logf, SafeBackend) (Extension, error) { + return &mockExtension{name: n}, nil + } + RegisterExtension(n, newFn) + } + + if extensions.Len() != 5 { + t.Errorf("extensions count = %d, want 5", extensions.Len()) + } + + for _, name := range names { + if !extensions.Contains(name) { + t.Errorf("extension %q not registered", name) + } + } +}