From ac536767ce7f934d5469adc709ad218b6c1f39ee Mon Sep 17 00:00:00 2001 From: Percy Wegmann Date: Thu, 8 Feb 2024 17:35:36 -0600 Subject: [PATCH] Code review suggestions --- cmd/localapiclient/localapiclient.go | 101 +++++++++++++-------------- cmd/tailscale/main.go | 24 +++++-- 2 files changed, 66 insertions(+), 59 deletions(-) diff --git a/cmd/localapiclient/localapiclient.go b/cmd/localapiclient/localapiclient.go index 6c695c5..a3c6b53 100644 --- a/cmd/localapiclient/localapiclient.go +++ b/cmd/localapiclient/localapiclient.go @@ -1,87 +1,82 @@ package localapiclient import ( - "bytes" - "errors" + "context" "fmt" - "log" + "io" "net/http" "time" "tailscale.com/ipn/localapi" ) -// LocalAPIResponseWriter substitutes for http.ResponseWriter in order to write byte streams directly -// to a receiver function in the application. -type LocalApiResponseWriter struct { - headers http.Header - body bytes.Buffer - status int +// Response represents the result of processing an http.Request. +type Response struct { + headers http.Header + status int + bodyWriter io.WriteCloser + bodyReader io.ReadCloser } -func newLocalApiResponseWriter() *LocalApiResponseWriter { - return &LocalApiResponseWriter{headers: http.Header{}, status: http.StatusOK} -} - -func (w *LocalApiResponseWriter) Header() http.Header { - return w.headers +func (r *Response) Header() http.Header { + return r.headers } // Write writes the data to the response body, which will be sent to Java. If WriteHeader is not called // explicitly, the first call to Write will trigger an implicit WriteHeader(http.StatusOK). -func (w *LocalApiResponseWriter) Write(data []byte) (int, error) { - if w.status == 0 { - w.WriteHeader(http.StatusOK) +func (r *Response) Write(data []byte) (int, error) { + if r.status == 0 { + r.WriteHeader(http.StatusOK) } - return w.body.Write(data) + return r.bodyWriter.Write(data) } -func (w *LocalApiResponseWriter) WriteHeader(statusCode int) { - w.status = statusCode +func (r *Response) WriteHeader(statusCode int) { + r.status = statusCode } -func (w *LocalApiResponseWriter) Body() []byte { - return w.body.Bytes() +func (r *Response) Body() io.ReadCloser { + return r.bodyReader } -func (w *LocalApiResponseWriter) StatusCode() int { - return w.status +func (r *Response) StatusCode() int { + return r.status } -type LocalApiClient struct { +type LocalAPIClient struct { h *localapi.Handler } -func NewLocalApiClient(h *localapi.Handler) LocalApiClient { - return LocalApiClient{h: h} +func New(h *localapi.Handler) *LocalAPIClient { + return &LocalAPIClient{h: h} } -var ErrBadHttpStatus = errors.New("bad http status for localapi response") +// Call calls the given endpoint on the local API using the given HTTP method +// optionally sending the given body. It returns a Response representing the +// result of the call and an error if the call could not be completed or the +// local API returned a status code in the 400 series or greater. +// Note - Response includes a response body available from the Body method, it +// is the caller's responsibility to close this. +func (cl *LocalAPIClient) Call(ctx context.Context, method, endpoint string, body io.Reader) (*Response, error) { + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() -func CallLocalApi(h *localapi.Handler, method string, endpoint string) (*LocalApiResponseWriter, error) { - done := make(chan *LocalApiResponseWriter, 1) - var responseError error - go func() { - req, err := http.NewRequest(method, "/localapi/v0/"+endpoint, nil) - if err != nil { - log.Printf("error creating new request for %s: %v", endpoint, err) - responseError = err - close(done) - return - } - w := newLocalApiResponseWriter() - h.ServeHTTP(w, req) - if w.StatusCode() > 300 { - log.Printf("%s bad http status: %v", endpoint, w.StatusCode()) - responseError = ErrBadHttpStatus - } - done <- w - }() + req, err := http.NewRequestWithContext(ctx, method, "/localapi/v0/"+endpoint, body) + if err != nil { + return nil, fmt.Errorf("error creating new request for %s: %w", endpoint, err) + } + pipeReader, pipeWriter := io.Pipe() + defer pipeWriter.Close() - select { - case w := <-done: - return w, responseError - case <-time.After(2 * time.Second): - return nil, fmt.Errorf("request to %s timed out", endpoint) + resp := &Response{ + headers: http.Header{}, + status: http.StatusOK, + bodyReader: pipeReader, + bodyWriter: pipeWriter, + } + cl.h.ServeHTTP(resp, req) + if resp.StatusCode() >= 400 { + return resp, fmt.Errorf("request failed with status code %d", resp.StatusCode()) } + return resp, nil } diff --git a/cmd/tailscale/main.go b/cmd/tailscale/main.go index 2e3d8a1..e3f5727 100644 --- a/cmd/tailscale/main.go +++ b/cmd/tailscale/main.go @@ -58,6 +58,8 @@ type App struct { store *stateStore logIDPublicAtomic atomic.Pointer[logid.PublicID] + localAPIClient *localapiclient.LocalAPIClient + // netStates receives the most recent network state. netStates chan BackendState // prefs receives new preferences from the backend. @@ -292,6 +294,7 @@ func (a *App) runBackend() error { h := localapi.NewHandler(b.backend, log.Printf, b.sys.NetMon.Get(), *a.logIDPublicAtomic.Load()) h.PermitRead = true h.PermitWrite = true + a.localAPIClient = localapiclient.New(h) // Contrary to the documentation for VpnService.Builder.addDnsServer, // ChromeOS doesn't fall back to the underlying network nameservers if @@ -438,7 +441,7 @@ func (a *App) runBackend() error { case BugEvent: backendLogIDStr := a.logIDPublicAtomic.Load().String() fallbackLog := fmt.Sprintf("BUG-%v-%v-%v", backendLogIDStr, time.Now().UTC().Format("20060102150405Z"), randHex(8)) - getBugReportID(h, a.bugReport, fallbackLog) + a.getBugReportID(a.bugReport, fallbackLog) case OAuth2Event: go b.backend.Login(e.Token) case ToggleEvent: @@ -563,13 +566,22 @@ func (a *App) runBackend() error { } } -func getBugReportID(h *localapi.Handler, bugReportChan chan<- string, fallbackLog string) { - w, err := localapiclient.CallLocalApi(h, "POST", "bugreport") - if w == nil || err != nil { +func (a *App) getBugReportID(bugReportChan chan<- string, fallbackLog string) { + r, err := a.localAPIClient.Call(a.appCtx, "POST", "bugreport", nil) + defer r.Body().Close() + + if err != nil { + log.Printf("get bug report: %s", err) + bugReportChan <- fallbackLog + return + } + logBytes, err := io.ReadAll(r.Body()) + if err != nil { + log.Printf("read bug report: %s", err) bugReportChan <- fallbackLog - } else { - bugReportChan <- string(w.Body()) + return } + bugReportChan <- string(logBytes) } func (a *App) processWaitingFiles(b *ipnlocal.LocalBackend) error {