control/noise: factor out nonce checking and incrementing into a type.

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/3293/head
David Anderson 3 years ago committed by Dave Anderson
parent d3acd35a90
commit a34350ffda

@ -51,7 +51,7 @@ type Conn struct {
type rxState struct { type rxState struct {
sync.Mutex sync.Mutex
cipher cipher.AEAD cipher cipher.AEAD
nonce [chp.NonceSize]byte nonce nonce
buf [maxMessageSize]byte buf [maxMessageSize]byte
n int // number of valid bytes in buf n int // number of valid bytes in buf
next int // offset of next undecrypted packet next int // offset of next undecrypted packet
@ -62,7 +62,7 @@ type rxState struct {
type txState struct { type txState struct {
sync.Mutex sync.Mutex
cipher cipher.AEAD cipher cipher.AEAD
nonce [chp.NonceSize]byte nonce nonce
buf [maxMessageSize]byte buf [maxMessageSize]byte
err error // records the first partial write error for all future calls err error // records the first partial write error for all future calls
} }
@ -86,12 +86,6 @@ func (c *Conn) Peer() key.MachinePublic {
return c.peer return c.peer
} }
// validNonce reports whether nonce is in the valid range for use: 0
// through 2^64-2.
func validNonce(nonce []byte) bool {
return binary.BigEndian.Uint32(nonce[:4]) == 0 && binary.BigEndian.Uint64(nonce[4:]) != invalidNonce
}
// readNLocked reads into c.rx.buf until buf contains at least total // readNLocked reads into c.rx.buf until buf contains at least total
// bytes. Returns a slice of the available bytes in rxBuf, or an // bytes. Returns a slice of the available bytes in rxBuf, or an
// error if fewer than total bytes are available. // error if fewer than total bytes are available.
@ -123,15 +117,12 @@ func (c *Conn) decryptLocked(msg []byte) (err error) {
// be. // be.
ciphertext := msg[headerLen:] ciphertext := msg[headerLen:]
if !validNonce(c.rx.nonce[:]) { if !c.rx.nonce.Valid() {
return errCipherExhausted{} return errCipherExhausted{}
} }
c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
c.rx.nonce.Increment()
// Safe to increment the nonce here, because we checked for nonce
// wraparound above.
binary.BigEndian.PutUint64(c.rx.nonce[4:], 1+binary.BigEndian.Uint64(c.rx.nonce[4:]))
if err != nil { if err != nil {
// Once a decryption has failed, our Conn is no longer // Once a decryption has failed, our Conn is no longer
@ -147,7 +138,7 @@ func (c *Conn) decryptLocked(msg []byte) (err error) {
// packet header) and returns a slice of the ciphertext, or an error // packet header) and returns a slice of the ciphertext, or an error
// if the cipher is exhausted (i.e. can no longer be used safely). // if the cipher is exhausted (i.e. can no longer be used safely).
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) { func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
if !validNonce(c.tx.nonce[:]) { if !c.tx.nonce.Valid() {
// Received 2^64-1 messages on this cipher state. Connection // Received 2^64-1 messages on this cipher state. Connection
// is no longer usable. // is no longer usable.
return nil, errCipherExhausted{} return nil, errCipherExhausted{}
@ -156,10 +147,7 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
c.tx.buf[0] = msgTypeRecord c.tx.buf[0] = msgTypeRecord
binary.BigEndian.PutUint16(c.tx.buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) binary.BigEndian.PutUint16(c.tx.buf[1:headerLen], uint16(len(plaintext)+chp.Overhead))
ret := c.tx.cipher.Seal(c.tx.buf[:headerLen], c.tx.nonce[:], plaintext, nil) ret := c.tx.cipher.Seal(c.tx.buf[:headerLen], c.tx.nonce[:], plaintext, nil)
c.tx.nonce.Increment()
// Safe to increment the nonce here, because we checked for nonce
// wraparound above.
binary.BigEndian.PutUint64(c.tx.nonce[4:], 1+binary.BigEndian.Uint64(c.tx.nonce[4:]))
return ret, nil return ret, nil
} }
@ -357,3 +345,16 @@ func (e errReadTooBig) Temporary() bool {
return false return false
} }
func (e errReadTooBig) Timeout() bool { return false } func (e errReadTooBig) Timeout() bool { return false }
type nonce [chp.NonceSize]byte
func (n *nonce) Valid() bool {
return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce
}
func (n *nonce) Increment() {
if !n.Valid() {
panic("increment of invalid nonce")
}
binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:]))
}

Loading…
Cancel
Save