diff --git a/control/noise/handshake.go b/control/noise/handshake.go index 8e64b1dab..142a9e6a5 100644 --- a/control/noise/handshake.go +++ b/control/noise/handshake.go @@ -7,7 +7,6 @@ package noise import ( "context" "crypto/cipher" - "encoding/binary" "fmt" "hash" "io" @@ -81,15 +80,17 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK machineEphemeralPub := machineEphemeral.Public() copy(init.EphemeralPub(), machineEphemeralPub[:]) s.MixHash(machineEphemeralPub[:]) - if err := s.MixDH(machineEphemeral, controlKey); err != nil { + cipher, err := s.MixDH(machineEphemeral, controlKey) + if err != nil { return nil, fmt.Errorf("computing es: %w", err) } machineKeyPub := machineKey.Public() - s.EncryptAndHash(init.MachinePub(), machineKeyPub[:]) - if err := s.MixDH(machineKey, controlKey); err != nil { + s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub[:]) + cipher, err = s.MixDH(machineKey, controlKey) + if err != nil { return nil, fmt.Errorf("computing ss: %w", err) } - s.EncryptAndHash(init.Tag(), nil) // empty message payload + s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload if _, err := conn.Write(init[:]); err != nil { return nil, fmt.Errorf("writing initiation: %w", err) @@ -124,13 +125,14 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK var controlEphemeralPub key.Public copy(controlEphemeralPub[:], resp.EphemeralPub()) s.MixHash(controlEphemeralPub[:]) - if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { + if _, err = s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { return nil, fmt.Errorf("computing ee: %w", err) } - if err := s.MixDH(machineKey, controlEphemeralPub); err != nil { + cipher, err = s.MixDH(machineKey, controlEphemeralPub) + if err != nil { return nil, fmt.Errorf("computing se: %w", err) } - if err := s.DecryptAndHash(nil, resp.Tag()); err != nil { + if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { return nil, fmt.Errorf("decrypting payload: %w", err) } @@ -219,17 +221,19 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, var machineEphemeralPub key.Public copy(machineEphemeralPub[:], init.EphemeralPub()) s.MixHash(machineEphemeralPub[:]) - if err := s.MixDH(controlKey, machineEphemeralPub); err != nil { + cipher, err := s.MixDH(controlKey, machineEphemeralPub) + if err != nil { return nil, fmt.Errorf("computing es: %w", err) } var machineKey key.Public - if err := s.DecryptAndHash(machineKey[:], init.MachinePub()); err != nil { + if err := s.DecryptAndHash(cipher, machineKey[:], init.MachinePub()); err != nil { return nil, fmt.Errorf("decrypting machine key: %w", err) } - if err := s.MixDH(controlKey, machineKey); err != nil { + cipher, err = s.MixDH(controlKey, machineKey) + if err != nil { return nil, fmt.Errorf("computing ss: %w", err) } - if err := s.DecryptAndHash(nil, init.Tag()); err != nil { + if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { return nil, fmt.Errorf("decrypting initiation tag: %w", err) } @@ -239,13 +243,14 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, controlEphemeralPub := controlEphemeral.Public() copy(resp.EphemeralPub(), controlEphemeralPub[:]) s.MixHash(controlEphemeralPub[:]) - if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { + if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { return nil, fmt.Errorf("computing ee: %w", err) } - if err := s.MixDH(controlEphemeral, machineKey); err != nil { + cipher, err = s.MixDH(controlEphemeral, machineKey) + if err != nil { return nil, fmt.Errorf("computing se: %w", err) } - s.EncryptAndHash(resp.Tag(), nil) // empty message payload + s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload c1, c2, err := s.Split() if err != nil { @@ -280,9 +285,6 @@ type symmetricState struct { h [blake2s.Size]byte ck [blake2s.Size]byte - k [chp.KeySize]byte - n uint64 - mixer hash.Hash // for updating h } @@ -301,8 +303,6 @@ func (s *symmetricState) Initialize() { } s.h = blake2s.Sum256([]byte(protocolName)) s.ck = s.h - s.k = [chp.KeySize]byte{} - s.n = invalidNonce s.mixer = newBLAKE2s() } @@ -316,75 +316,55 @@ func (s *symmetricState) MixHash(data []byte) { s.mixer.Sum(s.h[:0]) } -// MixDH updates s.ck and s.k with the result of X25519(priv, pub). +// MixDH updates s.ck with the result of X25519(priv, pub) and returns +// a singleUseCHP that can be used to encrypt or decrypt handshake +// data. // // MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing // it as a single function allows for strongly-typed arguments that // reduce the risk of error in the caller (e.g. invoking X25519 with // two private keys, or two public keys), and thus producing the wrong // calculation. -func (s *symmetricState) MixDH(priv key.Private, pub key.Public) error { +func (s *symmetricState) MixDH(priv key.Private, pub key.Public) (*singleUseCHP, error) { s.checkFinished() keyData, err := curve25519.X25519(priv[:], pub[:]) if err != nil { - return fmt.Errorf("computing X25519: %w", err) + return nil, fmt.Errorf("computing X25519: %w", err) } r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) if _, err := io.ReadFull(r, s.ck[:]); err != nil { - return fmt.Errorf("extracting ck: %w", err) + return nil, fmt.Errorf("extracting ck: %w", err) } - if _, err := io.ReadFull(r, s.k[:]); err != nil { - return fmt.Errorf("extracting k: %w", err) + var k [chp.KeySize]byte + if _, err := io.ReadFull(r, k[:]); err != nil { + return nil, fmt.Errorf("extracting k: %w", err) } - s.n = 0 - return nil + return newSingleUseCHP(k), nil } // EncryptAndHash encrypts plaintext into ciphertext (which must be -// the correct size to hold the encrypted plaintext) using the current -// s.k, mixes the ciphertext into s.h, and returns the ciphertext. -func (s *symmetricState) EncryptAndHash(ciphertext, plaintext []byte) { +// the correct size to hold the encrypted plaintext) using cipher, +// mixes the ciphertext into s.h, and returns the ciphertext. +func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { s.checkFinished() - if s.n == invalidNonce { - // Noise in general permits writing "ciphertext" without a - // key, but in IK it cannot happen. - panic("attempted encryption with uninitialized key") - } if len(ciphertext) != len(plaintext)+poly1305.TagSize { panic("ciphertext is wrong size for given plaintext") } - aead := newCHP(s.k) - var nonce [chp.NonceSize]byte - // chacha20poly1305 nonces are 96 bits, but we use a 64-bit - // counter. Therefore, the leading 4 bytes are always zero. - binary.BigEndian.PutUint64(nonce[4:], s.n) - s.n++ - ret := aead.Seal(ciphertext[:0], nonce[:], plaintext, s.h[:]) + ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) s.MixHash(ret) } // DecryptAndHash decrypts the given ciphertext into plaintext (which // must be the correct size to hold the decrypted ciphertext) using -// the current s.k. If decryption is successful, it mixes the -// ciphertext into s.h. -func (s *symmetricState) DecryptAndHash(plaintext, ciphertext []byte) error { +// cipher. If decryption is successful, it mixes the ciphertext into +// s.h. +func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { s.checkFinished() - if s.n == invalidNonce { - // Noise in general permits "ciphertext" without a key, but in - // IK it cannot happen. - panic("attempted encryption with uninitialized key") - } if len(ciphertext) != len(plaintext)+poly1305.TagSize { panic("plaintext is wrong size for given ciphertext") } - aead := newCHP(s.k) - var nonce [chp.NonceSize]byte - // chacha20poly1305 nonces are 96 bits, but we use a 64-bit - // counter. Therefore, the leading 4 bytes are always zero. - binary.BigEndian.PutUint64(nonce[4:], s.n) - s.n++ - if _, err := aead.Open(plaintext[:0], nonce[:], ciphertext, s.h[:]); err != nil { + if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { return err } s.MixHash(ciphertext) @@ -439,3 +419,35 @@ func newCHP(key [chp.KeySize]byte) cipher.AEAD { } return aead } + +// singleUseCHP is an instance of ChaCha20Poly1305 that can be used +// only once, either for encrypting or decrypting, but not both. The +// chosen operation is always executed with an all-zeros +// nonce. Subsequent calls to either Seal or Open panic. +type singleUseCHP struct { + c cipher.AEAD +} + +func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { + return &singleUseCHP{newCHP(key)} +} + +func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { + if c.c == nil { + panic("Attempted reuse of singleUseAEAD") + } + cipher := c.c + c.c = nil + var nonce [chp.NonceSize]byte + return cipher.Seal(dst, nonce[:], plaintext, additionalData) +} + +func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { + if c.c == nil { + panic("Attempted reuse of singleUseAEAD") + } + cipher := c.c + c.c = nil + var nonce [chp.NonceSize]byte + return cipher.Open(dst, nonce[:], ciphertext, additionalData) +}