control/controlbase: enable asynchronous client handshaking.

With this change, the client can obtain the initial handshake message
separately from the rest of the handshake, for embedding into another
protocol. This enables things like RTT reduction by stuffing the
handshake initiation message into an HTTP header.

Similarly, the server API optionally accepts a pre-read Noise initiation
message, in addition to reading the message directly off a net.Conn.

Updates #3488

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/3759/head
David Anderson 2 years ago committed by Dave Anderson
parent 6cd180746f
commit d5a7eabcd0

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package noise implements the base transport of the Tailscale 2021 // Package controlbase implements the base transport of the Tailscale
// control protocol. // 2021 control protocol.
// //
// The base transport implements Noise IK, instantiated with // The base transport implements Noise IK, instantiated with
// Curve25519, ChaCha20Poly1305 and BLAKE2s. // Curve25519, ChaCha20Poly1305 and BLAKE2s.

@ -202,7 +202,7 @@ func TestConnStd(t *testing.T) {
serverErr := make(chan error, 1) serverErr := make(chan error, 1)
go func() { go func() {
var err error var err error
c2, err = Server(context.Background(), s2, controlKey) c2, err = Server(context.Background(), s2, controlKey, nil)
serverErr <- err serverErr <- err
}() }()
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public()) c1, err = Client(context.Background(), s1, machineKey, controlKey.Public())
@ -319,7 +319,7 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn)
) )
go func() { go func() {
var err error var err error
server, err = Server(context.Background(), serverConn, controlKey) server, err = Server(context.Background(), serverConn, controlKey, nil)
serverErr <- err serverErr <- err
}() }()

@ -50,21 +50,23 @@ func protocolVersionPrologue(version uint16) []byte {
return strconv.AppendUint(ret, uint64(version), 10) return strconv.AppendUint(ret, uint64(version), 10)
} }
// Client initiates a control client handshake, returning the resulting // HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
// control connection. // is assumed to have already sent the client>server handshake
// initiation message.
type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
// ClientDeferred initiates a control client handshake, returning the
// initial message to send to the server and a continuation to
// finalize the handshake.
// //
// The context deadline, if any, covers the entire handshaking // ClientDeferred is split in this way for RTT reduction: we run this
// process. Any preexisting Conn deadline is removed. // protocol after negotiating a protocol switch from HTTP/HTTPS. If we
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) { // completely serialized the negotiation followed by the handshake,
if deadline, ok := ctx.Deadline(); ok { // we'd pay an extra RTT to transmit the handshake initiation after
if err := conn.SetDeadline(deadline); err != nil { // protocol switching. By splitting the handshake into an initial
return nil, fmt.Errorf("setting conn deadline: %w", err) // message and a continuation, we can embed the handshake initiation
} // into the HTTP protocol switching request and avoid a bit of delay.
defer func() { func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
conn.SetDeadline(time.Time{})
}()
}
var s symmetricState var s symmetricState
s.Initialize() s.Initialize()
@ -83,18 +85,53 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
s.MixHash(machineEphemeralPub.UntypedBytes()) s.MixHash(machineEphemeralPub.UntypedBytes())
cipher, err := s.MixDH(machineEphemeral, controlKey) cipher, err := s.MixDH(machineEphemeral, controlKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("computing es: %w", err) return nil, nil, fmt.Errorf("computing es: %w", err)
} }
machineKeyPub := machineKey.Public() machineKeyPub := machineKey.Public()
s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
cipher, err = s.MixDH(machineKey, controlKey) cipher, err = s.MixDH(machineKey, controlKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("computing ss: %w", err) return nil, nil, fmt.Errorf("computing ss: %w", err)
} }
s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
if _, err := conn.Write(init[:]); err != nil { cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
return nil, fmt.Errorf("writing initiation: %w", err) return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey)
}
return init[:], cont, nil
}
// Client wraps ClientDeferred and immediately invokes the returned
// continuation with conn.
//
// This is a helper for when you don't need the fancy
// continuation-style handshake, and just want to synchronously
// upgrade a net.Conn to a secure transport.
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
init, cont, err := ClientDeferred(machineKey, controlKey)
if err != nil {
return nil, err
}
if _, err := conn.Write(init); err != nil {
return nil, err
}
return cont(ctx, conn)
}
func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
// No matter what, this function can only run once per s. Ensure
// attempted reuse causes a panic.
defer func() {
s.finished = true
}()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err)
}
defer func() {
conn.SetDeadline(time.Time{})
}()
} }
// Read in the payload and look for errors/protocol violations from the server. // Read in the payload and look for errors/protocol violations from the server.
@ -122,10 +159,10 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
// <- e, ee, se // <- e, ee, se
controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
s.MixHash(controlEphemeralPub.UntypedBytes()) s.MixHash(controlEphemeralPub.UntypedBytes())
if _, err = s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
return nil, fmt.Errorf("computing ee: %w", err) return nil, fmt.Errorf("computing ee: %w", err)
} }
cipher, err = s.MixDH(machineKey, controlEphemeralPub) cipher, err := s.MixDH(machineKey, controlEphemeralPub)
if err != nil { if err != nil {
return nil, fmt.Errorf("computing se: %w", err) return nil, fmt.Errorf("computing se: %w", err)
} }
@ -156,9 +193,13 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
// Server initiates a control server handshake, returning the resulting // Server initiates a control server handshake, returning the resulting
// control connection. // control connection.
// //
// optionalInit can be the client's initial handshake message as
// returned by ClientDeferred, or nil in which case the initial
// message is read from conn.
//
// The context deadline, if any, covers the entire handshaking // The context deadline, if any, covers the entire handshaking
// process. // process.
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (*Conn, error) { func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil { if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err) return nil, fmt.Errorf("setting conn deadline: %w", err)
@ -190,7 +231,12 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (
s.Initialize() s.Initialize()
var init initiationMessage var init initiationMessage
if _, err := io.ReadFull(conn, init.Header()); err != nil { if optionalInit != nil {
if len(optionalInit) != len(init) {
return nil, sendErr("wrong handshake initiation size")
}
copy(init[:], optionalInit)
} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
return nil, err return nil, err
} }
if init.Version() != protocolVersion { if init.Version() != protocolVersion {
@ -202,8 +248,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate) (
if init.Length() != len(init.Payload()) { if init.Length() != len(init.Payload()) {
return nil, sendErr("wrong handshake initiation length") return nil, sendErr("wrong handshake initiation length")
} }
if _, err := io.ReadFull(conn, init.Payload()); err != nil { // if optionalInit was provided, we have the payload already.
return nil, err if optionalInit == nil {
if _, err := io.ReadFull(conn, init.Payload()); err != nil {
return nil, err
}
} }
// prologue. Can only do this once we at least think the client is // prologue. Can only do this once we at least think the client is

@ -26,7 +26,7 @@ func TestHandshake(t *testing.T) {
) )
go func() { go func() {
var err error var err error
server, err = Server(context.Background(), serverConn, serverKey) server, err = Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
}() }()
@ -78,7 +78,7 @@ func TestNoReuse(t *testing.T) {
) )
go func() { go func() {
var err error var err error
server, err = Server(context.Background(), serverConn, serverKey) server, err = Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
}() }()
@ -172,7 +172,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
) )
go func() { go func() {
_, err := Server(context.Background(), serverConn, serverKey) _, err := Server(context.Background(), serverConn, serverKey, nil)
// If the server failed, we have to close the Conn to // If the server failed, we have to close the Conn to
// unblock the client. // unblock the client.
if err != nil { if err != nil {
@ -200,7 +200,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
) )
go func() { go func() {
_, err := Server(context.Background(), serverConn, serverKey) _, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
}() }()
@ -225,7 +225,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
) )
go func() { go func() {
server, err := Server(context.Background(), serverConn, serverKey) server, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
_, err = io.WriteString(server, strings.Repeat("a", 14)) _, err = io.WriteString(server, strings.Repeat("a", 14))
serverErr <- err serverErr <- err
@ -266,7 +266,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
) )
go func() { go func() {
server, err := Server(context.Background(), serverConn, serverKey) server, err := Server(context.Background(), serverConn, serverKey, nil)
serverErr <- err serverErr <- err
var bs [100]byte var bs [100]byte
// The server needs a timeout if the tampering is hitting the length header. // The server needs a timeout if the tampering is hitting the length header.

@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) {
) )
go func() { go func() {
server, err := Server(context.Background(), s2, controlKey) server, err := Server(context.Background(), s2, controlKey, nil)
serverErr <- err serverErr <- err
if err != nil { if err != nil {
return return

Loading…
Cancel
Save