diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index 6bb9b5182..3cf55ccbf 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -428,3 +428,360 @@ func TestKeepItSorted(t *testing.T) { } } } + +// ===== defBool Tests ===== + +func TestDefBool(t *testing.T) { + tests := []struct { + name string + input string + def bool + expected bool + }{ + {"empty_default_true", "", true, true}, + {"empty_default_false", "", false, false}, + {"true_string", "true", false, true}, + {"false_string", "false", true, false}, + {"1_string", "1", false, true}, + {"0_string", "0", true, false}, + {"t_string", "t", false, true}, + {"f_string", "f", true, false}, + {"invalid_uses_default_true", "invalid", true, true}, + {"invalid_uses_default_false", "invalid", false, false}, + {"True_uppercase", "True", false, true}, + {"FALSE_uppercase", "FALSE", true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := defBool(tt.input, tt.def) + if got != tt.expected { + t.Errorf("defBool(%q, %v) = %v, want %v", tt.input, tt.def, got, tt.expected) + } + }) + } +} + +// ===== dnsMessageTypeForString Tests ===== + +func TestDNSMessageTypeForString(t *testing.T) { + tests := []struct { + input string + expected string // type name for comparison + wantErr bool + }{ + {"A", "TypeA", false}, + {"AAAA", "TypeAAAA", false}, + {"CNAME", "TypeCNAME", false}, + {"MX", "TypeMX", false}, + {"NS", "TypeNS", false}, + {"PTR", "TypePTR", false}, + {"SOA", "TypeSOA", false}, + {"SRV", "TypeSRV", false}, + {"TXT", "TypeTXT", false}, + {"ALL", "TypeALL", false}, + {"HINFO", "TypeHINFO", false}, + {"MINFO", "TypeMINFO", false}, + {"OPT", "TypeOPT", false}, + {"WKS", "TypeWKS", false}, + // Lowercase should work (gets uppercased) + {"a", "TypeA", false}, + {"aaaa", "TypeAAAA", false}, + {"txt", "TypeTXT", false}, + // With whitespace (gets trimmed) + {" A ", "TypeA", false}, + {" AAAA ", "TypeAAAA", false}, + // Invalid types + {"INVALID", "", true}, + {"", "", true}, + {"UNKNOWN", "", true}, + {"B", "", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, err := dnsMessageTypeForString(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("dnsMessageTypeForString(%q) succeeded, want error", tt.input) + } + return + } + if err != nil { + t.Errorf("dnsMessageTypeForString(%q) failed: %v", tt.input, err) + return + } + // We can't directly compare dnsmessage.Type values easily, + // but we can check that we got a non-zero value for valid types + if got == 0 { + t.Errorf("dnsMessageTypeForString(%q) = 0, want non-zero type", tt.input) + } + }) + } +} + +// ===== handlerForPath Tests ===== + +func TestHandlerForPath(t *testing.T) { + tests := []struct { + path string + wantRoute string + wantOK bool + wantPrefix bool // whether it's a prefix match + }{ + {"/", "/", true, false}, + {"/localapi/v0/status", "/localapi/v0/status", true, false}, + {"/localapi/v0/prefs", "/localapi/v0/prefs", true, false}, + {"/localapi/v0/profiles/", "/localapi/v0/profiles/", true, true}, + {"/localapi/v0/profiles/123", "/localapi/v0/profiles/", true, true}, + {"/localapi/v0/start", "/localapi/v0/start", true, false}, + {"/localapi/v0/shutdown", "/localapi/v0/shutdown", true, false}, + {"/localapi/v0/ping", "/localapi/v0/ping", true, false}, + {"/localapi/v0/whois", "/localapi/v0/whois", true, false}, + {"/localapi/v0/goroutines", "/localapi/v0/goroutines", true, false}, + {"/localapi/v0/derpmap", "/localapi/v0/derpmap", true, false}, + // Invalid paths + {"/invalid", "", false, false}, + {"/localapi/invalid", "", false, false}, + {"/api/v0/status", "", false, false}, + {"/localapi/v1/status", "", false, false}, + {"/localapi/v0/nonexistent", "", false, false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + fn, route, ok := handlerForPath(tt.path) + if ok != tt.wantOK { + t.Errorf("handlerForPath(%q) ok = %v, want %v", tt.path, ok, tt.wantOK) + } + if route != tt.wantRoute { + t.Errorf("handlerForPath(%q) route = %q, want %q", tt.path, route, tt.wantRoute) + } + if tt.wantOK && fn == nil { + t.Errorf("handlerForPath(%q) returned nil handler", tt.path) + } + if !tt.wantOK && fn != nil { + t.Errorf("handlerForPath(%q) returned non-nil handler for invalid path", tt.path) + } + }) + } +} + +func TestHandlerForPath_PrefixMatching(t *testing.T) { + // Test that prefix matches work correctly + _, route1, ok1 := handlerForPath("/localapi/v0/profiles/") + _, route2, ok2 := handlerForPath("/localapi/v0/profiles/current") + _, route3, ok3 := handlerForPath("/localapi/v0/profiles/123/switch") + + if !ok1 || !ok2 || !ok3 { + t.Error("prefix matching should work for all profiles/ paths") + } + + // All should return the same route (the prefix) + if route1 != "/localapi/v0/profiles/" { + t.Errorf("route1 = %q, want /localapi/v0/profiles/", route1) + } + if route2 != "/localapi/v0/profiles/" { + t.Errorf("route2 = %q, want /localapi/v0/profiles/", route2) + } + if route3 != "/localapi/v0/profiles/" { + t.Errorf("route3 = %q, want /localapi/v0/profiles/", route3) + } +} + +// ===== WriteErrorJSON Tests ===== + +func TestWriteErrorJSON(t *testing.T) { + tests := []struct { + name string + err error + wantStatus int + wantBodySubstr string + }{ + { + name: "simple_error", + err: errors.New("test error"), + wantStatus: http.StatusInternalServerError, + wantBodySubstr: "test error", + }, + { + name: "nil_error", + err: nil, + wantStatus: http.StatusInternalServerError, + wantBodySubstr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + WriteErrorJSON(rec, tt.err) + + if rec.Code != tt.wantStatus { + t.Errorf("status = %d, want %d", rec.Code, tt.wantStatus) + } + + if tt.wantBodySubstr != "" && !strings.Contains(rec.Body.String(), tt.wantBodySubstr) { + t.Errorf("body = %q, want to contain %q", rec.Body.String(), tt.wantBodySubstr) + } + + // Check Content-Type + ct := rec.Header().Get("Content-Type") + if ct != "application/json" { + t.Errorf("Content-Type = %q, want application/json", ct) + } + }) + } +} + +// ===== Register Tests ===== + +func TestRegister(t *testing.T) { + // Save the original handler map + originalHandler := handler + + // Create a test handler function + testHandler := func(h *Handler, w http.ResponseWriter, r *http.Request) { + w.Write([]byte("test")) + } + + // Register a new handler + testRoute := "test-route-12345" + Register(testRoute, testHandler) + + // Verify it was registered + fn, route, ok := handlerForPath("/localapi/v0/" + testRoute) + if !ok { + t.Error("registered route not found") + } + if route != "/localapi/v0/"+testRoute { + t.Errorf("route = %q, want %q", route, "/localapi/v0/"+testRoute) + } + if fn == nil { + t.Error("registered handler is nil") + } + + // Restore original handler map + handler = originalHandler +} + +// ===== InUseOtherUserIPNStream Tests ===== + +func TestInUseOtherUserIPNStream(t *testing.T) { + tests := []struct { + name string + err error + wantHandled bool + }{ + { + name: "in_use_error", + err: ipn.ErrStateNotExist, + wantHandled: true, + }, + { + name: "other_error", + err: errors.New("some other error"), + wantHandled: false, + }, + { + name: "nil_error", + err: nil, + wantHandled: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + + handled := InUseOtherUserIPNStream(rec, req, tt.err) + + if handled != tt.wantHandled { + t.Errorf("InUseOtherUserIPNStream() handled = %v, want %v", handled, tt.wantHandled) + } + + if tt.wantHandled && rec.Code != http.StatusForbidden { + t.Errorf("status = %d, want %d for handled error", rec.Code, http.StatusForbidden) + } + }) + } +} + +// ===== Handler Permission Tests ===== + +func TestHandler_PermitRead(t *testing.T) { + h := &Handler{ + PermitRead: true, + b: &ipnlocal.LocalBackend{}, + } + + if !h.PermitRead { + t.Error("PermitRead should be true") + } +} + +func TestHandler_PermitWrite(t *testing.T) { + h := &Handler{ + PermitWrite: true, + b: &ipnlocal.LocalBackend{}, + } + + if !h.PermitWrite { + t.Error("PermitWrite should be true") + } +} + +func TestHandler_PermitCert(t *testing.T) { + h := &Handler{ + PermitCert: true, + b: &ipnlocal.LocalBackend{}, + } + + if !h.PermitCert { + t.Error("PermitCert should be true") + } +} + +func TestHandler_RequiredPassword(t *testing.T) { + h := &Handler{ + RequiredPassword: "test-password", + b: &ipnlocal.LocalBackend{}, + } + + if h.RequiredPassword != "test-password" { + t.Errorf("RequiredPassword = %q, want %q", h.RequiredPassword, "test-password") + } +} + +// ===== Handler Methods Tests ===== + +func TestHandler_Logf(t *testing.T) { + var logged bool + logf := func(format string, args ...any) { + logged = true + } + + h := &Handler{ + logf: logf, + b: &ipnlocal.LocalBackend{}, + } + + h.Logf("test message") + + if !logged { + t.Error("Logf did not call the logger function") + } +} + +func TestHandler_LocalBackend(t *testing.T) { + lb := &ipnlocal.LocalBackend{} + h := &Handler{ + b: lb, + } + + got := h.LocalBackend() + if got != lb { + t.Error("LocalBackend() returned wrong backend") + } +}