Show example of setting deadlines on pipes

oxtocart/bugreportvialocalapi_codereview_suggestion
Percy Wegmann 2 years ago
parent ac536767ce
commit 6cc4ad7569
No known key found for this signature in database
GPG Key ID: 29D8CDEB4C13D48B

@ -2,9 +2,12 @@ package localapiclient
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"sync"
"time" "time"
"tailscale.com/ipn/localapi" "tailscale.com/ipn/localapi"
@ -12,10 +15,12 @@ import (
// Response represents the result of processing an http.Request. // Response represents the result of processing an http.Request.
type Response struct { type Response struct {
headers http.Header headers http.Header
status int status int
bodyWriter io.WriteCloser bodyWriter net.Conn
bodyReader io.ReadCloser bodyReader net.Conn
startWritingBody chan interface{}
startWritingBodyOnce sync.Once
} }
func (r *Response) Header() http.Header { func (r *Response) Header() http.Header {
@ -25,6 +30,7 @@ func (r *Response) Header() http.Header {
// Write writes the data to the response body, which will be sent to Java. If WriteHeader is not called // Write writes the data to the response body, which will be sent to Java. If WriteHeader is not called
// explicitly, the first call to Write will trigger an implicit WriteHeader(http.StatusOK). // explicitly, the first call to Write will trigger an implicit WriteHeader(http.StatusOK).
func (r *Response) Write(data []byte) (int, error) { func (r *Response) Write(data []byte) (int, error) {
r.Flush()
if r.status == 0 { if r.status == 0 {
r.WriteHeader(http.StatusOK) r.WriteHeader(http.StatusOK)
} }
@ -35,7 +41,7 @@ func (r *Response) WriteHeader(statusCode int) {
r.status = statusCode r.status = statusCode
} }
func (r *Response) Body() io.ReadCloser { func (r *Response) Body() net.Conn {
return r.bodyReader return r.bodyReader
} }
@ -43,6 +49,12 @@ func (r *Response) StatusCode() int {
return r.status return r.status
} }
func (r *Response) Flush() {
r.startWritingBodyOnce.Do(func() {
close(r.startWritingBody)
})
}
type LocalAPIClient struct { type LocalAPIClient struct {
h *localapi.Handler h *localapi.Handler
} }
@ -57,26 +69,41 @@ func New(h *localapi.Handler) *LocalAPIClient {
// local API returned a status code in the 400 series or greater. // local API returned a status code in the 400 series or greater.
// Note - Response includes a response body available from the Body method, it // Note - Response includes a response body available from the Body method, it
// is the caller's responsibility to close this. // is the caller's responsibility to close this.
func (cl *LocalAPIClient) Call(ctx context.Context, method, endpoint string, body io.Reader) (*Response, error) { func (cl *LocalAPIClient) Call(method, endpoint string, body io.Reader) (*Response, error) {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second) deadline := time.Now().Add(2 * time.Second)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel() defer cancel()
req, err := http.NewRequestWithContext(ctx, method, "/localapi/v0/"+endpoint, body) req, err := http.NewRequestWithContext(ctx, method, "/localapi/v0/"+endpoint, body)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating new request for %s: %w", endpoint, err) return nil, fmt.Errorf("error creating new request for %s: %w", endpoint, err)
} }
pipeReader, pipeWriter := io.Pipe() pipeReader, pipeWriter := net.Pipe()
defer pipeWriter.Close() pipeReader.SetDeadline(deadline)
pipeWriter.SetDeadline(deadline)
resp := &Response{ resp := &Response{
headers: http.Header{}, headers: http.Header{},
status: http.StatusOK, status: http.StatusOK,
bodyReader: pipeReader, bodyReader: pipeReader,
bodyWriter: pipeWriter, bodyWriter: pipeWriter,
startWritingBody: make(chan interface{}),
} }
cl.h.ServeHTTP(resp, req) go func() {
if resp.StatusCode() >= 400 { cl.h.ServeHTTP(resp, req)
return resp, fmt.Errorf("request failed with status code %d", resp.StatusCode()) resp.Flush()
}()
select {
case <-resp.startWritingBody:
return resp, nil
case <-ctx.Done():
// handle timeout
return nil, errors.New("timeout")
} }
return resp, nil // }
// if resp.StatusCode() >= 400 {
// return resp, fmt.Errorf("request failed with status code %d", resp.StatusCode())
// }
// return resp, nil
} }

Loading…
Cancel
Save