From cf90392174c7ab1ea4f24388e936c36479edb1f4 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 30 Jul 2021 11:38:10 -0700 Subject: [PATCH] control/noise: review fixups Signed-off-by: David Anderson --- control/noise/conn.go | 59 ++++++++++------- control/noise/handshake.go | 61 ++++++++++------- control/noise/handshake_test.go | 2 +- control/noise/messages.go | 112 ++++++++++++++++++++------------ 4 files changed, 146 insertions(+), 88 deletions(-) diff --git a/control/noise/conn.go b/control/noise/conn.go index 63b4f36de..4dce79f80 100644 --- a/control/noise/conn.go +++ b/control/noise/conn.go @@ -24,18 +24,24 @@ import ( ) const ( - maxMessageSize = 4096 + // maxMessageSize is the maximum size of a protocol frame on the + // wire, including header and payload. + maxMessageSize = 4096 + // maxCiphertextSize is the maximum amount of ciphertext bytes + // that one protocol frame can carry, after framing. maxCiphertextSize = maxMessageSize - headerLen - maxPlaintextSize = maxCiphertextSize - poly1305.TagSize + // maxPlaintextSize is the maximum amount of plaintext bytes that + // one protocol frame can carry, after encryption and framing. + maxPlaintextSize = maxCiphertextSize - poly1305.TagSize ) // A Conn is a secured Noise connection. It implements the net.Conn // interface, with the unusual trait that any write error (including a -// SetWriteDeadline induced i/o timeout) cause all future writes to +// SetWriteDeadline induced i/o timeout) causes all future writes to // fail. type Conn struct { conn net.Conn - version int + version uint16 peer key.Public handshakeHash [blake2s.Size]byte rx rxState @@ -62,8 +68,10 @@ type txState struct { err error // records the first partial write error for all future calls } +// ProtocolVersion returns the protocol version that was used to +// establish this Conn. func (c *Conn) ProtocolVersion() int { - return c.version + return int(c.version) } // HandshakeHash returns the Noise handshake hash for the connection, @@ -85,7 +93,7 @@ func validNonce(nonce []byte) bool { return binary.BigEndian.Uint32(nonce[:4]) == 0 && binary.BigEndian.Uint64(nonce[4:]) != invalidNonce } -// readNLocked reads into c.rxBuf until rxBuf contains at least total +// readNLocked reads into c.rx.buf until buf contains at least total // bytes. Returns a slice of the available bytes in rxBuf, or an // error if fewer than total bytes are available. func (c *Conn) readNLocked(total int) ([]byte, error) { @@ -105,10 +113,8 @@ func (c *Conn) readNLocked(total int) ([]byte, error) { } } -// decryptLocked decrypts message (which is header+ciphertext) -// in-place and sets c.rx.plaintext to the decrypted bytes. Returns an -// error if the cipher is exhausted (i.e. can no longer be used -// safely) or decryption fails. +// decryptLocked decrypts msg (which is header+ciphertext) in-place +// and sets c.rx.plaintext to the decrypted bytes. func (c *Conn) decryptLocked(msg []byte) (err error) { if hdrVersion(msg) != c.version { return fmt.Errorf("received message with unexpected protocol version %d, want %d", hdrVersion(msg), c.version) @@ -116,7 +122,9 @@ func (c *Conn) decryptLocked(msg []byte) (err error) { if hdrType(msg) != msgTypeRecord { return fmt.Errorf("received message with unexpected type %d, want %d", hdrType(msg), msgTypeRecord) } - // length was already handled in caller to size msg. + // We don't check the length field here, because the caller + // already did in order to figure out how big the msg slice should + // be. ciphertext := msg[headerLen:] if !validNonce(c.rx.nonce[:]) { @@ -132,7 +140,8 @@ func (c *Conn) decryptLocked(msg []byte) (err error) { if err != nil { // Once a decryption has failed, our Conn is no longer // synchronized with our peer. Nuke the cipher state to be - // safe, so that no further decryptions are attempted. + // safe, so that no further decryptions are attempted. Future + // read attempts will return net.ErrClosed. c.rx.cipher = nil } return err @@ -148,8 +157,8 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) { return nil, errCipherExhausted{} } - setHeader(c.tx.buf[:5], protocolVersion, msgTypeRecord, len(plaintext)+poly1305.TagSize) - ret := c.tx.cipher.Seal(c.tx.buf[:5], c.tx.nonce[:], plaintext, nil) + setHeader(c.tx.buf[:headerLen], protocolVersion, msgTypeRecord, len(plaintext)+poly1305.TagSize) + ret := c.tx.cipher.Seal(c.tx.buf[:headerLen], c.tx.nonce[:], plaintext, nil) // Safe to increment the nonce here, because we checked for nonce // wraparound above. @@ -169,7 +178,7 @@ func (c *Conn) wholeMessageLocked() []byte { return nil } bs := c.rx.buf[c.rx.next:c.rx.n] - totalSize := hdrLen(bs) + headerLen + totalSize := headerLen + hdrLen(bs) if len(bs) < totalSize { return nil } @@ -193,8 +202,7 @@ func (c *Conn) decryptOneLocked() error { // To simplify the read logic, move the remainder of the // buffered bytes back to the head of the buffer, so we can // grow it without worrying about wraparound. - copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) - c.rx.n -= c.rx.next + c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) c.rx.next = 0 } @@ -224,8 +232,10 @@ func (c *Conn) Read(bs []byte) (int, error) { if c.rx.cipher == nil { return 0, net.ErrClosed } - // Loop to handle receiving a zero-byte Noise message. Just skip - // over it and keep decrypting until we find some bytes. + // If no plaintext is buffered, decrypt incoming frames until we + // have some plaintext. Zero-byte Noise frames are allowed in this + // protocol, which is why we have to loop here rather than decrypt + // a single additional frame. for len(c.rx.plaintext) == 0 { if err := c.decryptOneLocked(); err != nil { return 0, err @@ -276,15 +286,15 @@ func (c *Conn) Write(bs []byte) (n int, err error) { return 0, err } - if n, err := c.conn.Write(ciphertext); err != nil { - sent += n + n, err := c.conn.Write(ciphertext) + sent += n + if err != nil { // Return the raw error on the Write that actually // failed. For future writes, return that error wrapped in // a desync error. c.tx.err = errPartialWrite{err} return sent, err } - sent += len(toSend) } return sent, nil } @@ -292,6 +302,11 @@ func (c *Conn) Write(bs []byte) (n int, err error) { // Close implements io.Closer. func (c *Conn) Close() error { closeErr := c.conn.Close() // unblocks any waiting reads or writes + + // Remove references to live cipher state. Strictly speaking this + // is unnecessary, but we want to try and hand the active cipher + // state to the garbage collector promptly, to preserve perfect + // forward secrecy as much as we can. c.rx.Lock() c.rx.cipher = nil c.rx.Unlock() diff --git a/control/noise/handshake.go b/control/noise/handshake.go index 1cc0af85c..8e64b1dab 100644 --- a/control/noise/handshake.go +++ b/control/noise/handshake.go @@ -31,7 +31,7 @@ const ( protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" // protocolVersion is the version of the Tailscale base // protocol that Client will use when initiating a handshake. - protocolVersion = 1 + protocolVersion uint16 = 1 // protocolVersionPrefix is the name portion of the protocol // name+version string that gets mixed into the Noise handshake as // a prologue. @@ -44,7 +44,7 @@ const ( invalidNonce = ^uint64(0) ) -func protocolVersionPrologue(version int) []byte { +func protocolVersionPrologue(version uint16) []byte { ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. ret = append(ret, protocolVersionPrefix...) return strconv.AppendUint(ret, uint64(version), 10) @@ -54,7 +54,7 @@ func protocolVersionPrologue(version int) []byte { // Noise connection. // // The context deadline, if any, covers the entire handshaking -// process. +// process. Any preexisting Conn deadline is removed. func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlKey key.Public) (*Conn, error) { if deadline, ok := ctx.Deadline(); ok { if err := conn.SetDeadline(deadline); err != nil { @@ -111,7 +111,7 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK if _, err := io.ReadFull(conn, msg); err != nil { return nil, err } - return nil, fmt.Errorf("server error: %s", string(msg)) + return nil, fmt.Errorf("server error: %q", msg) } if resp.Length() != len(resp.Payload()) { return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) @@ -139,7 +139,7 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK return nil, fmt.Errorf("finalizing handshake: %w", err) } - return &Conn{ + c := &Conn{ conn: conn, version: protocolVersion, peer: controlKey, @@ -150,7 +150,8 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK rx: rxState{ cipher: c2, }, - }, nil + } + return c, nil } // Server initiates a Noise server handshake, returning the resulting @@ -179,10 +180,10 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, if _, err := conn.Write(hdr[:]); err != nil { return fmt.Errorf("sending %q error to client: %w", msg, err) } - if _, err := conn.Write([]byte(msg)); err != nil { + if _, err := io.WriteString(conn, msg); err != nil { return fmt.Errorf("sending %q error to client: %w", msg, err) } - return fmt.Errorf("refused client handshake: %s", msg) + return fmt.Errorf("refused client handshake: %q", msg) } var s symmetricState @@ -255,7 +256,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, return nil, err } - return &Conn{ + c := &Conn{ conn: conn, version: protocolVersion, peer: machineKey, @@ -266,13 +267,16 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, rx: rxState{ cipher: c1, }, - }, nil + } + return c, nil } // symmetricState is the SymmetricState object from the Noise protocol // spec. It contains all the symmetric cipher state of an in-flight // handshake. Field names match the variable names in the spec. type symmetricState struct { + finished bool + h [blake2s.Size]byte ck [blake2s.Size]byte @@ -282,9 +286,16 @@ type symmetricState struct { mixer hash.Hash // for updating h } +func (s *symmetricState) checkFinished() { + if s.finished { + panic("attempted to use symmetricState after Split was called") + } +} + // Initialize sets s to the initial handshake state, prior to // processing any Noise messages. func (s *symmetricState) Initialize() { + s.checkFinished() if s.mixer != nil { panic("symmetricState cannot be reused") } @@ -298,10 +309,11 @@ func (s *symmetricState) Initialize() { // MixHash updates s.h to be BLAKE2s(s.h || data), where || is // concatenation. func (s *symmetricState) MixHash(data []byte) { + s.checkFinished() s.mixer.Reset() s.mixer.Write(s.h[:]) s.mixer.Write(data) - s.mixer.Sum(s.h[:0]) // TODO: check this actually updates s.h correctly... + s.mixer.Sum(s.h[:0]) } // MixDH updates s.ck and s.k with the result of X25519(priv, pub). @@ -312,16 +324,7 @@ func (s *symmetricState) MixHash(data []byte) { // two private keys, or two public keys), and thus producing the wrong // calculation. func (s *symmetricState) MixDH(priv key.Private, pub key.Public) error { - // TODO(danderson): check that this operation is correct. The docs - // for X25519 say that the 2nd arg must be either Basepoint or the - // output of another X25519 call. - // - // I think this is correct, because pub is the result of a - // ScalarBaseMult on the private key, and our private key - // generation code clamps keys to avoid low order points. I - // believe that makes pub equivalent to the output of - // X25519(privateKey, Basepoint), and so the contract is - // respected. + s.checkFinished() keyData, err := curve25519.X25519(priv[:], pub[:]) if err != nil { return fmt.Errorf("computing X25519: %w", err) @@ -342,6 +345,7 @@ func (s *symmetricState) MixDH(priv key.Private, pub key.Public) error { // the correct size to hold the encrypted plaintext) using the current // s.k, mixes the ciphertext into s.h, and returns the ciphertext. func (s *symmetricState) EncryptAndHash(ciphertext, plaintext []byte) { + s.checkFinished() if s.n == invalidNonce { // Noise in general permits writing "ciphertext" without a // key, but in IK it cannot happen. @@ -352,6 +356,8 @@ func (s *symmetricState) EncryptAndHash(ciphertext, plaintext []byte) { } aead := newCHP(s.k) var nonce [chp.NonceSize]byte + // chacha20poly1305 nonces are 96 bits, but we use a 64-bit + // counter. Therefore, the leading 4 bytes are always zero. binary.BigEndian.PutUint64(nonce[4:], s.n) s.n++ ret := aead.Seal(ciphertext[:0], nonce[:], plaintext, s.h[:]) @@ -363,6 +369,7 @@ func (s *symmetricState) EncryptAndHash(ciphertext, plaintext []byte) { // the current s.k. If decryption is successful, it mixes the // ciphertext into s.h. func (s *symmetricState) DecryptAndHash(plaintext, ciphertext []byte) error { + s.checkFinished() if s.n == invalidNonce { // Noise in general permits "ciphertext" without a key, but in // IK it cannot happen. @@ -373,6 +380,8 @@ func (s *symmetricState) DecryptAndHash(plaintext, ciphertext []byte) error { } aead := newCHP(s.k) var nonce [chp.NonceSize]byte + // chacha20poly1305 nonces are 96 bits, but we use a 64-bit + // counter. Therefore, the leading 4 bytes are always zero. binary.BigEndian.PutUint64(nonce[4:], s.n) s.n++ if _, err := aead.Open(plaintext[:0], nonce[:], ciphertext, s.h[:]); err != nil { @@ -383,9 +392,11 @@ func (s *symmetricState) DecryptAndHash(plaintext, ciphertext []byte) error { } // Split returns two ChaCha20Poly1305 ciphers with keys derived from -// the current handshake state. Methods on s must not be used again -// after calling Split(). +// the current handshake state. Methods on s cannot be used again +// after calling Split. func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { + s.finished = true + var k1, k2 [chp.KeySize]byte r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) if _, err := io.ReadFull(r, k1[:]); err != nil { @@ -412,7 +423,7 @@ func newBLAKE2s() hash.Hash { if err != nil { // Should never happen, errors only happen when using BLAKE2s // in MAC mode with a key. - panic(fmt.Sprintf("blake2s construction: %v", err)) + panic(err) } return h } @@ -424,7 +435,7 @@ func newCHP(key [chp.KeySize]byte) cipher.AEAD { if err != nil { // Can only happen if we passed a key of the wrong length. The // function signature prevents that. - panic(fmt.Sprintf("chacha20poly1305 construction: %v", err)) + panic(err) } return aead } diff --git a/control/noise/handshake_test.go b/control/noise/handshake_test.go index 8d97807e6..043dfd8aa 100644 --- a/control/noise/handshake_test.go +++ b/control/noise/handshake_test.go @@ -42,7 +42,7 @@ func TestHandshake(t *testing.T) { t.Fatal("client and server disagree on handshake hash") } - if client.ProtocolVersion() != protocolVersion { + if client.ProtocolVersion() != int(protocolVersion) { t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), protocolVersion) } if client.ProtocolVersion() != server.ProtocolVersion() { diff --git a/control/noise/messages.go b/control/noise/messages.go index abfa0520b..381038915 100644 --- a/control/noise/messages.go +++ b/control/noise/messages.go @@ -6,32 +6,68 @@ package noise import "encoding/binary" -const ( - msgTypeInitiation = 1 - msgTypeResponse = 2 - msgTypeError = 3 - msgTypeRecord = 4 -) - -// headerLen is the size of the cleartext message header that gets -// prepended to Noise messages. +// The transport protocol is mostly Noise messages encapsulated in a +// small header describing the payload's type and length. The one +// place we deviate from pure Noise+header is that we also support +// sending an unauthenticated plaintext error as payload, to provide +// an explanation for a connection error that happens before the +// handshake completes. +// +// All frames in our protocol have a 5-byte header: +// +// +------+------+------+------+------+ +// | version | type | length | +// +------+------+------+------+------+ // // 2b: protocol version // 1b: message type -// 2b: payload length (not including this header) +// 2b: payload length (not including the header) +// +// Multibyte values are all big-endian on the wire, as is traditional +// for network protocols. +// +// The protocol version is 2 bytes in order to encourage frequent +// revving of the protocol as needed, without fear of running out of +// version numbers. At minimum, the version number must change +// whenever any particulars of the Noise handshake change +// (e.g. switching from Noise IK to Noise IKpsk1 or Noise XX), and +// when security-critical aspects of the "uppper" protocol within the +// Noise frames change (e.g. how further authentication data is bound +// to the underlying Noise session). + +// headerLen is the size of the header that gets prepended to Noise +// messages. const headerLen = 5 -func setHeader(bs []byte, version int, msgType byte, length int) { +const ( + // msgTypeInitiation frames carry a Noise IK handshake initiation message. + msgTypeInitiation = 1 + // msgTypeResponse frames carry a Noise IK handshake response message. + msgTypeResponse = 2 + // msgTypeError frames carry an unauthenticated human-readable + // error message. + // + // Errors reported in this message type must be treated as public + // hints only. They are not encrypted or authenticated, and so can + // be seen and tampered with on the wire. + msgTypeError = 3 + // msgTypeRecord frames carry a Noise transport message (i.e. "user data"). + msgTypeRecord = 4 +) + +func setHeader(bs []byte, version uint16, msgType byte, length int) { binary.LittleEndian.PutUint16(bs[:2], uint16(version)) bs[2] = msgType binary.LittleEndian.PutUint16(bs[3:5], uint16(length)) } -func hdrVersion(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[:2])) } -func hdrType(bs []byte) byte { return bs[2] } -func hdrLen(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[3:5])) } +func hdrVersion(bs []byte) uint16 { return binary.LittleEndian.Uint16(bs[:2]) } +func hdrType(bs []byte) byte { return bs[2] } +func hdrLen(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[3:5])) } // initiationMessage is the Noise protocol message sent from a client -// machine to a control server. +// machine to a control server. Aside from the message header, the +// values are as specified in the Noise specification for the IK +// handshake pattern. // // 5b: header (see headerLen for fields) // 32b: client ephemeral public key (cleartext) @@ -41,47 +77,43 @@ type initiationMessage [101]byte func mkInitiationMessage() initiationMessage { var ret initiationMessage - binary.LittleEndian.PutUint16(ret[:2], protocolVersion) - ret[2] = msgTypeInitiation - binary.LittleEndian.PutUint16(ret[3:5], 96) + setHeader(ret[:], protocolVersion, msgTypeInitiation, len(ret.Payload())) return ret } -func (m *initiationMessage) Header() []byte { return m[:5] } -func (m *initiationMessage) Payload() []byte { return m[5:] } +func (m *initiationMessage) Header() []byte { return m[:headerLen] } +func (m *initiationMessage) Payload() []byte { return m[headerLen:] } -func (m *initiationMessage) Version() int { return hdrVersion(m.Header()) } -func (m *initiationMessage) Type() byte { return hdrType(m.Header()) } -func (m *initiationMessage) Length() int { return hdrLen(m.Header()) } +func (m *initiationMessage) Version() uint16 { return hdrVersion(m.Header()) } +func (m *initiationMessage) Type() byte { return hdrType(m.Header()) } +func (m *initiationMessage) Length() int { return hdrLen(m.Header()) } -func (m *initiationMessage) EphemeralPub() []byte { return m[5:37] } -func (m *initiationMessage) MachinePub() []byte { return m[37:85] } -func (m *initiationMessage) Tag() []byte { return m[85:] } +func (m *initiationMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } +func (m *initiationMessage) MachinePub() []byte { return m[headerLen+32 : headerLen+32+48] } +func (m *initiationMessage) Tag() []byte { return m[headerLen+32+48:] } // responseMessage is the Noise protocol message sent from a control -// server to a client machine. +// server to a client machine. Aside from the message header, the +// values are as specified in the Noise specification for the IK +// handshake pattern. // -// 2b: little-endian protocol version -// 1b: message type -// 2b: little-endian size of message (not including this header) +// 5b: header (see headerLen for fields) // 32b: control ephemeral public key (cleartext) // 16b: message tag (authenticates the whole message) type responseMessage [53]byte func mkResponseMessage() responseMessage { var ret responseMessage - binary.LittleEndian.PutUint16(ret[:2], protocolVersion) - ret[2] = msgTypeResponse - binary.LittleEndian.PutUint16(ret[3:5], 48) + setHeader(ret[:], protocolVersion, msgTypeResponse, len(ret.Payload())) return ret } -func (m *responseMessage) Header() []byte { return m[:5] } -func (m *responseMessage) Payload() []byte { return m[5:] } +func (m *responseMessage) Header() []byte { return m[:headerLen] } +func (m *responseMessage) Payload() []byte { return m[headerLen:] } -func (m *responseMessage) Version() int { return hdrVersion(m.Header()) } -func (m *responseMessage) Type() byte { return hdrType(m.Header()) } -func (m *responseMessage) Length() int { return hdrLen(m.Header()) } +func (m *responseMessage) Version() uint16 { return hdrVersion(m.Header()) } +func (m *responseMessage) Type() byte { return hdrType(m.Header()) } +func (m *responseMessage) Length() int { return hdrLen(m.Header()) } -func (m *responseMessage) EphemeralPub() []byte { return m[5:37] } -func (m *responseMessage) Tag() []byte { return m[37:] } +func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } +func (m *responseMessage) Tag() []byte { return m[headerLen+32:] }