From d5a7eabcd0a1aea8a5e1bcb3e329914969520d13 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Mon, 17 Jan 2022 15:30:30 -0800 Subject: [PATCH] control/controlbase: enable asynchronous client handshaking. With this change, the client can obtain the initial handshake message separately from the rest of the handshake, for embedding into another protocol. This enables things like RTT reduction by stuffing the handshake initiation message into an HTTP header. Similarly, the server API optionally accepts a pre-read Noise initiation message, in addition to reading the message directly off a net.Conn. Updates #3488 Signed-off-by: David Anderson --- control/controlbase/conn.go | 4 +- control/controlbase/conn_test.go | 4 +- control/controlbase/handshake.go | 97 ++++++++++++++++++++------- control/controlbase/handshake_test.go | 12 ++-- control/controlbase/interop_test.go | 2 +- 5 files changed, 84 insertions(+), 35 deletions(-) diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index 0e28f4d08..aba8d755e 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package noise implements the base transport of the Tailscale 2021 -// control protocol. +// Package controlbase implements the base transport of the Tailscale +// 2021 control protocol. // // The base transport implements Noise IK, instantiated with // Curve25519, ChaCha20Poly1305 and BLAKE2s. diff --git a/control/controlbase/conn_test.go b/control/controlbase/conn_test.go index a8328bd0b..c0dfa9940 100644 --- a/control/controlbase/conn_test.go +++ b/control/controlbase/conn_test.go @@ -202,7 +202,7 @@ func TestConnStd(t *testing.T) { serverErr := make(chan error, 1) go func() { var err error - c2, err = Server(context.Background(), s2, controlKey) + c2, err = Server(context.Background(), s2, controlKey, nil) serverErr <- err }() c1, err = Client(context.Background(), s1, machineKey, controlKey.Public()) @@ -319,7 +319,7 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) ) go func() { var err error - server, err = Server(context.Background(), serverConn, controlKey) + server, err = Server(context.Background(), serverConn, controlKey, nil) serverErr <- err }() diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 57606581c..393576ee8 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -50,21 +50,23 @@ func protocolVersionPrologue(version uint16) []byte { return strconv.AppendUint(ret, uint64(version), 10) } -// Client initiates a control client handshake, returning the resulting -// control connection. +// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn +// is assumed to have already sent the client>server handshake +// initiation message. +type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) + +// ClientDeferred initiates a control client handshake, returning the +// initial message to send to the server and a continuation to +// finalize the handshake. // -// The context deadline, if any, covers the entire handshaking -// process. Any preexisting Conn deadline is removed. -func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) { - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - +// ClientDeferred is split in this way for RTT reduction: we run this +// protocol after negotiating a protocol switch from HTTP/HTTPS. If we +// completely serialized the negotiation followed by the handshake, +// we'd pay an extra RTT to transmit the handshake initiation after +// protocol switching. By splitting the handshake into an initial +// message and a continuation, we can embed the handshake initiation +// into the HTTP protocol switching request and avoid a bit of delay. +func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { var s symmetricState s.Initialize() @@ -83,18 +85,53 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c s.MixHash(machineEphemeralPub.UntypedBytes()) cipher, err := s.MixDH(machineEphemeral, controlKey) if err != nil { - return nil, fmt.Errorf("computing es: %w", err) + return nil, nil, fmt.Errorf("computing es: %w", err) } machineKeyPub := machineKey.Public() s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) cipher, err = s.MixDH(machineKey, controlKey) if err != nil { - return nil, fmt.Errorf("computing ss: %w", err) + return nil, nil, fmt.Errorf("computing ss: %w", err) } s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload - if _, err := conn.Write(init[:]); err != nil { - return nil, fmt.Errorf("writing initiation: %w", err) + cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { + return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey) + } + return init[:], cont, nil +} + +// Client wraps ClientDeferred and immediately invokes the returned +// continuation with conn. +// +// This is a helper for when you don't need the fancy +// continuation-style handshake, and just want to synchronously +// upgrade a net.Conn to a secure transport. +func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) { + init, cont, err := ClientDeferred(machineKey, controlKey) + if err != nil { + return nil, err + } + if _, err := conn.Write(init); err != nil { + return nil, err + } + return cont(ctx, conn) +} + +func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) { + // No matter what, this function can only run once per s. Ensure + // attempted reuse causes a panic. + defer func() { + s.finished = true + }() + + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() } // Read in the payload and look for errors/protocol violations from the server. @@ -122,10 +159,10 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c // <- e, ee, se controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err = s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { + if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { return nil, fmt.Errorf("computing ee: %w", err) } - cipher, err = s.MixDH(machineKey, controlEphemeralPub) + cipher, err := s.MixDH(machineKey, controlEphemeralPub) if err != nil { return nil, fmt.Errorf("computing se: %w", err) } @@ -156,9 +193,13 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c // Server initiates a control server handshake, returning the resulting // control connection. // +// optionalInit can be the client's initial handshake message as +// returned by ClientDeferred, or nil in which case the initial +// message is read from conn. +// // The context deadline, if any, covers the entire handshaking // process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (*Conn, error) { +func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { if deadline, ok := ctx.Deadline(); ok { if err := conn.SetDeadline(deadline); err != nil { return nil, fmt.Errorf("setting conn deadline: %w", err) @@ -190,7 +231,12 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) ( s.Initialize() var init initiationMessage - if _, err := io.ReadFull(conn, init.Header()); err != nil { + if optionalInit != nil { + if len(optionalInit) != len(init) { + return nil, sendErr("wrong handshake initiation size") + } + copy(init[:], optionalInit) + } else if _, err := io.ReadFull(conn, init.Header()); err != nil { return nil, err } if init.Version() != protocolVersion { @@ -202,8 +248,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) ( if init.Length() != len(init.Payload()) { return nil, sendErr("wrong handshake initiation length") } - if _, err := io.ReadFull(conn, init.Payload()); err != nil { - return nil, err + // if optionalInit was provided, we have the payload already. + if optionalInit == nil { + if _, err := io.ReadFull(conn, init.Payload()); err != nil { + return nil, err + } } // prologue. Can only do this once we at least think the client is diff --git a/control/controlbase/handshake_test.go b/control/controlbase/handshake_test.go index a5664c11a..9cdc6f5f2 100644 --- a/control/controlbase/handshake_test.go +++ b/control/controlbase/handshake_test.go @@ -26,7 +26,7 @@ func TestHandshake(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey) + server, err = Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -78,7 +78,7 @@ func TestNoReuse(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey) + server, err = Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -172,7 +172,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey) + _, err := Server(context.Background(), serverConn, serverKey, nil) // If the server failed, we have to close the Conn to // unblock the client. if err != nil { @@ -200,7 +200,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey) + _, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -225,7 +225,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey) + server, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err _, err = io.WriteString(server, strings.Repeat("a", 14)) serverErr <- err @@ -266,7 +266,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey) + server, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err var bs [100]byte // The server needs a timeout if the tampering is hitting the length header. diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index 04bd7f41d..3417639fe 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) { ) go func() { - server, err := Server(context.Background(), s2, controlKey) + server, err := Server(context.Background(), s2, controlKey, nil) serverErr <- err if err != nil { return