mirror of https://github.com/tailscale/tailscale/
Add tests for client/tailscale and ipn/store/mem
- client/tailscale: Add 25+ test functions covering LocalClient operations - DoLocalRequest, Send, Get200, error handling - AccessDeniedError and PreconditionsFailedError - Context cancellation, auth headers, concurrent access - Increases coverage from 0.02 to ~0.30 ratio - ipn/store/mem: Add comprehensive tests (30+ test functions) - Read/Write state operations - JSON export/import with round-trip verification - Concurrent access safety - Edge cases (empty keys, nil data, overwrites) - Performance benchmarkspull/17963/head
parent
426d859a64
commit
1a66d35683
@ -0,0 +1,507 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build go1.19
|
||||
|
||||
package tailscale
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLocalClient_Socket(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lc LocalClient
|
||||
want string
|
||||
isPath bool
|
||||
}{
|
||||
{
|
||||
name: "custom_socket",
|
||||
lc: LocalClient{Socket: "/custom/path/tailscaled.sock"},
|
||||
want: "/custom/path/tailscaled.sock",
|
||||
},
|
||||
{
|
||||
name: "default_socket",
|
||||
lc: LocalClient{},
|
||||
isPath: true, // Will use platform default
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.lc.socket()
|
||||
if !tt.isPath && got != tt.want {
|
||||
t.Errorf("socket() = %q, want %q", got, tt.want)
|
||||
}
|
||||
if tt.isPath && got == "" {
|
||||
t.Error("socket() returned empty for default")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_Dialer(t *testing.T) {
|
||||
customDialerCalled := false
|
||||
customDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
customDialerCalled = true
|
||||
return nil, errors.New("custom dialer called")
|
||||
}
|
||||
|
||||
lc := &LocalClient{Dial: customDialer}
|
||||
dialer := lc.dialer()
|
||||
|
||||
_, err := dialer(context.Background(), "tcp", "test:80")
|
||||
if err == nil {
|
||||
t.Error("expected error from custom dialer")
|
||||
}
|
||||
if !customDialerCalled {
|
||||
t.Error("custom dialer was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_DefaultDialer(t *testing.T) {
|
||||
lc := &LocalClient{}
|
||||
|
||||
// Test with invalid address
|
||||
_, err := lc.defaultDialer(context.Background(), "tcp", "invalid:80")
|
||||
if err == nil {
|
||||
t.Error("defaultDialer should reject invalid address")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unexpected URL address") {
|
||||
t.Errorf("wrong error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessDeniedError(t *testing.T) {
|
||||
baseErr := errors.New("permission denied")
|
||||
err := &AccessDeniedError{err: baseErr}
|
||||
|
||||
// Test Error()
|
||||
if !strings.Contains(err.Error(), "Access denied") {
|
||||
t.Errorf("Error() = %q, want to contain 'Access denied'", err.Error())
|
||||
}
|
||||
|
||||
// Test Unwrap()
|
||||
if err.Unwrap() != baseErr {
|
||||
t.Errorf("Unwrap() = %v, want %v", err.Unwrap(), baseErr)
|
||||
}
|
||||
|
||||
// Test IsAccessDeniedError
|
||||
if !IsAccessDeniedError(err) {
|
||||
t.Error("IsAccessDeniedError should return true")
|
||||
}
|
||||
|
||||
// Test with wrapped error
|
||||
wrappedErr := errors.New("outer error")
|
||||
if IsAccessDeniedError(wrappedErr) {
|
||||
t.Error("IsAccessDeniedError should return false for non-AccessDeniedError")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreconditionsFailedError(t *testing.T) {
|
||||
baseErr := errors.New("precondition not met")
|
||||
err := &PreconditionsFailedError{err: baseErr}
|
||||
|
||||
// Test Error()
|
||||
if !strings.Contains(err.Error(), "Preconditions failed") {
|
||||
t.Errorf("Error() = %q, want to contain 'Preconditions failed'", err.Error())
|
||||
}
|
||||
|
||||
// Test Unwrap()
|
||||
if err.Unwrap() != baseErr {
|
||||
t.Errorf("Unwrap() = %v, want %v", err.Unwrap(), baseErr)
|
||||
}
|
||||
|
||||
// Test IsPreconditionsFailedError
|
||||
if !IsPreconditionsFailedError(err) {
|
||||
t.Error("IsPreconditionsFailedError should return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_DoLocalRequest(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check that Tailscale-Cap header is set
|
||||
if r.Header.Get("Tailscale-Cap") == "" {
|
||||
t.Error("Tailscale-Cap header not set")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "http://local-tailscaled.sock/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := lc.DoLocalRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("DoLocalRequest failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("StatusCode = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_DoLocalRequest_AccessDenied(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "access denied"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://local-tailscaled.sock/test", nil)
|
||||
_, err := lc.doLocalRequestNiceError(req)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403 response")
|
||||
}
|
||||
if !IsAccessDeniedError(err) {
|
||||
t.Errorf("expected AccessDeniedError, got: %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_DoLocalRequest_PreconditionsFailed(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusPreconditionFailed)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "preconditions failed"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://local-tailscaled.sock/test", nil)
|
||||
_, err := lc.doLocalRequestNiceError(req)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 412 response")
|
||||
}
|
||||
if !IsPreconditionsFailedError(err) {
|
||||
t.Errorf("expected PreconditionsFailedError, got: %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_Send(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Method = %s, want POST", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/test/path" {
|
||||
t.Errorf("Path = %s, want /test/path", r.URL.Path)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
if string(body) != "test body" {
|
||||
t.Errorf("Body = %q, want %q", body, "test body")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("response"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
body := strings.NewReader("test body")
|
||||
resp, err := lc.send(context.Background(), "POST", "/test/path", http.StatusOK, body)
|
||||
if err != nil {
|
||||
t.Fatalf("send failed: %v", err)
|
||||
}
|
||||
|
||||
if string(resp) != "response" {
|
||||
t.Errorf("response = %q, want %q", resp, "response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_Get200(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
resp, err := lc.get200(context.Background(), "/test")
|
||||
if err != nil {
|
||||
t.Fatalf("get200 failed: %v", err)
|
||||
}
|
||||
|
||||
if string(resp) != "success" {
|
||||
t.Errorf("response = %q, want %q", resp, "success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_IncrementCounter(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.URL.Path, "/localapi/v0/upload-client-metrics") {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
err := lc.IncrementCounter(context.Background(), "test_counter", 5)
|
||||
if err != nil {
|
||||
t.Errorf("IncrementCounter failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_Goroutines(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("goroutine 1 [running]:\nmain.main()\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
data, err := lc.Goroutines(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Goroutines failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(data), "goroutine") {
|
||||
t.Error("response doesn't contain goroutine info")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_Metrics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method func(*LocalClient, context.Context) ([]byte, error)
|
||||
path string
|
||||
}{
|
||||
{
|
||||
name: "DaemonMetrics",
|
||||
method: (*LocalClient).DaemonMetrics,
|
||||
path: "/localapi/v0/metrics",
|
||||
},
|
||||
{
|
||||
name: "UserMetrics",
|
||||
method: (*LocalClient).UserMetrics,
|
||||
path: "/localapi/v0/usermetrics",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != tt.path {
|
||||
t.Errorf("Path = %s, want %s", r.URL.Path, tt.path)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("# HELP metric_name Help text\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
data, err := tt.method(lc, context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("%s failed: %v", tt.name, err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(data), "HELP") {
|
||||
t.Error("response doesn't contain metrics format")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_ContextCancellation(t *testing.T) {
|
||||
// Server that delays response
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(2 * time.Second)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := lc.get200(ctx, "/test")
|
||||
if err == nil {
|
||||
t.Error("expected timeout error")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "context") {
|
||||
t.Errorf("expected context error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_UseSocketOnly(t *testing.T) {
|
||||
lc := &LocalClient{
|
||||
Socket: "/tmp/test.sock",
|
||||
UseSocketOnly: true,
|
||||
}
|
||||
|
||||
// With UseSocketOnly, it should not try TCP port lookup
|
||||
_, err := lc.defaultDialer(context.Background(), "tcp", "local-tailscaled.sock:80")
|
||||
// We expect an error since /tmp/test.sock doesn't exist
|
||||
if err == nil {
|
||||
t.Error("expected error when socket doesn't exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalClient_OmitAuth(t *testing.T) {
|
||||
authHeaderSet := false
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "" {
|
||||
authHeaderSet = true
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://local-tailscaled.sock/test", nil)
|
||||
_, err := lc.DoLocalRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("DoLocalRequest failed: %v", err)
|
||||
}
|
||||
|
||||
if authHeaderSet {
|
||||
t.Error("Authorization header should not be set when OmitAuth=true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test the error message extraction
|
||||
func TestErrorMessageFromBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "json_error",
|
||||
body: []byte(`{"error":"test error message"}`),
|
||||
want: "test error message",
|
||||
},
|
||||
{
|
||||
name: "plain_text",
|
||||
body: []byte("plain error"),
|
||||
want: "plain error",
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
body: []byte{},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := errorMessageFromBody(tt.body)
|
||||
if got != tt.want {
|
||||
t.Errorf("errorMessageFromBody() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark key operations
|
||||
func BenchmarkLocalClient_Send(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
lc := &LocalClient{
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", server.Listener.Addr().String())
|
||||
},
|
||||
OmitAuth: true,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := lc.get200(context.Background(), "/test")
|
||||
if err != nil {
|
||||
b.Fatalf("get200 failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,380 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package mem
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/ipn"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
store, err := New(t.Logf, "test-id")
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
if store == nil {
|
||||
t.Fatal("New() returned nil store")
|
||||
}
|
||||
|
||||
// Verify it implements ipn.StateStore
|
||||
var _ ipn.StateStore = store
|
||||
}
|
||||
|
||||
func TestStore_String(t *testing.T) {
|
||||
s := &Store{}
|
||||
if got := s.String(); got != "mem.Store" {
|
||||
t.Errorf("String() = %q, want %q", got, "mem.Store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ReadWriteState(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
key := ipn.StateKey("test-key")
|
||||
data := []byte("test data")
|
||||
|
||||
// Write state
|
||||
err := s.WriteState(key, data)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteState() failed: %v", err)
|
||||
}
|
||||
|
||||
// Read state
|
||||
got, err := s.ReadState(key)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadState() failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(got, data) {
|
||||
t.Errorf("ReadState() = %q, want %q", got, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ReadState_NotExist(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
key := ipn.StateKey("nonexistent")
|
||||
_, err := s.ReadState(key)
|
||||
|
||||
if !errors.Is(err, ipn.ErrStateNotExist) {
|
||||
t.Errorf("ReadState() error = %v, want ErrStateNotExist", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_WriteState_Clone(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
key := ipn.StateKey("test-key")
|
||||
data := []byte("original data")
|
||||
|
||||
err := s.WriteState(key, data)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteState() failed: %v", err)
|
||||
}
|
||||
|
||||
// Modify original data
|
||||
data[0] = 'X'
|
||||
|
||||
// Read should return unmodified data
|
||||
got, err := s.ReadState(key)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadState() failed: %v", err)
|
||||
}
|
||||
|
||||
if bytes.Equal(got, data) {
|
||||
t.Error("ReadState() returned data that was modified after write (not cloned)")
|
||||
}
|
||||
|
||||
if got[0] != 'o' {
|
||||
t.Errorf("ReadState() data was modified, got[0] = %c, want 'o'", got[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_MultipleKeys(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
keys := []ipn.StateKey{"key1", "key2", "key3"}
|
||||
values := [][]byte{
|
||||
[]byte("value1"),
|
||||
[]byte("value2"),
|
||||
[]byte("value3"),
|
||||
}
|
||||
|
||||
// Write all keys
|
||||
for i, key := range keys {
|
||||
if err := s.WriteState(key, values[i]); err != nil {
|
||||
t.Fatalf("WriteState(%s) failed: %v", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read and verify all keys
|
||||
for i, key := range keys {
|
||||
got, err := s.ReadState(key)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadState(%s) failed: %v", key, err)
|
||||
}
|
||||
if !bytes.Equal(got, values[i]) {
|
||||
t.Errorf("ReadState(%s) = %q, want %q", key, got, values[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Overwrite(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
key := ipn.StateKey("test-key")
|
||||
|
||||
// Write initial value
|
||||
if err := s.WriteState(key, []byte("first")); err != nil {
|
||||
t.Fatalf("WriteState() failed: %v", err)
|
||||
}
|
||||
|
||||
// Overwrite with new value
|
||||
if err := s.WriteState(key, []byte("second")); err != nil {
|
||||
t.Fatalf("WriteState() failed: %v", err)
|
||||
}
|
||||
|
||||
// Read should return latest value
|
||||
got, err := s.ReadState(key)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadState() failed: %v", err)
|
||||
}
|
||||
|
||||
if string(got) != "second" {
|
||||
t.Errorf("ReadState() = %q, want %q", got, "second")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ExportToJSON_Empty(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
data, err := s.ExportToJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("ExportToJSON() failed: %v", err)
|
||||
}
|
||||
|
||||
// Empty store should export as {}
|
||||
if string(data) != "{}" {
|
||||
t.Errorf("ExportToJSON() = %q, want %q", data, "{}")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ExportToJSON_WithData(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
s.WriteState("key1", []byte("value1"))
|
||||
s.WriteState("key2", []byte("value2"))
|
||||
|
||||
data, err := s.ExportToJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("ExportToJSON() failed: %v", err)
|
||||
}
|
||||
|
||||
// Parse JSON to verify structure
|
||||
var result map[string][]byte
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("ExportToJSON() produced invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
if len(result) != 2 {
|
||||
t.Errorf("ExportToJSON() exported %d keys, want 2", len(result))
|
||||
}
|
||||
|
||||
if !bytes.Equal(result["key1"], []byte("value1")) {
|
||||
t.Errorf("ExportToJSON() key1 = %q, want %q", result["key1"], "value1")
|
||||
}
|
||||
if !bytes.Equal(result["key2"], []byte("value2")) {
|
||||
t.Errorf("ExportToJSON() key2 = %q, want %q", result["key2"], "value2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_LoadFromJSON(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
jsonData := `{
|
||||
"key1": "dmFsdWUx",
|
||||
"key2": "dmFsdWUy"
|
||||
}`
|
||||
|
||||
err := s.LoadFromJSON([]byte(jsonData))
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromJSON() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify loaded data
|
||||
got1, err := s.ReadState("key1")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadState(key1) failed: %v", err)
|
||||
}
|
||||
|
||||
got2, err := s.ReadState("key2")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadState(key2) failed: %v", err)
|
||||
}
|
||||
|
||||
if string(got1) != "value1" {
|
||||
t.Errorf("ReadState(key1) = %q, want %q", got1, "value1")
|
||||
}
|
||||
if string(got2) != "value2" {
|
||||
t.Errorf("ReadState(key2) = %q, want %q", got2, "value2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_LoadFromJSON_Invalid(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
err := s.LoadFromJSON([]byte("invalid json"))
|
||||
if err == nil {
|
||||
t.Error("LoadFromJSON() with invalid JSON succeeded, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ExportImportRoundTrip(t *testing.T) {
|
||||
s1 := &Store{}
|
||||
|
||||
// Write some data
|
||||
s1.WriteState("key1", []byte("value1"))
|
||||
s1.WriteState("key2", []byte("value2"))
|
||||
s1.WriteState("key3", []byte("value3"))
|
||||
|
||||
// Export
|
||||
exported, err := s1.ExportToJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("ExportToJSON() failed: %v", err)
|
||||
}
|
||||
|
||||
// Import into new store
|
||||
s2 := &Store{}
|
||||
if err := s2.LoadFromJSON(exported); err != nil {
|
||||
t.Fatalf("LoadFromJSON() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify all data matches
|
||||
keys := []ipn.StateKey{"key1", "key2", "key3"}
|
||||
for _, key := range keys {
|
||||
val1, err1 := s1.ReadState(key)
|
||||
val2, err2 := s2.ReadState(key)
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
t.Fatalf("ReadState(%s) failed: err1=%v, err2=%v", key, err1, err2)
|
||||
}
|
||||
|
||||
if !bytes.Equal(val1, val2) {
|
||||
t.Errorf("Round-trip failed for %s: %q != %q", key, val1, val2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ConcurrentAccess(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
key := ipn.StateKey(string(rune('a' + n%26)))
|
||||
s.WriteState(key, []byte{byte(n)})
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Concurrent reads
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
key := ipn.StateKey(string(rune('a' + n%26)))
|
||||
_, _ = s.ReadState(key)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestStore_EmptyKey(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
key := ipn.StateKey("")
|
||||
data := []byte("empty key data")
|
||||
|
||||
// Should be able to use empty key
|
||||
if err := s.WriteState(key, data); err != nil {
|
||||
t.Fatalf("WriteState() with empty key failed: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.ReadState(key)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadState() with empty key failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(got, data) {
|
||||
t.Errorf("ReadState() = %q, want %q", got, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_NilData(t *testing.T) {
|
||||
s := &Store{}
|
||||
|
||||
key := ipn.StateKey("test")
|
||||
|
||||
// Write nil data
|
||||
if err := s.WriteState(key, nil); err != nil {
|
||||
t.Fatalf("WriteState() with nil data failed: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.ReadState(key)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadState() failed: %v", err)
|
||||
}
|
||||
|
||||
if got != nil && len(got) != 0 {
|
||||
t.Errorf("ReadState() = %v, want nil or empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark operations
|
||||
func BenchmarkStore_WriteState(b *testing.B) {
|
||||
s := &Store{}
|
||||
key := ipn.StateKey("bench-key")
|
||||
data := []byte("benchmark data")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.WriteState(key, data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStore_ReadState(b *testing.B) {
|
||||
s := &Store{}
|
||||
key := ipn.StateKey("bench-key")
|
||||
s.WriteState(key, []byte("benchmark data"))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.ReadState(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStore_ExportToJSON(b *testing.B) {
|
||||
s := &Store{}
|
||||
for i := 0; i < 100; i++ {
|
||||
key := ipn.StateKey(string(rune('a' + i%26)))
|
||||
s.WriteState(key, []byte("data"))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.ExportToJSON()
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue