From 0b392dbaf7ea4a9216b1cf55ebe45189b6fb9ea5 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Thu, 29 Jul 2021 11:59:40 -0700 Subject: [PATCH] control/noise: adjust implementation to match revised spec. Signed-off-by: David Anderson --- control/noise/conn.go | 79 +++++++++++-------- control/noise/handshake.go | 132 +++++++++++++++++++++++--------- control/noise/handshake_test.go | 18 +++-- control/noise/interop_test.go | 48 ++++++++---- control/noise/key.go | 26 +++++++ control/noise/messages.go | 87 +++++++++++++++++++++ 6 files changed, 300 insertions(+), 90 deletions(-) create mode 100644 control/noise/key.go create mode 100644 control/noise/messages.go diff --git a/control/noise/conn.go b/control/noise/conn.go index efeb538d6..63b4f36de 100644 --- a/control/noise/conn.go +++ b/control/noise/conn.go @@ -24,9 +24,9 @@ import ( ) const ( - maxPlaintextSize = 4096 - maxCiphertextSize = maxPlaintextSize + poly1305.TagSize - maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header + maxMessageSize = 4096 + maxCiphertextSize = maxMessageSize - headerLen + maxPlaintextSize = maxCiphertextSize - poly1305.TagSize ) // A Conn is a secured Noise connection. It implements the net.Conn @@ -35,6 +35,7 @@ const ( // fail. type Conn struct { conn net.Conn + version int peer key.Public handshakeHash [blake2s.Size]byte rx rxState @@ -46,7 +47,7 @@ type rxState struct { sync.Mutex cipher cipher.AEAD nonce [chp.NonceSize]byte - buf [maxPacketSize]byte + buf [maxMessageSize]byte n int // number of valid bytes in buf next int // offset of next undecrypted packet plaintext []byte // slice into buf of decrypted bytes @@ -57,10 +58,14 @@ type txState struct { sync.Mutex cipher cipher.AEAD nonce [chp.NonceSize]byte - buf [maxPacketSize]byte + buf [maxMessageSize]byte err error // records the first partial write error for all future calls } +func (c *Conn) ProtocolVersion() int { + return c.version +} + // HandshakeHash returns the Noise handshake hash for the connection, // which can be used to bind other messages to this connection // (i.e. to ensure that the message wasn't replayed from a different @@ -84,7 +89,7 @@ func validNonce(nonce []byte) bool { // bytes. Returns a slice of the available bytes in rxBuf, or an // error if fewer than total bytes are available. func (c *Conn) readNLocked(total int) ([]byte, error) { - if total > maxPacketSize { + if total > maxMessageSize { return nil, errReadTooBig{total} } for { @@ -100,10 +105,20 @@ func (c *Conn) readNLocked(total int) ([]byte, error) { } } -// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext -// to the decrypted bytes. Returns an error if the cipher is exhausted -// (i.e. can no longer be used safely) or decryption fails. -func (c *Conn) decryptLocked(ciphertext []byte) (err error) { +// decryptLocked decrypts message (which is header+ciphertext) +// in-place and sets c.rx.plaintext to the decrypted bytes. Returns an +// error if the cipher is exhausted (i.e. can no longer be used +// safely) or decryption fails. +func (c *Conn) decryptLocked(msg []byte) (err error) { + if hdrVersion(msg) != c.version { + return fmt.Errorf("received message with unexpected protocol version %d, want %d", hdrVersion(msg), c.version) + } + if hdrType(msg) != msgTypeRecord { + return fmt.Errorf("received message with unexpected type %d, want %d", hdrType(msg), msgTypeRecord) + } + // length was already handled in caller to size msg. + ciphertext := msg[headerLen:] + if !validNonce(c.rx.nonce[:]) { return errCipherExhausted{} } @@ -124,8 +139,8 @@ func (c *Conn) decryptLocked(ciphertext []byte) (err error) { } // encryptLocked encrypts plaintext into c.tx.buf (including the -// 2-byte length header) and returns a slice of the ciphertext, or an -// error if the cipher is exhausted (i.e. can no longer be used safely). +// packet header) and returns a slice of the ciphertext, or an error +// if the cipher is exhausted (i.e. can no longer be used safely). func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) { if !validNonce(c.tx.nonce[:]) { // Received 2^64-1 messages on this cipher state. Connection @@ -133,8 +148,8 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) { return nil, errCipherExhausted{} } - binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize)) - ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil) + setHeader(c.tx.buf[:5], protocolVersion, msgTypeRecord, len(plaintext)+poly1305.TagSize) + ret := c.tx.cipher.Seal(c.tx.buf[:5], c.tx.nonce[:], plaintext, nil) // Safe to increment the nonce here, because we checked for nonce // wraparound above. @@ -143,18 +158,18 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) { return ret, nil } -// wholeCiphertextLocked returns a slice of one whole Noise frame from -// c.rx.buf, if one whole ciphertext is available, and advances the -// read state to the next Noise frame in the buffer. Returns nil -// without advancing read state if there's not one whole ciphertext in -// c.rx.buf. -func (c *Conn) wholeCiphertextLocked() []byte { +// wholeMessageLocked returns a slice of one whole Noise transport +// message from c.rx.buf, if one whole message is available, and +// advances the read state to the next Noise message in the +// buffer. Returns nil without advancing read state if there isn't one +// whole message in c.rx.buf. +func (c *Conn) wholeMessageLocked() []byte { available := c.rx.n - c.rx.next - if available < 2 { + if available < headerLen { return nil } bs := c.rx.buf[c.rx.next:c.rx.n] - totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2 + totalSize := hdrLen(bs) + headerLen if len(bs) < totalSize { return nil } @@ -162,16 +177,16 @@ func (c *Conn) wholeCiphertextLocked() []byte { return bs[:totalSize] } -// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed, -// and sets c.rx.plaintext to point to the decrypted +// decryptOneLocked decrypts one Noise transport message, reading from +// c.conn as needed, and sets c.rx.plaintext to point to the decrypted // bytes. c.rx.plaintext is only valid if err == nil. func (c *Conn) decryptOneLocked() error { c.rx.plaintext = nil // Fast path: do we have one whole ciphertext frame buffered // already? - if bs := c.wholeCiphertextLocked(); bs != nil { - return c.decryptLocked(bs[2:]) + if bs := c.wholeMessageLocked(); bs != nil { + return c.decryptLocked(bs) } if c.rx.next != 0 { @@ -183,18 +198,20 @@ func (c *Conn) decryptOneLocked() error { c.rx.next = 0 } - bs, err := c.readNLocked(2) + bs, err := c.readNLocked(headerLen) if err != nil { return err } - totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2 - bs, err = c.readNLocked(totalLen) + // The rest of the header (besides the length field) gets verified + // in decryptLocked, not here. + messageLen := headerLen + hdrLen(bs) + bs, err = c.readNLocked(messageLen) if err != nil { return err } + bs = bs[:messageLen] - c.rx.next = totalLen - bs = bs[2:totalLen] + c.rx.next = len(bs) return c.decryptLocked(bs) } diff --git a/control/noise/handshake.go b/control/noise/handshake.go index 6b163b6d9..1cc0af85c 100644 --- a/control/noise/handshake.go +++ b/control/noise/handshake.go @@ -12,6 +12,7 @@ import ( "hash" "io" "net" + "strconv" "time" "golang.org/x/crypto/blake2s" @@ -23,15 +24,32 @@ import ( ) const ( + // protocolName is the name of the specific instantiation of the + // Noise protocol we're using. Each field is defined in the Noise + // spec, and shouldn't be changed unless we're switching to a + // different Noise protocol instance. protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" - // protocolVersion is the version string that gets included as the - // Noise "prologue" in the handshake. It exists so that we can - // ensure that peer have agreed on the protocol version they're - // executing, to defeat some MITM protocol downgrade attacks. - protocolVersion = "Tailscale Control Protocol v1" - invalidNonce = ^uint64(0) + // protocolVersion is the version of the Tailscale base + // protocol that Client will use when initiating a handshake. + protocolVersion = 1 + // protocolVersionPrefix is the name portion of the protocol + // name+version string that gets mixed into the Noise handshake as + // a prologue. + // + // This mixing verifies that both clients agree that + // they're executing the Tailscale control protocol at a specific + // version that matches the advertised version in the cleartext + // packet header. + protocolVersionPrefix = "Tailscale Control Protocol v" + invalidNonce = ^uint64(0) ) +func protocolVersionPrologue(version int) []byte { + ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. + ret = append(ret, protocolVersionPrefix...) + return strconv.AppendUint(ret, uint64(version), 10) +} + // Client initiates a Noise client handshake, returning the resulting // Noise connection. // @@ -50,15 +68,18 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK var s symmetricState s.Initialize() + // prologue + s.MixHash(protocolVersionPrologue(protocolVersion)) + // <- s // ... s.MixHash(controlKey[:]) // -> e, es, s, ss - var init initiationMessage + init := mkInitiationMessage() machineEphemeral := key.NewPrivate() machineEphemeralPub := machineEphemeral.Public() - copy(init.MachineEphemeralPub(), machineEphemeralPub[:]) + copy(init.EphemeralPub(), machineEphemeralPub[:]) s.MixHash(machineEphemeralPub[:]) if err := s.MixDH(machineEphemeral, controlKey); err != nil { return nil, fmt.Errorf("computing es: %w", err) @@ -74,14 +95,34 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK return nil, fmt.Errorf("writing initiation: %w", err) } - // <- e, ee, se + // Read in the payload and look for errors/protocol violations from the server. var resp responseMessage - if _, err := io.ReadFull(conn, resp[:]); err != nil { - return nil, fmt.Errorf("reading response: %w", err) + if _, err := io.ReadFull(conn, resp.Header()); err != nil { + return nil, fmt.Errorf("reading response header: %w", err) + } + if resp.Version() != protocolVersion { + return nil, fmt.Errorf("unexpected version %d from server, want %d", resp.Version(), protocolVersion) + } + if resp.Type() != msgTypeResponse { + if resp.Type() != msgTypeError { + return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) + } + msg := make([]byte, resp.Length()) + if _, err := io.ReadFull(conn, msg); err != nil { + return nil, err + } + return nil, fmt.Errorf("server error: %s", string(msg)) + } + if resp.Length() != len(resp.Payload()) { + return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) + } + if _, err := io.ReadFull(conn, resp.Payload()); err != nil { + return nil, err } + // <- e, ee, se var controlEphemeralPub key.Public - copy(controlEphemeralPub[:], resp.ControlEphemeralPub()) + copy(controlEphemeralPub[:], resp.EphemeralPub()) s.MixHash(controlEphemeralPub[:]) if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { return nil, fmt.Errorf("computing ee: %w", err) @@ -100,6 +141,7 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK return &Conn{ conn: conn, + version: protocolVersion, peer: controlKey, handshakeHash: s.h, tx: txState{ @@ -126,22 +168,55 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, }() } + // Deliberately does not support formatting, so that we don't echo + // attacker-controlled input back to them. + sendErr := func(msg string) error { + if len(msg) >= 1<<16 { + msg = msg[:1<<16] + } + var hdr [headerLen]byte + setHeader(hdr[:], protocolVersion, msgTypeError, len(msg)) + if _, err := conn.Write(hdr[:]); err != nil { + return fmt.Errorf("sending %q error to client: %w", msg, err) + } + if _, err := conn.Write([]byte(msg)); err != nil { + return fmt.Errorf("sending %q error to client: %w", msg, err) + } + return fmt.Errorf("refused client handshake: %s", msg) + } + var s symmetricState s.Initialize() + var init initiationMessage + if _, err := io.ReadFull(conn, init.Header()); err != nil { + return nil, err + } + if init.Version() != protocolVersion { + return nil, sendErr("unsupported protocol version") + } + if init.Type() != msgTypeInitiation { + return nil, sendErr("unexpected handshake message type") + } + if init.Length() != len(init.Payload()) { + return nil, sendErr("wrong handshake initiation length") + } + if _, err := io.ReadFull(conn, init.Payload()); err != nil { + return nil, err + } + + // prologue. Can only do this once we at least think the client is + // handshaking using a supported version. + s.MixHash(protocolVersionPrologue(protocolVersion)) + // <- s // ... controlKeyPub := controlKey.Public() s.MixHash(controlKeyPub[:]) // -> e, es, s, ss - var init initiationMessage - if _, err := io.ReadFull(conn, init[:]); err != nil { - return nil, fmt.Errorf("reading initiation: %w", err) - } - var machineEphemeralPub key.Public - copy(machineEphemeralPub[:], init.MachineEphemeralPub()) + copy(machineEphemeralPub[:], init.EphemeralPub()) s.MixHash(machineEphemeralPub[:]) if err := s.MixDH(controlKey, machineEphemeralPub); err != nil { return nil, fmt.Errorf("computing es: %w", err) @@ -158,10 +233,10 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, } // <- e, ee, se - var resp responseMessage + resp := mkResponseMessage() controlEphemeral := key.NewPrivate() controlEphemeralPub := controlEphemeral.Public() - copy(resp.ControlEphemeralPub(), controlEphemeralPub[:]) + copy(resp.EphemeralPub(), controlEphemeralPub[:]) s.MixHash(controlEphemeralPub[:]) if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { return nil, fmt.Errorf("computing ee: %w", err) @@ -182,6 +257,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, return &Conn{ conn: conn, + version: protocolVersion, peer: machineKey, handshakeHash: s.h, tx: txState{ @@ -193,21 +269,6 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, }, nil } -// initiationMessage is the Noise protocol message sent from a client -// machine to a control server. -type initiationMessage [96]byte - -func (m *initiationMessage) MachineEphemeralPub() []byte { return m[:32] } -func (m *initiationMessage) MachinePub() []byte { return m[32:80] } -func (m *initiationMessage) Tag() []byte { return m[80:] } - -// responseMessage is the Noise protocol message sent from a control -// server to a client machine. -type responseMessage [48]byte - -func (m *responseMessage) ControlEphemeralPub() []byte { return m[:32] } -func (m *responseMessage) Tag() []byte { return m[32:] } - // symmetricState is the SymmetricState object from the Noise protocol // spec. It contains all the symmetric cipher state of an in-flight // handshake. Field names match the variable names in the spec. @@ -232,7 +293,6 @@ func (s *symmetricState) Initialize() { s.k = [chp.KeySize]byte{} s.n = invalidNonce s.mixer = newBLAKE2s() - s.MixHash([]byte(protocolVersion)) } // MixHash updates s.h to be BLAKE2s(s.h || data), where || is diff --git a/control/noise/handshake_test.go b/control/noise/handshake_test.go index 172ee0ff8..8d97807e6 100644 --- a/control/noise/handshake_test.go +++ b/control/noise/handshake_test.go @@ -42,6 +42,12 @@ func TestHandshake(t *testing.T) { t.Fatal("client and server disagree on handshake hash") } + if client.ProtocolVersion() != protocolVersion { + t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), protocolVersion) + } + if client.ProtocolVersion() != server.ProtocolVersion() { + t.Fatalf("peers disagree on protocol version, client=%d server=%d", client.ProtocolVersion(), server.ProtocolVersion()) + } if client.Peer() != serverKey.Public() { t.Fatal("client peer key isn't serverKey") } @@ -154,7 +160,7 @@ func (r *tamperReader) Read(bs []byte) (int, error) { func TestTampering(t *testing.T) { // Tamper with every byte of the client initiation message. - for i := 0; i < 96; i++ { + for i := 0; i < 101; i++ { var ( clientConn, serverRaw = tsnettest.NewConn("noise", 128000) serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}} @@ -182,7 +188,7 @@ func TestTampering(t *testing.T) { } // Tamper with every byte of the server response message. - for i := 0; i < 48; i++ { + for i := 0; i < 53; i++ { var ( clientRaw, serverConn = tsnettest.NewConn("noise", 128000) clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}} @@ -210,7 +216,7 @@ func TestTampering(t *testing.T) { for i := 0; i < 32; i++ { var ( clientRaw, serverConn = tsnettest.NewConn("noise", 128000) - clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 48 + i, 0}} + clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 53 + i, 0}} serverKey = key.NewPrivate() clientKey = key.NewPrivate() serverErr = make(chan error, 1) @@ -233,7 +239,7 @@ func TestTampering(t *testing.T) { } // The client needs a timeout if the tampering is hitting the length header. - if i == 0 || i == 1 { + if i == 3 || i == 4 { client.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) } @@ -251,7 +257,7 @@ func TestTampering(t *testing.T) { for i := 0; i < 32; i++ { var ( clientConn, serverRaw = tsnettest.NewConn("noise", 128000) - serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 96 + i, 0}} + serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 101 + i, 0}} serverKey = key.NewPrivate() clientKey = key.NewPrivate() serverErr = make(chan error, 1) @@ -261,7 +267,7 @@ func TestTampering(t *testing.T) { serverErr <- err var bs [100]byte // The server needs a timeout if the tampering is hitting the length header. - if i == 0 || i == 1 { + if i == 3 || i == 4 { server.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) } n, err := server.Read(bs[:]) diff --git a/control/noise/interop_test.go b/control/noise/interop_test.go index 05fe6805d..d3ae83468 100644 --- a/control/noise/interop_test.go +++ b/control/noise/interop_test.go @@ -120,9 +120,14 @@ func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Pr private_key: machineKey, public_key: machineKey.Public(), } - session := InitSession(true, []byte(protocolVersion), mk, controlKey) + session := InitSession(true, protocolVersionPrologue(protocolVersion), mk, controlKey) _, msg1 := SendMessage(&session, nil) + var hdr [headerLen]byte + setHeader(hdr[:], protocolVersion, msgTypeInitiation, 96) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } if _, err := conn.Write(msg1.ne[:]); err != nil { return nil, err } @@ -134,13 +139,15 @@ func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Pr } var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:48]); err != nil { + if _, err := io.ReadFull(conn, buf[:53]); err != nil { return nil, err } + // ignore the header for this test, we're only checking the noise + // implementation. msg2 := messagebuffer{ - ciphertext: buf[32:48], + ciphertext: buf[37:53], } - copy(msg2.ne[:], buf[:32]) + copy(msg2.ne[:], buf[5:37]) _, p, valid := RecvMessage(&session, &msg2) if !valid { return nil, errors.New("handshake failed") @@ -150,18 +157,19 @@ func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Pr } _, msg3 := SendMessage(&session, payload) - binary.BigEndian.PutUint16(buf[:2], uint16(len(msg3.ciphertext))) - if _, err := conn.Write(buf[:2]); err != nil { + setHeader(hdr[:], protocolVersion, msgTypeRecord, len(msg3.ciphertext)) + if _, err := conn.Write(hdr[:]); err != nil { return nil, err } if _, err := conn.Write(msg3.ciphertext); err != nil { return nil, err } - if _, err := io.ReadFull(conn, buf[:2]); err != nil { + if _, err := io.ReadFull(conn, buf[:5]); err != nil { return nil, err } - plen := int(binary.BigEndian.Uint16(buf[:2])) + // Ignore all of the header except the payload length + plen := int(binary.LittleEndian.Uint16(buf[3:5])) if _, err := io.ReadFull(conn, buf[:plen]); err != nil { return nil, err } @@ -182,17 +190,18 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k private_key: controlKey, public_key: controlKey.Public(), } - session := InitSession(false, []byte(protocolVersion), mk, [32]byte{}) + session := InitSession(false, protocolVersionPrologue(protocolVersion), mk, [32]byte{}) var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:96]); err != nil { + if _, err := io.ReadFull(conn, buf[:101]); err != nil { return nil, err } + // Ignore the header, we're just checking the noise implementation. msg1 := messagebuffer{ - ns: buf[32:80], - ciphertext: buf[80:96], + ns: buf[37:85], + ciphertext: buf[85:101], } - copy(msg1.ne[:], buf[:32]) + copy(msg1.ne[:], buf[5:37]) _, p, valid := RecvMessage(&session, &msg1) if !valid { return nil, errors.New("handshake failed") @@ -202,6 +211,11 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k } _, msg2 := SendMessage(&session, nil) + var hdr [headerLen]byte + setHeader(hdr[:], protocolVersion, msgTypeResponse, 48) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } if _, err := conn.Write(msg2.ne[:]); err != nil { return nil, err } @@ -209,10 +223,10 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k return nil, err } - if _, err := io.ReadFull(conn, buf[:2]); err != nil { + if _, err := io.ReadFull(conn, buf[:5]); err != nil { return nil, err } - plen := int(binary.BigEndian.Uint16(buf[:2])) + plen := int(binary.LittleEndian.Uint16(buf[3:5])) if _, err := io.ReadFull(conn, buf[:plen]); err != nil { return nil, err } @@ -226,8 +240,8 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k } _, msg4 := SendMessage(&session, payload) - binary.BigEndian.PutUint16(buf[:2], uint16(len(msg4.ciphertext))) - if _, err := conn.Write(buf[:2]); err != nil { + setHeader(hdr[:], protocolVersion, msgTypeRecord, len(msg4.ciphertext)) + if _, err := conn.Write(hdr[:]); err != nil { return nil, err } if _, err := conn.Write(msg4.ciphertext); err != nil { diff --git a/control/noise/key.go b/control/noise/key.go new file mode 100644 index 000000000..6fe480c01 --- /dev/null +++ b/control/noise/key.go @@ -0,0 +1,26 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package noise + +// Note that these types are deliberately separate from the types/key +// package. That package defines generic curve25519 keys, without +// consideration for how those keys are used. We don't want to +// encourage mixing machine keys, node keys, and whatever else we +// might use curve25519 for. +// +// Furthermore, the implementation in types/key does some work that is +// unnecessary for machine keys, and results in a harder to follow +// implementation. In particular, machine keys do not need to be +// clamped per the curve25519 spec because they're only used with the +// X25519 operation, and the X25519 operation defines its own clamping +// and sanity checking logic. Thus, these keys must be used only with +// this Noise protocol implementation, and the easiest way to ensure +// that is a different type. + +// PrivateKey is a Tailscale machine private key. +type PrivateKey [32]byte + +// PublicKey is a Tailscale machine public key. +type PublicKey [32]byte diff --git a/control/noise/messages.go b/control/noise/messages.go new file mode 100644 index 000000000..abfa0520b --- /dev/null +++ b/control/noise/messages.go @@ -0,0 +1,87 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package noise + +import "encoding/binary" + +const ( + msgTypeInitiation = 1 + msgTypeResponse = 2 + msgTypeError = 3 + msgTypeRecord = 4 +) + +// headerLen is the size of the cleartext message header that gets +// prepended to Noise messages. +// +// 2b: protocol version +// 1b: message type +// 2b: payload length (not including this header) +const headerLen = 5 + +func setHeader(bs []byte, version int, msgType byte, length int) { + binary.LittleEndian.PutUint16(bs[:2], uint16(version)) + bs[2] = msgType + binary.LittleEndian.PutUint16(bs[3:5], uint16(length)) +} +func hdrVersion(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[:2])) } +func hdrType(bs []byte) byte { return bs[2] } +func hdrLen(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[3:5])) } + +// initiationMessage is the Noise protocol message sent from a client +// machine to a control server. +// +// 5b: header (see headerLen for fields) +// 32b: client ephemeral public key (cleartext) +// 48b: client machine public key (encrypted) +// 16b: message tag (authenticates the whole message) +type initiationMessage [101]byte + +func mkInitiationMessage() initiationMessage { + var ret initiationMessage + binary.LittleEndian.PutUint16(ret[:2], protocolVersion) + ret[2] = msgTypeInitiation + binary.LittleEndian.PutUint16(ret[3:5], 96) + return ret +} + +func (m *initiationMessage) Header() []byte { return m[:5] } +func (m *initiationMessage) Payload() []byte { return m[5:] } + +func (m *initiationMessage) Version() int { return hdrVersion(m.Header()) } +func (m *initiationMessage) Type() byte { return hdrType(m.Header()) } +func (m *initiationMessage) Length() int { return hdrLen(m.Header()) } + +func (m *initiationMessage) EphemeralPub() []byte { return m[5:37] } +func (m *initiationMessage) MachinePub() []byte { return m[37:85] } +func (m *initiationMessage) Tag() []byte { return m[85:] } + +// responseMessage is the Noise protocol message sent from a control +// server to a client machine. +// +// 2b: little-endian protocol version +// 1b: message type +// 2b: little-endian size of message (not including this header) +// 32b: control ephemeral public key (cleartext) +// 16b: message tag (authenticates the whole message) +type responseMessage [53]byte + +func mkResponseMessage() responseMessage { + var ret responseMessage + binary.LittleEndian.PutUint16(ret[:2], protocolVersion) + ret[2] = msgTypeResponse + binary.LittleEndian.PutUint16(ret[3:5], 48) + return ret +} + +func (m *responseMessage) Header() []byte { return m[:5] } +func (m *responseMessage) Payload() []byte { return m[5:] } + +func (m *responseMessage) Version() int { return hdrVersion(m.Header()) } +func (m *responseMessage) Type() byte { return hdrType(m.Header()) } +func (m *responseMessage) Length() int { return hdrLen(m.Header()) } + +func (m *responseMessage) EphemeralPub() []byte { return m[5:37] } +func (m *responseMessage) Tag() []byte { return m[37:] }