control/controlhttp: add AcceptHTTP hook to add coalesced Server->Client write

New plan for #5972. Instead of sending the public key in the clear
(from earlier unreleased 246274b8e9) where the client might have to
worry about it being dropped or tampered with and retrying, we'll
instead send it post-Noise handshake but before the HTTP/2 connection
begins.

This replaces the earlier extraHeaders hook with a different sort of
hook that allows us to combine two writes on the wire in one packet.

Updates #5972

Change-Id: I42cdf7c1859b53ca4dfa5610bd1b840c6986e09c
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/6105/head
Brad Fitzpatrick 2 years ago committed by Brad Fitzpatrick
parent c21a3c4733
commit 5e9e57ecf5

@ -37,6 +37,8 @@ type httpTestParam struct {
// makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a // makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
// 101 switching protocols. // 101 switching protocols.
makeHTTPHangAfterUpgrade bool makeHTTPHangAfterUpgrade bool
doEarlyWrite bool
} }
func TestControlHTTP(t *testing.T) { func TestControlHTTP(t *testing.T) {
@ -111,6 +113,11 @@ func TestControlHTTP(t *testing.T) {
allowHTTP: true, allowHTTP: true,
}, },
}, },
// Early write
{
name: "early_write",
doEarlyWrite: true,
},
} }
for _, test := range tests { for _, test := range tests {
@ -125,9 +132,21 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
client, server := key.NewMachine(), key.NewMachine() client, server := key.NewMachine(), key.NewMachine()
const testProtocolVersion = 1 const testProtocolVersion = 1
const earlyWriteMsg = "Hello, world!"
sch := make(chan serverResult, 1) sch := make(chan serverResult, 1)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := AcceptHTTP(context.Background(), w, r, server, nil) var earlyWriteFn func(protocolVersion int, w io.Writer) error
if param.doEarlyWrite {
earlyWriteFn = func(protocolVersion int, w io.Writer) error {
if protocolVersion != testProtocolVersion {
t.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
return fmt.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
}
_, err := io.WriteString(w, earlyWriteMsg)
return err
}
}
conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
} }
@ -228,6 +247,15 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) { if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) {
t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr) t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr)
} }
if param.doEarlyWrite {
buf := make([]byte, len(earlyWriteMsg))
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatalf("reading early write: %v", err)
}
if string(buf) != earlyWriteMsg {
t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
}
}
} }
type serverResult struct { type serverResult struct {

@ -9,7 +9,10 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"io"
"net"
"net/http" "net/http"
"time"
"nhooyr.io/websocket" "nhooyr.io/websocket"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
@ -18,16 +21,20 @@ import (
"tailscale.com/types/key" "tailscale.com/types/key"
) )
// AcceptHTTP upgrades the HTTP request given by w and r into a // AcceptHTTP upgrades the HTTP request given by w and r into a Tailscale
// Tailscale control protocol base transport connection. // control protocol base transport connection.
// //
// AcceptHTTP always writes an HTTP response to w. The caller must not // AcceptHTTP always writes an HTTP response to w. The caller must not attempt
// attempt their own response after calling AcceptHTTP. // their own response after calling AcceptHTTP.
// //
// extraHeader optionally specifies extra header(s) to send in the // earlyWrite optionally specifies a func to write to the noise connection
// 101 Switching Protocols Upgrade response. It must not include the "Upgrade" // (encrypted). It receives the negotiated version and a writer to write to, if
// or "Connection" headers; they will be replaced. // desired.
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, extraHeader http.Header) (*controlbase.Conn, error) { func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (*controlbase.Conn, error) {
return acceptHTTP(ctx, w, r, private, earlyWrite)
}
func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (_ *controlbase.Conn, retErr error) {
next := r.Header.Get("Upgrade") next := r.Header.Get("Upgrade")
if next == "" { if next == "" {
http.Error(w, "missing next protocol", http.StatusBadRequest) http.Error(w, "missing next protocol", http.StatusBadRequest)
@ -58,9 +65,6 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
return nil, errors.New("can't hijack client connection") return nil, errors.New("can't hijack client connection")
} }
for k, vv := range extraHeader {
w.Header()[k] = vv
}
w.Header().Set("Upgrade", upgradeHeaderValue) w.Header().Set("Upgrade", upgradeHeaderValue)
w.Header().Set("Connection", "upgrade") w.Header().Set("Connection", "upgrade")
w.WriteHeader(http.StatusSwitchingProtocols) w.WriteHeader(http.StatusSwitchingProtocols)
@ -69,18 +73,41 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
if err != nil { if err != nil {
return nil, fmt.Errorf("hijacking client connection: %w", err) return nil, fmt.Errorf("hijacking client connection: %w", err)
} }
if err := brw.Flush(); err != nil {
defer func() {
if retErr != nil {
conn.Close() conn.Close()
}
}()
if err := brw.Flush(); err != nil {
return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err) return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err)
} }
conn = netutil.NewDrainBufConn(conn, brw.Reader) conn = netutil.NewDrainBufConn(conn, brw.Reader)
nc, err := controlbase.Server(ctx, conn, private, init) cwc := newWriteCorkingConn(conn)
nc, err := controlbase.Server(ctx, cwc, private, init)
if err != nil { if err != nil {
conn.Close()
return nil, fmt.Errorf("noise handshake failed: %w", err) return nil, fmt.Errorf("noise handshake failed: %w", err)
} }
if earlyWrite != nil {
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err)
}
defer conn.SetDeadline(time.Time{})
}
if err := earlyWrite(nc.ProtocolVersion(), nc); err != nil {
return nil, err
}
}
if err := cwc.uncork(); err != nil {
return nil, err
}
return nc, nil return nc, nil
} }
@ -128,3 +155,61 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request
return nc, nil return nc, nil
} }
// corkConn is a net.Conn wrapper that initially buffers all writes until uncork
// is called. If the conn is corked and a Read occurs, the Read will flush any
// buffered (corked) write.
//
// Until uncorked, Read/Write/uncork may be not called concurrently.
//
// Deadlines still work, but a corked write ignores deadlines until a Read or
// uncork goes to do that Write.
//
// Use newWriteCorkingConn to create one.
type corkConn struct {
net.Conn
corked bool
buf []byte // corked data
}
func newWriteCorkingConn(c net.Conn) *corkConn {
return &corkConn{Conn: c, corked: true}
}
func (c *corkConn) Write(b []byte) (int, error) {
if c.corked {
c.buf = append(c.buf, b...)
return len(b), nil
}
return c.Conn.Write(b)
}
func (c *corkConn) Read(b []byte) (int, error) {
if c.corked {
if err := c.flush(); err != nil {
return 0, err
}
}
return c.Conn.Read(b)
}
// uncork flushes any buffered data and uncorks the connection so future Writes
// don't buffer. It may not be called concurrently with reads or writes and
// may only be called once.
func (c *corkConn) uncork() error {
if !c.corked {
panic("usage error; uncork called twice") // worth panicking to catch misuse
}
err := c.flush()
c.corked = false
return err
}
func (c *corkConn) flush() error {
if len(c.buf) == 0 {
return nil
}
_, err := c.Conn.Write(c.buf)
c.buf = nil
return err
}

Loading…
Cancel
Save