diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index f1ead9f47..fdf61562d 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -7,7 +7,10 @@ package controlclient import ( "bytes" "context" + "encoding/binary" "encoding/json" + "errors" + "io" "math" "net/http" "net/url" @@ -34,6 +37,77 @@ type noiseConn struct { id int pool *noiseClient h2cc *http2.ClientConn + + readHeaderOnce sync.Once // guards init of reader field + reader io.Reader // (effectively Conn.Reader after header) + earlyPayloadReady chan struct{} // closed after earlyPayload is set (including set to nil) + earlyPayload *tailcfg.EarlyNoise +} + +func (c *noiseConn) RoundTrip(r *http.Request) (*http.Response, error) { + return c.h2cc.RoundTrip(r) +} + +// The first 9 bytes from the server to client over Noise are either an HTTP/2 +// settings frame (a normal HTTP/2 setup) or, as we added later, an "early payload" +// header that's also 9 bytes long: 5 bytes (earlyPayloadMagic) followed by 4 bytes +// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise. +// The early payload is optional. Some servers may not send it. +const ( + hdrLen = 9 // http2 frame header size; also size of our early payload size header + earlyPayloadMagic = "\xff\xff\xffTS" +) + +// returnErrReader is an io.Reader that always returns an error. +type returnErrReader struct { + err error // the error to return +} + +func (r returnErrReader) Read([]byte) (int, error) { return 0, r.err } + +// Read is basically the same as controlbase.Conn.Read, but it first reads the +// "early payload" header from the server which may or may not be present, +// depending on the server. +func (c *noiseConn) Read(p []byte) (n int, err error) { + c.readHeaderOnce.Do(c.readHeader) + return c.reader.Read(p) +} + +// readHeader reads the optional "early payload" from the server that arrives +// after the Noise handshake but before the HTTP/2 session begins. +// +// readHeader is responsible for reading the header (if present), initializing +// c.earlyPayload, closing c.earlyPayloadReady, and initializing c.reader for +// future reads. +func (c *noiseConn) readHeader() { + var hdr [hdrLen]byte + if _, err := io.ReadFull(c.Conn, hdr[:]); err != nil { + c.reader = returnErrReader{err} + return + } + if string(hdr[:len(earlyPayloadMagic)]) != earlyPayloadMagic { + // No early payload. We have to return the 9 bytes read we already + // consumed. + close(c.earlyPayloadReady) + c.reader = io.MultiReader(bytes.NewReader(hdr[:]), c.Conn) + return + } + epLen := binary.BigEndian.Uint32(hdr[len(earlyPayloadMagic):]) + if epLen > 10<<20 { + c.reader = returnErrReader{errors.New("invalid early payload length")} + return + } + payBuf := make([]byte, epLen) + if _, err := io.ReadFull(c.Conn, payBuf); err != nil { + c.reader = returnErrReader{err} + return + } + if err := json.Unmarshal(payBuf, &c.earlyPayload); err != nil { + c.reader = returnErrReader{err} + return + } + close(c.earlyPayloadReady) + c.reader = c.Conn } func (c *noiseConn) Close() error { @@ -88,7 +162,7 @@ type noiseClient struct { // serverURL is of the form https://: (no trailing slash). // // dialPlan may be nil -func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) { +func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) { u, err := url.Parse(serverURL) if err != nil { return nil, err @@ -111,7 +185,7 @@ func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, s } np := &noiseClient{ serverPubKey: serverPubKey, - privKey: priKey, + privKey: privKey, host: u.Hostname(), httpPort: httpPort, httpsPort: httpsPort, @@ -157,7 +231,7 @@ func (nc *noiseClient) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { return nil, err } - return conn.h2cc.RoundTrip(req) + return conn.RoundTrip(req) } // connClosed removes the connection with the provided ID from the pool @@ -259,14 +333,12 @@ func (nc *noiseClient) dial() (*noiseConn, error) { } ncc := &noiseConn{ - Conn: clientConn.Conn, - id: connID, - pool: nc, + Conn: clientConn.Conn, + id: connID, + pool: nc, + earlyPayloadReady: make(chan struct{}), } - // TODO(bradfitz): wrap clientConn in a type that sniffs the leading bytes - // from the server to see if it has early post-Noise, pre-H2 data for us. - h2cc, err := nc.h2t.NewClientConn(ncc) if err != nil { return nil, err diff --git a/control/controlclient/noise_test.go b/control/controlclient/noise_test.go index 3c8f1c14d..469bc281d 100644 --- a/control/controlclient/noise_test.go +++ b/control/controlclient/noise_test.go @@ -13,6 +13,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "golang.org/x/net/http2" "tailscale.com/control/controlhttp" @@ -38,15 +39,32 @@ func TestNoiseVersion(t *testing.T) { } } +type noiseClientTest struct { + sendEarlyPayload bool +} + func TestNoiseClientHTTP2Upgrade(t *testing.T) { + noiseClientTest{}.run(t) +} + +func TestNoiseClientHTTP2Upgrade_earlyPayload(t *testing.T) { + noiseClientTest{ + sendEarlyPayload: true, + }.run(t) +} + +func (tt noiseClientTest) run(t *testing.T) { serverPrivate := key.NewMachine() clientPrivate := key.NewMachine() + chalPrivate := key.NewChallenge() const msg = "Hello, client" h2 := &http2.Server{} hs := httptest.NewServer(&Upgrader{ - h2srv: h2, - noiseKeyPriv: serverPrivate, + h2srv: h2, + noiseKeyPriv: serverPrivate, + sendEarlyPayload: tt.sendEarlyPayload, + challenge: chalPrivate, httpBaseConfig: &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") @@ -61,19 +79,56 @@ func TestNoiseClientHTTP2Upgrade(t *testing.T) { if err != nil { t.Fatal(err) } - res, err := nc.post(context.Background(), "/", nil) + + // Get a conn and verify it read its early payload before the http/2 + // handshake. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c, err := nc.getConn(ctx) if err != nil { t.Fatal(err) } - defer res.Body.Close() - all, err := io.ReadAll(res.Body) + select { + case <-c.earlyPayloadReady: + gotNonNil := c.earlyPayload != nil + if gotNonNil != tt.sendEarlyPayload { + t.Errorf("sendEarlyPayload = %v but got earlyPayload = %T", tt.sendEarlyPayload, c.earlyPayload) + } + if c.earlyPayload != nil { + if c.earlyPayload.NodeKeyChallenge != chalPrivate.Public() { + t.Errorf("earlyPayload.NodeKeyChallenge = %v; want %v", c.earlyPayload.NodeKeyChallenge, chalPrivate.Public()) + } + } + + case <-ctx.Done(): + t.Fatal("timed out waiting for didReadHeaderCh") + } + + checkRes := func(t *testing.T, res *http.Response) { + t.Helper() + defer res.Body.Close() + all, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(all) != msg { + t.Errorf("got response %q; want %q", all, msg) + } + } + + // And verify we can do HTTP/2 against that conn. + res, err := (&http.Client{Transport: c}).Get("https://unused.example/") if err != nil { t.Fatal(err) } - if string(all) != msg { - t.Errorf("got response %q; want %q", all, msg) - } + checkRes(t, res) + // And try using the high-level nc.post API as well. + res, err = nc.post(context.Background(), "/", nil) + if err != nil { + t.Fatal(err) + } + checkRes(t, res) } // Upgrader is an http.Handler that hijacks and upgrades POST-with-Upgrade @@ -91,6 +146,7 @@ type Upgrader struct { logf logger.Logf noiseKeyPriv key.MachinePrivate + challenge key.ChallengePrivate sendEarlyPayload bool } @@ -109,21 +165,21 @@ func (up *Upgrader) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - chalPub := key.NewChallenge() earlyWriteFn := func(protocolVersion int, w io.Writer) error { if !up.sendEarlyPayload { return nil } - earlyJSON, err := json.Marshal(struct { - NodeKeyOwnershipChallenge string - }{chalPub.Public().String()}) + earlyJSON, err := json.Marshal(&tailcfg.EarlyNoise{ + NodeKeyChallenge: up.challenge.Public(), + }) if err != nil { return err } // 5 bytes that won't be mistaken for an HTTP/2 frame: // https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 (Especially not // an HTTP/2 settings frame, which isn't of type 'T') - var notH2Frame = [5]byte{0xff, 0xff, 0xff, 'T', 'S'} + var notH2Frame [5]byte + copy(notH2Frame[:], earlyPayloadMagic) var lenBuf [4]byte binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) // These writes are all buffered by caller, so fine to do them diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 187c61a34..8545207ae 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -1946,3 +1946,15 @@ type PeerChange struct { // // Mnemonic: 3.3.40 are numbers above the keys D, E, R, P. const DerpMagicIP = "127.3.3.40" + +// EarlyNoise is the early payload that's sent over Noise but before the HTTP/2 +// handshake when connecting to the coordination server. +// +// This exists to let the server push some early info to client for that +// stateful HTTP/2+Noise connection without incurring an extra round trip. (This +// would've used HTTP/2 server push, had Go's client-side APIs been available) +type EarlyNoise struct { + // NodeKeyChallenge is a random per-connection public key to be used by + // the client to prove possession of a wireguard private key. + NodeKeyChallenge key.ChallengePublic `json:"nodeKeyChallenge"` +}