diff --git a/cmd/localapiclient/localapiclient.go b/cmd/localapiclient/localapiclient.go index a3c6b53..e76f54f 100644 --- a/cmd/localapiclient/localapiclient.go +++ b/cmd/localapiclient/localapiclient.go @@ -2,9 +2,12 @@ package localapiclient import ( "context" + "errors" "fmt" "io" + "net" "net/http" + "sync" "time" "tailscale.com/ipn/localapi" @@ -12,10 +15,12 @@ import ( // Response represents the result of processing an http.Request. type Response struct { - headers http.Header - status int - bodyWriter io.WriteCloser - bodyReader io.ReadCloser + headers http.Header + status int + bodyWriter net.Conn + bodyReader net.Conn + startWritingBody chan interface{} + startWritingBodyOnce sync.Once } func (r *Response) Header() http.Header { @@ -25,6 +30,7 @@ func (r *Response) Header() http.Header { // Write writes the data to the response body, which will be sent to Java. If WriteHeader is not called // explicitly, the first call to Write will trigger an implicit WriteHeader(http.StatusOK). func (r *Response) Write(data []byte) (int, error) { + r.Flush() if r.status == 0 { r.WriteHeader(http.StatusOK) } @@ -35,7 +41,7 @@ func (r *Response) WriteHeader(statusCode int) { r.status = statusCode } -func (r *Response) Body() io.ReadCloser { +func (r *Response) Body() net.Conn { return r.bodyReader } @@ -43,6 +49,12 @@ func (r *Response) StatusCode() int { return r.status } +func (r *Response) Flush() { + r.startWritingBodyOnce.Do(func() { + close(r.startWritingBody) + }) +} + type LocalAPIClient struct { h *localapi.Handler } @@ -57,26 +69,41 @@ func New(h *localapi.Handler) *LocalAPIClient { // local API returned a status code in the 400 series or greater. // Note - Response includes a response body available from the Body method, it // is the caller's responsibility to close this. -func (cl *LocalAPIClient) Call(ctx context.Context, method, endpoint string, body io.Reader) (*Response, error) { - ctx, cancel := context.WithTimeout(ctx, 2*time.Second) +func (cl *LocalAPIClient) Call(method, endpoint string, body io.Reader) (*Response, error) { + deadline := time.Now().Add(2 * time.Second) + ctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() req, err := http.NewRequestWithContext(ctx, method, "/localapi/v0/"+endpoint, body) if err != nil { return nil, fmt.Errorf("error creating new request for %s: %w", endpoint, err) } - pipeReader, pipeWriter := io.Pipe() - defer pipeWriter.Close() + pipeReader, pipeWriter := net.Pipe() + pipeReader.SetDeadline(deadline) + pipeWriter.SetDeadline(deadline) resp := &Response{ - headers: http.Header{}, - status: http.StatusOK, - bodyReader: pipeReader, - bodyWriter: pipeWriter, + headers: http.Header{}, + status: http.StatusOK, + bodyReader: pipeReader, + bodyWriter: pipeWriter, + startWritingBody: make(chan interface{}), } - cl.h.ServeHTTP(resp, req) - if resp.StatusCode() >= 400 { - return resp, fmt.Errorf("request failed with status code %d", resp.StatusCode()) + go func() { + cl.h.ServeHTTP(resp, req) + resp.Flush() + }() + + select { + case <-resp.startWritingBody: + return resp, nil + case <-ctx.Done(): + // handle timeout + return nil, errors.New("timeout") } - return resp, nil + // } + // if resp.StatusCode() >= 400 { + // return resp, fmt.Errorf("request failed with status code %d", resp.StatusCode()) + // } + // return resp, nil }