diff --git a/cmd/localapiclient/localapiclient.go b/cmd/localapiclient/localapiclient.go index bfd9393..efc4222 100644 --- a/cmd/localapiclient/localapiclient.go +++ b/cmd/localapiclient/localapiclient.go @@ -7,8 +7,6 @@ import ( "net" "net/http" "sync" - - "tailscale.com/ipn/localapi" ) // Response represents the result of processing an http.Request. @@ -53,10 +51,10 @@ func (r *Response) Flush() { } type LocalAPIClient struct { - h *localapi.Handler + h http.Handler } -func New(h *localapi.Handler) *LocalAPIClient { +func New(h http.Handler) *LocalAPIClient { return &LocalAPIClient{h: h} } diff --git a/cmd/localapiclient/localapiclient_test.go b/cmd/localapiclient/localapiclient_test.go new file mode 100644 index 0000000..6c4db54 --- /dev/null +++ b/cmd/localapiclient/localapiclient_test.go @@ -0,0 +1,73 @@ +package localapiclient + +import ( + "context" + "io" + "net/http" + "testing" + "time" +) + +var ctx = context.Background() + +type BadStatusHandler struct{} + +func (b *BadStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) +} + +func TestBadStatus(t *testing.T) { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(2*time.Second)) + client := New(&BadStatusHandler{}) + defer cancel() + + _, err := client.Call(ctx, "POST", "test", nil) + + if err.Error() != "request failed with status code 400" { + t.Error("Expected bad status error, but got", err) + } +} + +type TimeoutHandler struct{} + +var successfulResponse = "successful response!" + +func (b *TimeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + time.Sleep(6 * time.Second) + w.Write([]byte(successfulResponse)) +} + +func TestTimeout(t *testing.T) { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(2*time.Second)) + client := New(&TimeoutHandler{}) + defer cancel() + + _, err := client.Call(ctx, "GET", "test", nil) + + if err.Error() != "timeout for test" { + t.Error("Expected timeout error, but got", err) + } +} + +type SuccessfulHandler struct{} + +func (b *SuccessfulHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(successfulResponse)) +} + +func TestSuccess(t *testing.T) { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(2*time.Second)) + client := New(&SuccessfulHandler{}) + defer cancel() + + w, err := client.Call(ctx, "GET", "test", nil) + + if err != nil { + t.Error("Expected no error, but got", err) + } + + report, err := io.ReadAll(w.Body()) + if string(report) != successfulResponse { + t.Error("Expected successful report, but got", report) + } +} diff --git a/cmd/tailscale/main.go b/cmd/tailscale/main.go index 5773cee..2fc5a74 100644 --- a/cmd/tailscale/main.go +++ b/cmd/tailscale/main.go @@ -61,6 +61,7 @@ type App struct { logIDPublicAtomic atomic.Pointer[logid.PublicID] localAPIClient *localapiclient.LocalAPIClient + backend *ipnlocal.LocalBackend // netStates receives the most recent network state. netStates chan BackendState @@ -294,6 +295,7 @@ func (a *App) runBackend(ctx context.Context) error { return err } a.logIDPublicAtomic.Store(&b.logIDPublic) + a.backend = b.backend defer b.CloseTUNs() h := localapi.NewHandler(b.backend, log.Printf, b.sys.NetMon.Get(), *a.logIDPublicAtomic.Load()) @@ -459,8 +461,9 @@ func (a *App) runBackend(ctx context.Context) error { state.Prefs.ExitNodeAllowLANAccess = bool(e) go b.backend.SetPrefs(state.Prefs) case WebAuthEvent: + log.Printf("KARI WEBAUTHEVENT") if !signingIn { - go b.backend.StartLoginInteractive() + go a.login(ctx) signingIn = true } case SetLoginServerEvent: @@ -591,6 +594,18 @@ func (a *App) getBugReportID(ctx context.Context, bugReportChan chan<- string, f bugReportChan <- string(logBytes) } +func (a *App) login(ctx context.Context) { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(2*time.Second)) + defer cancel() + r, err := a.localAPIClient.Call(ctx, "POST", "login-interactive", nil) + defer r.Body().Close() + + if err != nil { + log.Printf("login: %s", err) + a.backend.StartLoginInteractive() + } +} + func (a *App) processWaitingFiles(b *ipnlocal.LocalBackend) error { files, err := b.WaitingFiles() if err != nil { diff --git a/go.mod b/go.mod index 6f6051b..959f759 100644 --- a/go.mod +++ b/go.mod @@ -98,4 +98,4 @@ require ( gvisor.dev/gvisor v0.0.0-20240119233241-c9c1d4f9b186 // indirect inet.af/peercred v0.0.0-20210906144145-0893ea02156a // indirect nhooyr.io/websocket v1.8.10 // indirect -) +) \ No newline at end of file