From 5e9e57ecf531f26692413ecddebfa6172550dd44 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 27 Oct 2022 13:58:35 -0700 Subject: [PATCH] control/controlhttp: add AcceptHTTP hook to add coalesced Server->Client write New plan for #5972. Instead of sending the public key in the clear (from earlier unreleased 246274b8e91) where the client might have to worry about it being dropped or tampered with and retrying, we'll instead send it post-Noise handshake but before the HTTP/2 connection begins. This replaces the earlier extraHeaders hook with a different sort of hook that allows us to combine two writes on the wire in one packet. Updates #5972 Change-Id: I42cdf7c1859b53ca4dfa5610bd1b840c6986e09c Signed-off-by: Brad Fitzpatrick --- control/controlhttp/http_test.go | 30 +++++++- control/controlhttp/server.go | 113 +++++++++++++++++++++++++++---- 2 files changed, 128 insertions(+), 15 deletions(-) diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index d8ce4b43a..6c380f369 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -37,6 +37,8 @@ type httpTestParam struct { // makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a // 101 switching protocols. makeHTTPHangAfterUpgrade bool + + doEarlyWrite bool } func TestControlHTTP(t *testing.T) { @@ -111,6 +113,11 @@ func TestControlHTTP(t *testing.T) { allowHTTP: true, }, }, + // Early write + { + name: "early_write", + doEarlyWrite: true, + }, } for _, test := range tests { @@ -125,9 +132,21 @@ func testControlHTTP(t *testing.T, param httpTestParam) { client, server := key.NewMachine(), key.NewMachine() const testProtocolVersion = 1 + const earlyWriteMsg = "Hello, world!" sch := make(chan serverResult, 1) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server, nil) + var earlyWriteFn func(protocolVersion int, w io.Writer) error + if param.doEarlyWrite { + earlyWriteFn = func(protocolVersion int, w io.Writer) error { + if protocolVersion != testProtocolVersion { + t.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion) + return fmt.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion) + } + _, err := io.WriteString(w, earlyWriteMsg) + return err + } + } + conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn) if err != nil { log.Print(err) } @@ -228,6 +247,15 @@ func testControlHTTP(t *testing.T, param httpTestParam) { if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) { t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr) } + if param.doEarlyWrite { + buf := make([]byte, len(earlyWriteMsg)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Fatalf("reading early write: %v", err) + } + if string(buf) != earlyWriteMsg { + t.Errorf("early write = %q; want %q", buf, earlyWriteMsg) + } + } } type serverResult struct { diff --git a/control/controlhttp/server.go b/control/controlhttp/server.go index 748da2527..6aed6ea39 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/server.go @@ -9,7 +9,10 @@ import ( "encoding/base64" "errors" "fmt" + "io" + "net" "net/http" + "time" "nhooyr.io/websocket" "tailscale.com/control/controlbase" @@ -18,16 +21,20 @@ import ( "tailscale.com/types/key" ) -// AcceptHTTP upgrades the HTTP request given by w and r into a -// Tailscale control protocol base transport connection. +// AcceptHTTP upgrades the HTTP request given by w and r into a Tailscale +// control protocol base transport connection. // -// AcceptHTTP always writes an HTTP response to w. The caller must not -// attempt their own response after calling AcceptHTTP. +// AcceptHTTP always writes an HTTP response to w. The caller must not attempt +// their own response after calling AcceptHTTP. // -// extraHeader optionally specifies extra header(s) to send in the -// 101 Switching Protocols Upgrade response. It must not include the "Upgrade" -// or "Connection" headers; they will be replaced. -func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, extraHeader http.Header) (*controlbase.Conn, error) { +// earlyWrite optionally specifies a func to write to the noise connection +// (encrypted). It receives the negotiated version and a writer to write to, if +// desired. +func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (*controlbase.Conn, error) { + return acceptHTTP(ctx, w, r, private, earlyWrite) +} + +func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (_ *controlbase.Conn, retErr error) { next := r.Header.Get("Upgrade") if next == "" { http.Error(w, "missing next protocol", http.StatusBadRequest) @@ -58,9 +65,6 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri return nil, errors.New("can't hijack client connection") } - for k, vv := range extraHeader { - w.Header()[k] = vv - } w.Header().Set("Upgrade", upgradeHeaderValue) w.Header().Set("Connection", "upgrade") w.WriteHeader(http.StatusSwitchingProtocols) @@ -69,18 +73,41 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri if err != nil { return nil, fmt.Errorf("hijacking client connection: %w", err) } + + defer func() { + if retErr != nil { + conn.Close() + } + }() + if err := brw.Flush(); err != nil { - conn.Close() return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err) } conn = netutil.NewDrainBufConn(conn, brw.Reader) - nc, err := controlbase.Server(ctx, conn, private, init) + cwc := newWriteCorkingConn(conn) + + nc, err := controlbase.Server(ctx, cwc, private, init) if err != nil { - conn.Close() return nil, fmt.Errorf("noise handshake failed: %w", err) } + if earlyWrite != nil { + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer conn.SetDeadline(time.Time{}) + } + if err := earlyWrite(nc.ProtocolVersion(), nc); err != nil { + return nil, err + } + } + + if err := cwc.uncork(); err != nil { + return nil, err + } + return nc, nil } @@ -128,3 +155,61 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request return nc, nil } + +// corkConn is a net.Conn wrapper that initially buffers all writes until uncork +// is called. If the conn is corked and a Read occurs, the Read will flush any +// buffered (corked) write. +// +// Until uncorked, Read/Write/uncork may be not called concurrently. +// +// Deadlines still work, but a corked write ignores deadlines until a Read or +// uncork goes to do that Write. +// +// Use newWriteCorkingConn to create one. +type corkConn struct { + net.Conn + corked bool + buf []byte // corked data +} + +func newWriteCorkingConn(c net.Conn) *corkConn { + return &corkConn{Conn: c, corked: true} +} + +func (c *corkConn) Write(b []byte) (int, error) { + if c.corked { + c.buf = append(c.buf, b...) + return len(b), nil + } + return c.Conn.Write(b) +} + +func (c *corkConn) Read(b []byte) (int, error) { + if c.corked { + if err := c.flush(); err != nil { + return 0, err + } + } + return c.Conn.Read(b) +} + +// uncork flushes any buffered data and uncorks the connection so future Writes +// don't buffer. It may not be called concurrently with reads or writes and +// may only be called once. +func (c *corkConn) uncork() error { + if !c.corked { + panic("usage error; uncork called twice") // worth panicking to catch misuse + } + err := c.flush() + c.corked = false + return err +} + +func (c *corkConn) flush() error { + if len(c.buf) == 0 { + return nil + } + _, err := c.Conn.Write(c.buf) + c.buf = nil + return err +}