diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index d8ce4b43a..6c380f369 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -37,6 +37,8 @@ type httpTestParam struct { // makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a // 101 switching protocols. makeHTTPHangAfterUpgrade bool + + doEarlyWrite bool } func TestControlHTTP(t *testing.T) { @@ -111,6 +113,11 @@ func TestControlHTTP(t *testing.T) { allowHTTP: true, }, }, + // Early write + { + name: "early_write", + doEarlyWrite: true, + }, } for _, test := range tests { @@ -125,9 +132,21 @@ func testControlHTTP(t *testing.T, param httpTestParam) { client, server := key.NewMachine(), key.NewMachine() const testProtocolVersion = 1 + const earlyWriteMsg = "Hello, world!" sch := make(chan serverResult, 1) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server, nil) + var earlyWriteFn func(protocolVersion int, w io.Writer) error + if param.doEarlyWrite { + earlyWriteFn = func(protocolVersion int, w io.Writer) error { + if protocolVersion != testProtocolVersion { + t.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion) + return fmt.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion) + } + _, err := io.WriteString(w, earlyWriteMsg) + return err + } + } + conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn) if err != nil { log.Print(err) } @@ -228,6 +247,15 @@ func testControlHTTP(t *testing.T, param httpTestParam) { if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) { t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr) } + if param.doEarlyWrite { + buf := make([]byte, len(earlyWriteMsg)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Fatalf("reading early write: %v", err) + } + if string(buf) != earlyWriteMsg { + t.Errorf("early write = %q; want %q", buf, earlyWriteMsg) + } + } } type serverResult struct { diff --git a/control/controlhttp/server.go b/control/controlhttp/server.go index 748da2527..6aed6ea39 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/server.go @@ -9,7 +9,10 @@ import ( "encoding/base64" "errors" "fmt" + "io" + "net" "net/http" + "time" "nhooyr.io/websocket" "tailscale.com/control/controlbase" @@ -18,16 +21,20 @@ import ( "tailscale.com/types/key" ) -// AcceptHTTP upgrades the HTTP request given by w and r into a -// Tailscale control protocol base transport connection. +// AcceptHTTP upgrades the HTTP request given by w and r into a Tailscale +// control protocol base transport connection. // -// AcceptHTTP always writes an HTTP response to w. The caller must not -// attempt their own response after calling AcceptHTTP. +// AcceptHTTP always writes an HTTP response to w. The caller must not attempt +// their own response after calling AcceptHTTP. // -// extraHeader optionally specifies extra header(s) to send in the -// 101 Switching Protocols Upgrade response. It must not include the "Upgrade" -// or "Connection" headers; they will be replaced. -func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, extraHeader http.Header) (*controlbase.Conn, error) { +// earlyWrite optionally specifies a func to write to the noise connection +// (encrypted). It receives the negotiated version and a writer to write to, if +// desired. +func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (*controlbase.Conn, error) { + return acceptHTTP(ctx, w, r, private, earlyWrite) +} + +func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (_ *controlbase.Conn, retErr error) { next := r.Header.Get("Upgrade") if next == "" { http.Error(w, "missing next protocol", http.StatusBadRequest) @@ -58,9 +65,6 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri return nil, errors.New("can't hijack client connection") } - for k, vv := range extraHeader { - w.Header()[k] = vv - } w.Header().Set("Upgrade", upgradeHeaderValue) w.Header().Set("Connection", "upgrade") w.WriteHeader(http.StatusSwitchingProtocols) @@ -69,18 +73,41 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri if err != nil { return nil, fmt.Errorf("hijacking client connection: %w", err) } + + defer func() { + if retErr != nil { + conn.Close() + } + }() + if err := brw.Flush(); err != nil { - conn.Close() return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err) } conn = netutil.NewDrainBufConn(conn, brw.Reader) - nc, err := controlbase.Server(ctx, conn, private, init) + cwc := newWriteCorkingConn(conn) + + nc, err := controlbase.Server(ctx, cwc, private, init) if err != nil { - conn.Close() return nil, fmt.Errorf("noise handshake failed: %w", err) } + if earlyWrite != nil { + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer conn.SetDeadline(time.Time{}) + } + if err := earlyWrite(nc.ProtocolVersion(), nc); err != nil { + return nil, err + } + } + + if err := cwc.uncork(); err != nil { + return nil, err + } + return nc, nil } @@ -128,3 +155,61 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request return nc, nil } + +// corkConn is a net.Conn wrapper that initially buffers all writes until uncork +// is called. If the conn is corked and a Read occurs, the Read will flush any +// buffered (corked) write. +// +// Until uncorked, Read/Write/uncork may be not called concurrently. +// +// Deadlines still work, but a corked write ignores deadlines until a Read or +// uncork goes to do that Write. +// +// Use newWriteCorkingConn to create one. +type corkConn struct { + net.Conn + corked bool + buf []byte // corked data +} + +func newWriteCorkingConn(c net.Conn) *corkConn { + return &corkConn{Conn: c, corked: true} +} + +func (c *corkConn) Write(b []byte) (int, error) { + if c.corked { + c.buf = append(c.buf, b...) + return len(b), nil + } + return c.Conn.Write(b) +} + +func (c *corkConn) Read(b []byte) (int, error) { + if c.corked { + if err := c.flush(); err != nil { + return 0, err + } + } + return c.Conn.Read(b) +} + +// uncork flushes any buffered data and uncorks the connection so future Writes +// don't buffer. It may not be called concurrently with reads or writes and +// may only be called once. +func (c *corkConn) uncork() error { + if !c.corked { + panic("usage error; uncork called twice") // worth panicking to catch misuse + } + err := c.flush() + c.corked = false + return err +} + +func (c *corkConn) flush() error { + if len(c.buf) == 0 { + return nil + } + _, err := c.Conn.Write(c.buf) + c.buf = nil + return err +}