diff --git a/client/tailscale/client_test.go b/client/tailscale/client_test.go new file mode 100644 index 000000000..4de1138dc --- /dev/null +++ b/client/tailscale/client_test.go @@ -0,0 +1,507 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestLocalClient_Socket(t *testing.T) { + tests := []struct { + name string + lc LocalClient + want string + isPath bool + }{ + { + name: "custom_socket", + lc: LocalClient{Socket: "/custom/path/tailscaled.sock"}, + want: "/custom/path/tailscaled.sock", + }, + { + name: "default_socket", + lc: LocalClient{}, + isPath: true, // Will use platform default + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.lc.socket() + if !tt.isPath && got != tt.want { + t.Errorf("socket() = %q, want %q", got, tt.want) + } + if tt.isPath && got == "" { + t.Error("socket() returned empty for default") + } + }) + } +} + +func TestLocalClient_Dialer(t *testing.T) { + customDialerCalled := false + customDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + customDialerCalled = true + return nil, errors.New("custom dialer called") + } + + lc := &LocalClient{Dial: customDialer} + dialer := lc.dialer() + + _, err := dialer(context.Background(), "tcp", "test:80") + if err == nil { + t.Error("expected error from custom dialer") + } + if !customDialerCalled { + t.Error("custom dialer was not called") + } +} + +func TestLocalClient_DefaultDialer(t *testing.T) { + lc := &LocalClient{} + + // Test with invalid address + _, err := lc.defaultDialer(context.Background(), "tcp", "invalid:80") + if err == nil { + t.Error("defaultDialer should reject invalid address") + } + if !strings.Contains(err.Error(), "unexpected URL address") { + t.Errorf("wrong error: %v", err) + } +} + +func TestAccessDeniedError(t *testing.T) { + baseErr := errors.New("permission denied") + err := &AccessDeniedError{err: baseErr} + + // Test Error() + if !strings.Contains(err.Error(), "Access denied") { + t.Errorf("Error() = %q, want to contain 'Access denied'", err.Error()) + } + + // Test Unwrap() + if err.Unwrap() != baseErr { + t.Errorf("Unwrap() = %v, want %v", err.Unwrap(), baseErr) + } + + // Test IsAccessDeniedError + if !IsAccessDeniedError(err) { + t.Error("IsAccessDeniedError should return true") + } + + // Test with wrapped error + wrappedErr := errors.New("outer error") + if IsAccessDeniedError(wrappedErr) { + t.Error("IsAccessDeniedError should return false for non-AccessDeniedError") + } +} + +func TestPreconditionsFailedError(t *testing.T) { + baseErr := errors.New("precondition not met") + err := &PreconditionsFailedError{err: baseErr} + + // Test Error() + if !strings.Contains(err.Error(), "Preconditions failed") { + t.Errorf("Error() = %q, want to contain 'Preconditions failed'", err.Error()) + } + + // Test Unwrap() + if err.Unwrap() != baseErr { + t.Errorf("Unwrap() = %v, want %v", err.Unwrap(), baseErr) + } + + // Test IsPreconditionsFailedError + if !IsPreconditionsFailedError(err) { + t.Error("IsPreconditionsFailedError should return true") + } +} + +func TestLocalClient_DoLocalRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that Tailscale-Cap header is set + if r.Header.Get("Tailscale-Cap") == "" { + t.Error("Tailscale-Cap header not set") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + req, err := http.NewRequest("GET", "http://local-tailscaled.sock/test", nil) + if err != nil { + t.Fatalf("NewRequest failed: %v", err) + } + + resp, err := lc.DoLocalRequest(req) + if err != nil { + t.Fatalf("DoLocalRequest failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want %d", resp.StatusCode, http.StatusOK) + } +} + +func TestLocalClient_DoLocalRequest_AccessDenied(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + json.NewEncoder(w).Encode(map[string]string{"error": "access denied"}) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + req, _ := http.NewRequest("GET", "http://local-tailscaled.sock/test", nil) + _, err := lc.doLocalRequestNiceError(req) + + if err == nil { + t.Fatal("expected error for 403 response") + } + if !IsAccessDeniedError(err) { + t.Errorf("expected AccessDeniedError, got: %T", err) + } +} + +func TestLocalClient_DoLocalRequest_PreconditionsFailed(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusPreconditionFailed) + json.NewEncoder(w).Encode(map[string]string{"error": "preconditions failed"}) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + req, _ := http.NewRequest("GET", "http://local-tailscaled.sock/test", nil) + _, err := lc.doLocalRequestNiceError(req) + + if err == nil { + t.Fatal("expected error for 412 response") + } + if !IsPreconditionsFailedError(err) { + t.Errorf("expected PreconditionsFailedError, got: %T", err) + } +} + +func TestLocalClient_Send(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Method = %s, want POST", r.Method) + } + if r.URL.Path != "/test/path" { + t.Errorf("Path = %s, want /test/path", r.URL.Path) + } + + body, _ := io.ReadAll(r.Body) + if string(body) != "test body" { + t.Errorf("Body = %q, want %q", body, "test body") + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte("response")) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + body := strings.NewReader("test body") + resp, err := lc.send(context.Background(), "POST", "/test/path", http.StatusOK, body) + if err != nil { + t.Fatalf("send failed: %v", err) + } + + if string(resp) != "response" { + t.Errorf("response = %q, want %q", resp, "response") + } +} + +func TestLocalClient_Get200(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + resp, err := lc.get200(context.Background(), "/test") + if err != nil { + t.Fatalf("get200 failed: %v", err) + } + + if string(resp) != "success" { + t.Errorf("response = %q, want %q", resp, "success") + } +} + +func TestLocalClient_IncrementCounter(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/localapi/v0/upload-client-metrics") { + t.Errorf("unexpected path: %s", r.URL.Path) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + err := lc.IncrementCounter(context.Background(), "test_counter", 5) + if err != nil { + t.Errorf("IncrementCounter failed: %v", err) + } +} + +func TestLocalClient_Goroutines(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("goroutine 1 [running]:\nmain.main()\n")) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + data, err := lc.Goroutines(context.Background()) + if err != nil { + t.Fatalf("Goroutines failed: %v", err) + } + + if !strings.Contains(string(data), "goroutine") { + t.Error("response doesn't contain goroutine info") + } +} + +func TestLocalClient_Metrics(t *testing.T) { + tests := []struct { + name string + method func(*LocalClient, context.Context) ([]byte, error) + path string + }{ + { + name: "DaemonMetrics", + method: (*LocalClient).DaemonMetrics, + path: "/localapi/v0/metrics", + }, + { + name: "UserMetrics", + method: (*LocalClient).UserMetrics, + path: "/localapi/v0/usermetrics", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != tt.path { + t.Errorf("Path = %s, want %s", r.URL.Path, tt.path) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("# HELP metric_name Help text\n")) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + data, err := tt.method(lc, context.Background()) + if err != nil { + t.Fatalf("%s failed: %v", tt.name, err) + } + + if !strings.Contains(string(data), "HELP") { + t.Error("response doesn't contain metrics format") + } + }) + } +} + +func TestLocalClient_ContextCancellation(t *testing.T) { + // Server that delays response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err := lc.get200(ctx, "/test") + if err == nil { + t.Error("expected timeout error") + } + if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context") { + t.Errorf("expected context error, got: %v", err) + } +} + +func TestLocalClient_UseSocketOnly(t *testing.T) { + lc := &LocalClient{ + Socket: "/tmp/test.sock", + UseSocketOnly: true, + } + + // With UseSocketOnly, it should not try TCP port lookup + _, err := lc.defaultDialer(context.Background(), "tcp", "local-tailscaled.sock:80") + // We expect an error since /tmp/test.sock doesn't exist + if err == nil { + t.Error("expected error when socket doesn't exist") + } +} + +func TestLocalClient_OmitAuth(t *testing.T) { + authHeaderSet := false + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + authHeaderSet = true + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + req, _ := http.NewRequest("GET", "http://local-tailscaled.sock/test", nil) + _, err := lc.DoLocalRequest(req) + if err != nil { + t.Fatalf("DoLocalRequest failed: %v", err) + } + + if authHeaderSet { + t.Error("Authorization header should not be set when OmitAuth=true") + } +} + +// Test the error message extraction +func TestErrorMessageFromBody(t *testing.T) { + tests := []struct { + name string + body []byte + want string + }{ + { + name: "json_error", + body: []byte(`{"error":"test error message"}`), + want: "test error message", + }, + { + name: "plain_text", + body: []byte("plain error"), + want: "plain error", + }, + { + name: "empty", + body: []byte{}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := errorMessageFromBody(tt.body) + if got != tt.want { + t.Errorf("errorMessageFromBody() = %q, want %q", got, tt.want) + } + }) + } +} + +// Benchmark key operations +func BenchmarkLocalClient_Send(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + lc := &LocalClient{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", server.Listener.Addr().String()) + }, + OmitAuth: true, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := lc.get200(context.Background(), "/test") + if err != nil { + b.Fatalf("get200 failed: %v", err) + } + } +} diff --git a/ipn/store/mem/store_mem_test.go b/ipn/store/mem/store_mem_test.go new file mode 100644 index 000000000..b78311654 --- /dev/null +++ b/ipn/store/mem/store_mem_test.go @@ -0,0 +1,380 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package mem + +import ( + "bytes" + "encoding/json" + "errors" + "sync" + "testing" + + "tailscale.com/ipn" +) + +func TestNew(t *testing.T) { + store, err := New(t.Logf, "test-id") + if err != nil { + t.Fatalf("New() failed: %v", err) + } + if store == nil { + t.Fatal("New() returned nil store") + } + + // Verify it implements ipn.StateStore + var _ ipn.StateStore = store +} + +func TestStore_String(t *testing.T) { + s := &Store{} + if got := s.String(); got != "mem.Store" { + t.Errorf("String() = %q, want %q", got, "mem.Store") + } +} + +func TestStore_ReadWriteState(t *testing.T) { + s := &Store{} + + key := ipn.StateKey("test-key") + data := []byte("test data") + + // Write state + err := s.WriteState(key, data) + if err != nil { + t.Fatalf("WriteState() failed: %v", err) + } + + // Read state + got, err := s.ReadState(key) + if err != nil { + t.Fatalf("ReadState() failed: %v", err) + } + + if !bytes.Equal(got, data) { + t.Errorf("ReadState() = %q, want %q", got, data) + } +} + +func TestStore_ReadState_NotExist(t *testing.T) { + s := &Store{} + + key := ipn.StateKey("nonexistent") + _, err := s.ReadState(key) + + if !errors.Is(err, ipn.ErrStateNotExist) { + t.Errorf("ReadState() error = %v, want ErrStateNotExist", err) + } +} + +func TestStore_WriteState_Clone(t *testing.T) { + s := &Store{} + + key := ipn.StateKey("test-key") + data := []byte("original data") + + err := s.WriteState(key, data) + if err != nil { + t.Fatalf("WriteState() failed: %v", err) + } + + // Modify original data + data[0] = 'X' + + // Read should return unmodified data + got, err := s.ReadState(key) + if err != nil { + t.Fatalf("ReadState() failed: %v", err) + } + + if bytes.Equal(got, data) { + t.Error("ReadState() returned data that was modified after write (not cloned)") + } + + if got[0] != 'o' { + t.Errorf("ReadState() data was modified, got[0] = %c, want 'o'", got[0]) + } +} + +func TestStore_MultipleKeys(t *testing.T) { + s := &Store{} + + keys := []ipn.StateKey{"key1", "key2", "key3"} + values := [][]byte{ + []byte("value1"), + []byte("value2"), + []byte("value3"), + } + + // Write all keys + for i, key := range keys { + if err := s.WriteState(key, values[i]); err != nil { + t.Fatalf("WriteState(%s) failed: %v", key, err) + } + } + + // Read and verify all keys + for i, key := range keys { + got, err := s.ReadState(key) + if err != nil { + t.Fatalf("ReadState(%s) failed: %v", key, err) + } + if !bytes.Equal(got, values[i]) { + t.Errorf("ReadState(%s) = %q, want %q", key, got, values[i]) + } + } +} + +func TestStore_Overwrite(t *testing.T) { + s := &Store{} + + key := ipn.StateKey("test-key") + + // Write initial value + if err := s.WriteState(key, []byte("first")); err != nil { + t.Fatalf("WriteState() failed: %v", err) + } + + // Overwrite with new value + if err := s.WriteState(key, []byte("second")); err != nil { + t.Fatalf("WriteState() failed: %v", err) + } + + // Read should return latest value + got, err := s.ReadState(key) + if err != nil { + t.Fatalf("ReadState() failed: %v", err) + } + + if string(got) != "second" { + t.Errorf("ReadState() = %q, want %q", got, "second") + } +} + +func TestStore_ExportToJSON_Empty(t *testing.T) { + s := &Store{} + + data, err := s.ExportToJSON() + if err != nil { + t.Fatalf("ExportToJSON() failed: %v", err) + } + + // Empty store should export as {} + if string(data) != "{}" { + t.Errorf("ExportToJSON() = %q, want %q", data, "{}") + } +} + +func TestStore_ExportToJSON_WithData(t *testing.T) { + s := &Store{} + + s.WriteState("key1", []byte("value1")) + s.WriteState("key2", []byte("value2")) + + data, err := s.ExportToJSON() + if err != nil { + t.Fatalf("ExportToJSON() failed: %v", err) + } + + // Parse JSON to verify structure + var result map[string][]byte + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("ExportToJSON() produced invalid JSON: %v", err) + } + + if len(result) != 2 { + t.Errorf("ExportToJSON() exported %d keys, want 2", len(result)) + } + + if !bytes.Equal(result["key1"], []byte("value1")) { + t.Errorf("ExportToJSON() key1 = %q, want %q", result["key1"], "value1") + } + if !bytes.Equal(result["key2"], []byte("value2")) { + t.Errorf("ExportToJSON() key2 = %q, want %q", result["key2"], "value2") + } +} + +func TestStore_LoadFromJSON(t *testing.T) { + s := &Store{} + + jsonData := `{ + "key1": "dmFsdWUx", + "key2": "dmFsdWUy" + }` + + err := s.LoadFromJSON([]byte(jsonData)) + if err != nil { + t.Fatalf("LoadFromJSON() failed: %v", err) + } + + // Verify loaded data + got1, err := s.ReadState("key1") + if err != nil { + t.Fatalf("ReadState(key1) failed: %v", err) + } + + got2, err := s.ReadState("key2") + if err != nil { + t.Fatalf("ReadState(key2) failed: %v", err) + } + + if string(got1) != "value1" { + t.Errorf("ReadState(key1) = %q, want %q", got1, "value1") + } + if string(got2) != "value2" { + t.Errorf("ReadState(key2) = %q, want %q", got2, "value2") + } +} + +func TestStore_LoadFromJSON_Invalid(t *testing.T) { + s := &Store{} + + err := s.LoadFromJSON([]byte("invalid json")) + if err == nil { + t.Error("LoadFromJSON() with invalid JSON succeeded, want error") + } +} + +func TestStore_ExportImportRoundTrip(t *testing.T) { + s1 := &Store{} + + // Write some data + s1.WriteState("key1", []byte("value1")) + s1.WriteState("key2", []byte("value2")) + s1.WriteState("key3", []byte("value3")) + + // Export + exported, err := s1.ExportToJSON() + if err != nil { + t.Fatalf("ExportToJSON() failed: %v", err) + } + + // Import into new store + s2 := &Store{} + if err := s2.LoadFromJSON(exported); err != nil { + t.Fatalf("LoadFromJSON() failed: %v", err) + } + + // Verify all data matches + keys := []ipn.StateKey{"key1", "key2", "key3"} + for _, key := range keys { + val1, err1 := s1.ReadState(key) + val2, err2 := s2.ReadState(key) + + if err1 != nil || err2 != nil { + t.Fatalf("ReadState(%s) failed: err1=%v, err2=%v", key, err1, err2) + } + + if !bytes.Equal(val1, val2) { + t.Errorf("Round-trip failed for %s: %q != %q", key, val1, val2) + } + } +} + +func TestStore_ConcurrentAccess(t *testing.T) { + s := &Store{} + + var wg sync.WaitGroup + numGoroutines := 100 + + // Concurrent writes + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + key := ipn.StateKey(string(rune('a' + n%26))) + s.WriteState(key, []byte{byte(n)}) + }(i) + } + + wg.Wait() + + // Concurrent reads + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + key := ipn.StateKey(string(rune('a' + n%26))) + _, _ = s.ReadState(key) + }(i) + } + + wg.Wait() +} + +func TestStore_EmptyKey(t *testing.T) { + s := &Store{} + + key := ipn.StateKey("") + data := []byte("empty key data") + + // Should be able to use empty key + if err := s.WriteState(key, data); err != nil { + t.Fatalf("WriteState() with empty key failed: %v", err) + } + + got, err := s.ReadState(key) + if err != nil { + t.Fatalf("ReadState() with empty key failed: %v", err) + } + + if !bytes.Equal(got, data) { + t.Errorf("ReadState() = %q, want %q", got, data) + } +} + +func TestStore_NilData(t *testing.T) { + s := &Store{} + + key := ipn.StateKey("test") + + // Write nil data + if err := s.WriteState(key, nil); err != nil { + t.Fatalf("WriteState() with nil data failed: %v", err) + } + + got, err := s.ReadState(key) + if err != nil { + t.Fatalf("ReadState() failed: %v", err) + } + + if got != nil && len(got) != 0 { + t.Errorf("ReadState() = %v, want nil or empty", got) + } +} + +// Benchmark operations +func BenchmarkStore_WriteState(b *testing.B) { + s := &Store{} + key := ipn.StateKey("bench-key") + data := []byte("benchmark data") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.WriteState(key, data) + } +} + +func BenchmarkStore_ReadState(b *testing.B) { + s := &Store{} + key := ipn.StateKey("bench-key") + s.WriteState(key, []byte("benchmark data")) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.ReadState(key) + } +} + +func BenchmarkStore_ExportToJSON(b *testing.B) { + s := &Store{} + for i := 0; i < 100; i++ { + key := ipn.StateKey(string(rune('a' + i%26))) + s.WriteState(key, []byte("data")) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.ExportToJSON() + } +}