You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
tailscale/ipn/ipnext/ipnext_test.go

581 lines
13 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ipnext
import (
"errors"
"fmt"
"testing"
"tailscale.com/ipn"
"tailscale.com/tsd"
"tailscale.com/tstime"
"tailscale.com/types/logger"
)
// mockExtension implements Extension for testing
type mockExtension struct {
name string
initErr error
shutdownErr error
initCalled bool
shutdownCalled bool
}
func (m *mockExtension) Name() string { return m.name }
func (m *mockExtension) Init(Host) error {
m.initCalled = true
return m.initErr
}
func (m *mockExtension) Shutdown() error {
m.shutdownCalled = true
return m.shutdownErr
}
// mockSafeBackend implements SafeBackend for testing
type mockSafeBackend struct{}
func (m *mockSafeBackend) Sys() *tsd.System { return nil }
func (m *mockSafeBackend) Clock() tstime.Clock { return nil }
func (m *mockSafeBackend) TailscaleVarRoot() string { return "/tmp" }
// TestDefinition_Name tests Definition.Name()
func TestDefinition_Name(t *testing.T) {
d := &Definition{name: "test-extension"}
if got := d.Name(); got != "test-extension" {
t.Errorf("Name() = %q, want %q", got, "test-extension")
}
}
// TestDefinition_MakeExtension tests successful extension creation
func TestDefinition_MakeExtension(t *testing.T) {
ext := &mockExtension{name: "test"}
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
return ext, nil
}
d := &Definition{
name: "test",
newFn: newFn,
}
logf := logger.Discard
sb := &mockSafeBackend{}
got, err := d.MakeExtension(logf, sb)
if err != nil {
t.Fatalf("MakeExtension() error = %v", err)
}
if got != ext {
t.Error("MakeExtension() returned wrong extension")
}
}
// TestDefinition_MakeExtension_NameMismatch tests name validation
func TestDefinition_MakeExtension_NameMismatch(t *testing.T) {
ext := &mockExtension{name: "wrong-name"}
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
return ext, nil
}
d := &Definition{
name: "expected-name",
newFn: newFn,
}
logf := logger.Discard
sb := &mockSafeBackend{}
_, err := d.MakeExtension(logf, sb)
if err == nil {
t.Fatal("MakeExtension() should error on name mismatch")
}
wantErr := `extension name mismatch: registered "expected-name"; actual "wrong-name"`
if err.Error() != wantErr {
t.Errorf("error = %q, want %q", err.Error(), wantErr)
}
}
// TestDefinition_MakeExtension_NewFnError tests error propagation
func TestDefinition_MakeExtension_NewFnError(t *testing.T) {
expectedErr := errors.New("creation failed")
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
return nil, expectedErr
}
d := &Definition{
name: "test",
newFn: newFn,
}
logf := logger.Discard
sb := &mockSafeBackend{}
_, err := d.MakeExtension(logf, sb)
if !errors.Is(err, expectedErr) {
t.Errorf("MakeExtension() error = %v, want %v", err, expectedErr)
}
}
// TestRegisterExtension_Panic_NilFunc tests nil function panic
func TestRegisterExtension_Panic_NilFunc(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("RegisterExtension() should panic with nil function")
} else {
got := fmt.Sprint(r)
want := `ipnext: newExt is nil: "test"`
if got != want {
t.Errorf("panic message = %q, want %q", got, want)
}
}
// Reset extensions map after test
extensions = extensions[:0]
}()
RegisterExtension("test", nil)
}
// TestRegisterExtension_Panic_Duplicate tests duplicate name panic
func TestRegisterExtension_Panic_Duplicate(t *testing.T) {
defer func() {
// Reset extensions map after test
extensions = extensions[:0]
}()
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
return &mockExtension{name: "test"}, nil
}
// First registration should succeed
RegisterExtension("test", newFn)
// Second registration should panic
defer func() {
if r := recover(); r == nil {
t.Error("RegisterExtension() should panic on duplicate")
} else {
got := fmt.Sprint(r)
want := `ipnext: duplicate extension name "test"`
if got != want {
t.Errorf("panic message = %q, want %q", got, want)
}
}
}()
RegisterExtension("test", newFn)
}
// TestRegisterExtension_Success tests successful registration
func TestRegisterExtension_Success(t *testing.T) {
defer func() {
extensions = extensions[:0]
}()
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
return &mockExtension{name: "test"}, nil
}
RegisterExtension("test", newFn)
if !extensions.Contains("test") {
t.Error("extension not registered")
}
def, ok := extensions.Get("test")
if !ok {
t.Fatal("failed to get registered extension")
}
if def.name != "test" {
t.Errorf("registered name = %q, want %q", def.name, "test")
}
}
// TestExtensions_Iterator tests Extensions() iteration
func TestExtensions_Iterator(t *testing.T) {
defer func() {
extensions = extensions[:0]
}()
newFn := func(name string) NewExtensionFn {
return func(logger.Logf, SafeBackend) (Extension, error) {
return &mockExtension{name: name}, nil
}
}
RegisterExtension("ext1", newFn("ext1"))
RegisterExtension("ext2", newFn("ext2"))
RegisterExtension("ext3", newFn("ext3"))
count := 0
seen := make(map[string]bool)
for def := range Extensions() {
count++
seen[def.name] = true
}
if count != 3 {
t.Errorf("Extensions() count = %d, want 3", count)
}
for _, name := range []string{"ext1", "ext2", "ext3"} {
if !seen[name] {
t.Errorf("extension %q not seen in iteration", name)
}
}
}
// TestExtensions_Order tests iteration order preservation
func TestExtensions_Order(t *testing.T) {
defer func() {
extensions = extensions[:0]
}()
newFn := func(name string) NewExtensionFn {
return func(logger.Logf, SafeBackend) (Extension, error) {
return &mockExtension{name: name}, nil
}
}
RegisterExtension("first", newFn("first"))
RegisterExtension("second", newFn("second"))
RegisterExtension("third", newFn("third"))
var order []string
for def := range Extensions() {
order = append(order, def.name)
}
want := []string{"first", "second", "third"}
if len(order) != len(want) {
t.Fatalf("order length = %d, want %d", len(order), len(want))
}
for i, name := range want {
if order[i] != name {
t.Errorf("order[%d] = %q, want %q", i, order[i], name)
}
}
}
// TestDefinitionForTest tests test helper
func TestDefinitionForTest(t *testing.T) {
ext := &mockExtension{name: "test-ext"}
def := DefinitionForTest(ext)
if def.name != "test-ext" {
t.Errorf("name = %q, want %q", def.name, "test-ext")
}
logf := logger.Discard
sb := &mockSafeBackend{}
got, err := def.MakeExtension(logf, sb)
if err != nil {
t.Fatalf("MakeExtension() error = %v", err)
}
if got != ext {
t.Error("MakeExtension() returned wrong extension")
}
}
// TestDefinitionWithErrForTest tests error test helper
func TestDefinitionWithErrForTest(t *testing.T) {
expectedErr := errors.New("test error")
def := DefinitionWithErrForTest("error-ext", expectedErr)
if def.name != "error-ext" {
t.Errorf("name = %q, want %q", def.name, "error-ext")
}
logf := logger.Discard
sb := &mockSafeBackend{}
_, err := def.MakeExtension(logf, sb)
if !errors.Is(err, expectedErr) {
t.Errorf("MakeExtension() error = %v, want %v", err, expectedErr)
}
}
// TestSkipExtension_Error tests SkipExtension error
func TestSkipExtension_Error(t *testing.T) {
if SkipExtension == nil {
t.Fatal("SkipExtension should not be nil")
}
want := "skipping extension"
if SkipExtension.Error() != want {
t.Errorf("SkipExtension.Error() = %q, want %q", SkipExtension.Error(), want)
}
}
// TestSkipExtension_Wrapped tests wrapped SkipExtension
func TestSkipExtension_Wrapped(t *testing.T) {
wrapped := fmt.Errorf("platform not supported: %w", SkipExtension)
if !errors.Is(wrapped, SkipExtension) {
t.Error("wrapped error should be SkipExtension")
}
}
// TestMockExtension_Interface tests mock implements Extension
func TestMockExtension_Interface(t *testing.T) {
var _ Extension = (*mockExtension)(nil)
}
// TestMockExtension_Init tests Init tracking
func TestMockExtension_Init(t *testing.T) {
ext := &mockExtension{name: "test"}
if ext.initCalled {
t.Error("initCalled should be false initially")
}
err := ext.Init(nil)
if err != nil {
t.Errorf("Init() error = %v", err)
}
if !ext.initCalled {
t.Error("initCalled should be true after Init()")
}
}
// TestMockExtension_InitError tests Init error
func TestMockExtension_InitError(t *testing.T) {
expectedErr := errors.New("init failed")
ext := &mockExtension{
name: "test",
initErr: expectedErr,
}
err := ext.Init(nil)
if !errors.Is(err, expectedErr) {
t.Errorf("Init() error = %v, want %v", err, expectedErr)
}
if !ext.initCalled {
t.Error("initCalled should be true even on error")
}
}
// TestMockExtension_Shutdown tests Shutdown tracking
func TestMockExtension_Shutdown(t *testing.T) {
ext := &mockExtension{name: "test"}
if ext.shutdownCalled {
t.Error("shutdownCalled should be false initially")
}
err := ext.Shutdown()
if err != nil {
t.Errorf("Shutdown() error = %v", err)
}
if !ext.shutdownCalled {
t.Error("shutdownCalled should be true after Shutdown()")
}
}
// TestMockExtension_ShutdownError tests Shutdown error
func TestMockExtension_ShutdownError(t *testing.T) {
expectedErr := errors.New("shutdown failed")
ext := &mockExtension{
name: "test",
shutdownErr: expectedErr,
}
err := ext.Shutdown()
if !errors.Is(err, expectedErr) {
t.Errorf("Shutdown() error = %v, want %v", err, expectedErr)
}
if !ext.shutdownCalled {
t.Error("shutdownCalled should be true even on error")
}
}
// TestMockSafeBackend_Interface tests mock implements SafeBackend
func TestMockSafeBackend_Interface(t *testing.T) {
var _ SafeBackend = (*mockSafeBackend)(nil)
}
// TestMockSafeBackend_Methods tests SafeBackend methods
func TestMockSafeBackend_Methods(t *testing.T) {
sb := &mockSafeBackend{}
if sb.Sys() != nil {
t.Error("Sys() should return nil")
}
if sb.Clock() != nil {
t.Error("Clock() should return nil")
}
if sb.TailscaleVarRoot() != "/tmp" {
t.Errorf("TailscaleVarRoot() = %q, want /tmp", sb.TailscaleVarRoot())
}
}
// TestHooks_ZeroValue tests Hooks zero value
func TestHooks_ZeroValue(t *testing.T) {
var h Hooks
// Verify all hooks are zero-valued and usable
_ = h.BackendStateChange
_ = h.ProfileStateChange
_ = h.BackgroundProfileResolvers
_ = h.AuditLoggers
_ = h.NewControlClient
_ = h.OnSelfChange
_ = h.MutateNotifyLocked
_ = h.SetPeerStatus
_ = h.ShouldUploadServices
}
// TestProfileStateChangeCallback_Type tests callback signature
func TestProfileStateChangeCallback_Type(t *testing.T) {
var callback ProfileStateChangeCallback = func(p ipn.LoginProfileView, pr ipn.PrefsView, sameNode bool) {
// Callback implementation
_ = p
_ = pr
_ = sameNode
}
if callback == nil {
t.Error("callback should not be nil")
}
// Test calling the callback
callback(ipn.LoginProfileView{}, ipn.PrefsView{}, true)
}
// TestNewExtensionFn_Type tests function type
func TestNewExtensionFn_Type(t *testing.T) {
var fn NewExtensionFn = func(logger.Logf, SafeBackend) (Extension, error) {
return &mockExtension{name: "test"}, nil
}
if fn == nil {
t.Error("fn should not be nil")
}
ext, err := fn(logger.Discard, &mockSafeBackend{})
if err != nil {
t.Fatalf("fn() error = %v", err)
}
if ext.Name() != "test" {
t.Errorf("extension name = %q, want %q", ext.Name(), "test")
}
}
// TestAuditLogProvider_Type tests provider type
func TestAuditLogProvider_Type(t *testing.T) {
var provider AuditLogProvider = func() ipnauth.AuditLogFunc {
return func(*ipnauth.AuditLogEntry) error {
return nil
}
}
if provider == nil {
t.Error("provider should not be nil")
}
fn := provider()
if fn == nil {
t.Error("audit log func should not be nil")
}
err := fn(&ipnauth.AuditLogEntry{})
if err != nil {
t.Errorf("audit log func error = %v", err)
}
}
// TestProfileResolver_Type tests resolver type
func TestProfileResolver_Type(t *testing.T) {
var resolver ProfileResolver = func(ps ProfileStore) ipn.LoginProfileView {
return ps.CurrentProfile()
}
if resolver == nil {
t.Error("resolver should not be nil")
}
}
// TestExtensions_EmptyMap tests empty extensions map
func TestExtensions_EmptyMap(t *testing.T) {
defer func() {
extensions = extensions[:0]
}()
// Reset to empty
extensions = extensions[:0]
count := 0
for range Extensions() {
count++
}
if count != 0 {
t.Errorf("empty Extensions() should yield 0 items, got %d", count)
}
}
// TestDefinition_NilNewFn tests nil newFn handling
func TestDefinition_NilNewFn(t *testing.T) {
defer func() {
if r := recover(); r != nil {
// MakeExtension might panic on nil newFn
t.Logf("panic (expected): %v", r)
}
}()
d := &Definition{
name: "test",
newFn: nil,
}
// This should panic or error
_, err := d.MakeExtension(logger.Discard, &mockSafeBackend{})
if err == nil {
t.Error("MakeExtension() with nil newFn should fail")
}
}
// TestMultipleExtensions_Registration tests multiple extensions
func TestMultipleExtensions_Registration(t *testing.T) {
defer func() {
extensions = extensions[:0]
}()
names := []string{"ext-a", "ext-b", "ext-c", "ext-d", "ext-e"}
for _, name := range names {
n := name // capture
newFn := func(logger.Logf, SafeBackend) (Extension, error) {
return &mockExtension{name: n}, nil
}
RegisterExtension(n, newFn)
}
if extensions.Len() != 5 {
t.Errorf("extensions count = %d, want 5", extensions.Len())
}
for _, name := range names {
if !extensions.Contains(name) {
t.Errorf("extension %q not registered", name)
}
}
}