mirror of https://github.com/tailscale/tailscale/
control/noise: implement the base transport for the 2021 control protocol.
Signed-off-by: David Anderson <danderson@tailscale.com>pull/3293/head
parent
3e1daab704
commit
da7544bcc5
@ -0,0 +1,330 @@
|
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package noise implements the base transport of the Tailscale 2021
|
||||
// control protocol.
|
||||
//
|
||||
// The base transport implements Noise IK, instantiated with
|
||||
// Curve25519, ChaCha20Poly1305 and BLAKE2s.
|
||||
package noise
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
chp "golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const (
|
||||
maxPlaintextSize = 4096
|
||||
maxCiphertextSize = maxPlaintextSize + poly1305.TagSize
|
||||
maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header
|
||||
)
|
||||
|
||||
// A Conn is a secured Noise connection. It implements the net.Conn
|
||||
// interface, with the unusual trait that any write error (including a
|
||||
// SetWriteDeadline induced i/o timeout) cause all future writes to
|
||||
// fail.
|
||||
type Conn struct {
|
||||
conn net.Conn
|
||||
peer key.Public
|
||||
handshakeHash [blake2s.Size]byte
|
||||
rx rxState
|
||||
tx txState
|
||||
}
|
||||
|
||||
// rxState is all the Conn state that Read uses.
|
||||
type rxState struct {
|
||||
sync.Mutex
|
||||
cipher cipher.AEAD
|
||||
nonce [chp.NonceSize]byte
|
||||
buf [maxPacketSize]byte
|
||||
n int // number of valid bytes in buf
|
||||
next int // offset of next undecrypted packet
|
||||
plaintext []byte // slice into buf of decrypted bytes
|
||||
}
|
||||
|
||||
// txState is all the Conn state that Write uses.
|
||||
type txState struct {
|
||||
sync.Mutex
|
||||
cipher cipher.AEAD
|
||||
nonce [chp.NonceSize]byte
|
||||
buf [maxPacketSize]byte
|
||||
err error // records the first partial write error for all future calls
|
||||
}
|
||||
|
||||
// HandshakeHash returns the Noise handshake hash for the connection,
|
||||
// which can be used to bind other messages to this connection
|
||||
// (i.e. to ensure that the message wasn't replayed from a different
|
||||
// connection).
|
||||
func (c *Conn) HandshakeHash() [blake2s.Size]byte {
|
||||
return c.handshakeHash
|
||||
}
|
||||
|
||||
// Peer returns the peer's long-term public key.
|
||||
func (c *Conn) Peer() key.Public {
|
||||
return c.peer
|
||||
}
|
||||
|
||||
// validNonce reports whether nonce is in the valid range for use: 0
|
||||
// through 2^64-2.
|
||||
func validNonce(nonce []byte) bool {
|
||||
return binary.BigEndian.Uint32(nonce[:4]) == 0 && binary.BigEndian.Uint64(nonce[4:]) != invalidNonce
|
||||
}
|
||||
|
||||
// readNLocked reads into c.rxBuf until rxBuf contains at least total
|
||||
// bytes. Returns a slice of the available bytes in rxBuf, or an
|
||||
// error if fewer than total bytes are available.
|
||||
func (c *Conn) readNLocked(total int) ([]byte, error) {
|
||||
if total > maxPacketSize {
|
||||
return nil, errReadTooBig{total}
|
||||
}
|
||||
for {
|
||||
if total <= c.rx.n {
|
||||
return c.rx.buf[:c.rx.n], nil
|
||||
}
|
||||
|
||||
n, err := c.conn.Read(c.rx.buf[c.rx.n:])
|
||||
c.rx.n += n
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext
|
||||
// to the decrypted bytes. Returns an error if the cipher is exhausted
|
||||
// (i.e. can no longer be used safely) or decryption fails.
|
||||
func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
|
||||
if !validNonce(c.rx.nonce[:]) {
|
||||
return errCipherExhausted{}
|
||||
}
|
||||
|
||||
c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
|
||||
|
||||
// Safe to increment the nonce here, because we checked for nonce
|
||||
// wraparound above.
|
||||
binary.BigEndian.PutUint64(c.rx.nonce[4:], 1+binary.BigEndian.Uint64(c.rx.nonce[4:]))
|
||||
|
||||
if err != nil {
|
||||
// Once a decryption has failed, our Conn is no longer
|
||||
// synchronized with our peer. Nuke the cipher state to be
|
||||
// safe, so that no further decryptions are attempted.
|
||||
c.rx.cipher = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// encryptLocked encrypts plaintext into c.tx.buf (including the
|
||||
// 2-byte length header) and returns a slice of the ciphertext, or an
|
||||
// error if the cipher is exhausted (i.e. can no longer be used safely).
|
||||
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
|
||||
if !validNonce(c.tx.nonce[:]) {
|
||||
// Received 2^64-1 messages on this cipher state. Connection
|
||||
// is no longer usable.
|
||||
return nil, errCipherExhausted{}
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize))
|
||||
ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil)
|
||||
|
||||
// Safe to increment the nonce here, because we checked for nonce
|
||||
// wraparound above.
|
||||
binary.BigEndian.PutUint64(c.tx.nonce[4:], 1+binary.BigEndian.Uint64(c.tx.nonce[4:]))
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// wholeCiphertextLocked returns a slice of one whole Noise frame from
|
||||
// c.rx.buf, if one whole ciphertext is available, and advances the
|
||||
// read state to the next Noise frame in the buffer. Returns nil
|
||||
// without advancing read state if there's not one whole ciphertext in
|
||||
// c.rx.buf.
|
||||
func (c *Conn) wholeCiphertextLocked() []byte {
|
||||
available := c.rx.n - c.rx.next
|
||||
if available < 2 {
|
||||
return nil
|
||||
}
|
||||
bs := c.rx.buf[c.rx.next:c.rx.n]
|
||||
totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2
|
||||
if len(bs) < totalSize {
|
||||
return nil
|
||||
}
|
||||
c.rx.next += totalSize
|
||||
return bs[:totalSize]
|
||||
}
|
||||
|
||||
// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed,
|
||||
// and sets c.rx.plaintext to point to the decrypted
|
||||
// bytes. c.rx.plaintext is only valid if err == nil.
|
||||
func (c *Conn) decryptOneLocked() error {
|
||||
c.rx.plaintext = nil
|
||||
|
||||
// Fast path: do we have one whole ciphertext frame buffered
|
||||
// already?
|
||||
if bs := c.wholeCiphertextLocked(); bs != nil {
|
||||
return c.decryptLocked(bs[2:])
|
||||
}
|
||||
|
||||
if c.rx.next != 0 {
|
||||
// To simplify the read logic, move the remainder of the
|
||||
// buffered bytes back to the head of the buffer, so we can
|
||||
// grow it without worrying about wraparound.
|
||||
copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
|
||||
c.rx.n -= c.rx.next
|
||||
c.rx.next = 0
|
||||
}
|
||||
|
||||
bs, err := c.readNLocked(2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2
|
||||
bs, err = c.readNLocked(totalLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.rx.next = totalLen
|
||||
bs = bs[2:totalLen]
|
||||
|
||||
return c.decryptLocked(bs)
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (c *Conn) Read(bs []byte) (int, error) {
|
||||
c.rx.Lock()
|
||||
defer c.rx.Unlock()
|
||||
|
||||
if c.rx.cipher == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
// Loop to handle receiving a zero-byte Noise message. Just skip
|
||||
// over it and keep decrypting until we find some bytes.
|
||||
for len(c.rx.plaintext) == 0 {
|
||||
if err := c.decryptOneLocked(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
n := copy(bs, c.rx.plaintext)
|
||||
c.rx.plaintext = c.rx.plaintext[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (c *Conn) Write(bs []byte) (n int, err error) {
|
||||
c.tx.Lock()
|
||||
defer c.tx.Unlock()
|
||||
|
||||
if c.tx.err != nil {
|
||||
return 0, c.tx.err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
// All write errors are fatal for this conn, so clear the
|
||||
// cipher state whenever an error happens.
|
||||
c.tx.cipher = nil
|
||||
}
|
||||
if c.tx.err == nil {
|
||||
// Only set c.tx.err if not nil so that we can return one
|
||||
// error on the first failure, and a different one for
|
||||
// subsequent calls. See the error handling around Write
|
||||
// below for why.
|
||||
c.tx.err = err
|
||||
}
|
||||
}()
|
||||
|
||||
if c.tx.cipher == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
var sent int
|
||||
for len(bs) > 0 {
|
||||
toSend := bs
|
||||
if len(toSend) > maxPlaintextSize {
|
||||
toSend = bs[:maxPlaintextSize]
|
||||
}
|
||||
bs = bs[len(toSend):]
|
||||
|
||||
ciphertext, err := c.encryptLocked(toSend)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if n, err := c.conn.Write(ciphertext); err != nil {
|
||||
sent += n
|
||||
// Return the raw error on the Write that actually
|
||||
// failed. For future writes, return that error wrapped in
|
||||
// a desync error.
|
||||
c.tx.err = errPartialWrite{err}
|
||||
return sent, err
|
||||
}
|
||||
sent += len(toSend)
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
// Close implements io.Closer.
|
||||
func (c *Conn) Close() error {
|
||||
closeErr := c.conn.Close() // unblocks any waiting reads or writes
|
||||
c.rx.Lock()
|
||||
c.rx.cipher = nil
|
||||
c.rx.Unlock()
|
||||
c.tx.Lock()
|
||||
c.tx.cipher = nil
|
||||
c.tx.Unlock()
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||
func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
||||
func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
||||
|
||||
// errCipherExhausted is the error returned when we run out of nonces
|
||||
// on a cipher.
|
||||
type errCipherExhausted struct{}
|
||||
|
||||
func (errCipherExhausted) Error() string {
|
||||
return "cipher exhausted, no more nonces available for current key"
|
||||
}
|
||||
func (errCipherExhausted) Timeout() bool { return false }
|
||||
func (errCipherExhausted) Temporary() bool { return false }
|
||||
|
||||
// errPartialWrite is the error returned when the cipher state has
|
||||
// become unusable due to a past partial write.
|
||||
type errPartialWrite struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e errPartialWrite) Error() string {
|
||||
return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
|
||||
}
|
||||
func (e errPartialWrite) Unwrap() error { return e.err }
|
||||
func (e errPartialWrite) Temporary() bool { return false }
|
||||
func (e errPartialWrite) Timeout() bool { return false }
|
||||
|
||||
// errReadTooBig is the error returned when the peer sent an
|
||||
// unacceptably large Noise frame.
|
||||
type errReadTooBig struct {
|
||||
requested int
|
||||
}
|
||||
|
||||
func (e errReadTooBig) Error() string {
|
||||
return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
|
||||
}
|
||||
func (e errReadTooBig) Temporary() bool {
|
||||
// permanent error because this error only occurs when our peer
|
||||
// sends us a frame so large we're unwilling to ever decode it.
|
||||
return false
|
||||
}
|
||||
func (e errReadTooBig) Timeout() bool { return false }
|
@ -0,0 +1,339 @@
|
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package noise
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
chp "golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/nettest"
|
||||
tsnettest "tailscale.com/net/nettest"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestMessageSize(t *testing.T) {
|
||||
// This test is a regression guard against someone looking at
|
||||
// maxCiphertextSize, going "huh, we could be more efficient if it
|
||||
// were larger, and accidentally violating the Noise spec. Do not
|
||||
// change this max value, it's a deliberate limitation of the
|
||||
// cryptographic protocol we use (see Section 3 "Message Format"
|
||||
// of the Noise spec).
|
||||
const max = 65535
|
||||
if maxCiphertextSize > max {
|
||||
t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnBasic(t *testing.T) {
|
||||
client, server := pair(t)
|
||||
|
||||
sb := sinkReads(server)
|
||||
|
||||
want := "test"
|
||||
if _, err := io.WriteString(client, want); err != nil {
|
||||
t.Fatalf("client write failed: %v", err)
|
||||
}
|
||||
client.Close()
|
||||
|
||||
if got := sb.String(4); got != want {
|
||||
t.Fatalf("wrong content received: got %q, want %q", got, want)
|
||||
}
|
||||
if err := sb.Error(); err != io.EOF {
|
||||
t.Fatal("client close wasn't seen by server")
|
||||
}
|
||||
if sb.Total() != 4 {
|
||||
t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total())
|
||||
}
|
||||
}
|
||||
|
||||
// bufferedWriteConn wraps a net.Conn and gives control over how
|
||||
// Writes get batched out.
|
||||
type bufferedWriteConn struct {
|
||||
net.Conn
|
||||
w *bufio.Writer
|
||||
manualFlush bool
|
||||
}
|
||||
|
||||
func (c *bufferedWriteConn) Write(bs []byte) (int, error) {
|
||||
n, err := c.w.Write(bs)
|
||||
if err == nil && !c.manualFlush {
|
||||
err = c.w.Flush()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// TestFastPath exercises the Read codepath that can receive multiple
|
||||
// Noise frames at once and decode each in turn without making another
|
||||
// syscall.
|
||||
func TestFastPath(t *testing.T) {
|
||||
s1, s2 := tsnettest.NewConn("noise", 128000)
|
||||
b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false}
|
||||
client, server := pairWithConns(t, b, s2)
|
||||
|
||||
b.manualFlush = true
|
||||
|
||||
sb := sinkReads(server)
|
||||
|
||||
const packets = 10
|
||||
s := "test"
|
||||
for i := 0; i < packets; i++ {
|
||||
// Many separate writes, to force separate Noise frames that
|
||||
// all get buffered up and then all sent as a single slice to
|
||||
// the server.
|
||||
if _, err := io.WriteString(client, s); err != nil {
|
||||
t.Fatalf("client write1 failed: %v", err)
|
||||
}
|
||||
}
|
||||
if err := b.w.Flush(); err != nil {
|
||||
t.Fatalf("client flush failed: %v", err)
|
||||
}
|
||||
client.Close()
|
||||
|
||||
want := strings.Repeat(s, packets)
|
||||
if got := sb.String(len(want)); got != want {
|
||||
t.Fatalf("wrong content received: got %q, want %q", got, want)
|
||||
}
|
||||
if err := sb.Error(); err != io.EOF {
|
||||
t.Fatalf("client close wasn't seen by server")
|
||||
}
|
||||
}
|
||||
|
||||
// Writes things larger than a single Noise frame, to check the
|
||||
// chunking on the encoder and decoder.
|
||||
func TestBigData(t *testing.T) {
|
||||
client, server := pair(t)
|
||||
|
||||
serverReads := sinkReads(server)
|
||||
clientReads := sinkReads(client)
|
||||
|
||||
const sz = 15 * 1024 // 15KiB
|
||||
clientStr := strings.Repeat("abcde", sz/5)
|
||||
serverStr := strings.Repeat("fghij", sz/5*2)
|
||||
|
||||
if _, err := io.WriteString(client, clientStr); err != nil {
|
||||
t.Fatalf("writing client>server: %v", err)
|
||||
}
|
||||
if _, err := io.WriteString(server, serverStr); err != nil {
|
||||
t.Fatalf("writing server>client: %v", err)
|
||||
}
|
||||
|
||||
if serverGot := serverReads.String(sz); serverGot != clientStr {
|
||||
t.Error("server didn't receive what client sent")
|
||||
}
|
||||
if clientGot := clientReads.String(2 * sz); clientGot != serverStr {
|
||||
t.Error("client didn't receive what server sent")
|
||||
}
|
||||
|
||||
getNonce := func(n [chp.NonceSize]byte) uint64 {
|
||||
if binary.BigEndian.Uint32(n[:4]) != 0 {
|
||||
panic("unexpected nonce")
|
||||
}
|
||||
return binary.BigEndian.Uint64(n[4:])
|
||||
}
|
||||
|
||||
// Reach into the Conns and verify the cipher nonces advanced as
|
||||
// expected.
|
||||
if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) {
|
||||
t.Error("desynchronized client tx nonce")
|
||||
}
|
||||
if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) {
|
||||
t.Error("desynchronized server tx nonce")
|
||||
}
|
||||
if n := getNonce(client.tx.nonce); n != 4 {
|
||||
t.Errorf("wrong client tx nonce, got %d want 4", n)
|
||||
}
|
||||
if n := getNonce(server.tx.nonce); n != 8 {
|
||||
t.Errorf("wrong client tx nonce, got %d want 8", n)
|
||||
}
|
||||
}
|
||||
|
||||
// readerConn wraps a net.Conn and routes its Reads through a separate
|
||||
// io.Reader.
|
||||
type readerConn struct {
|
||||
net.Conn
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) }
|
||||
|
||||
// Check that the receiver can handle not being able to read an entire
|
||||
// frame in a single syscall.
|
||||
func TestDataTrickle(t *testing.T) {
|
||||
s1, s2 := tsnettest.NewConn("noise", 128000)
|
||||
client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)})
|
||||
serverReads := sinkReads(server)
|
||||
|
||||
const sz = 10000
|
||||
clientStr := strings.Repeat("abcde", sz/5)
|
||||
if _, err := io.WriteString(client, clientStr); err != nil {
|
||||
t.Fatalf("writing client>server: %v", err)
|
||||
}
|
||||
|
||||
serverGot := serverReads.String(sz)
|
||||
if serverGot != clientStr {
|
||||
t.Error("server didn't receive what client sent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnStd(t *testing.T) {
|
||||
// You can run this test manually, and noise.Conn should pass all
|
||||
// of them except for TestConn/PastTimeout,
|
||||
// TestConn/FutureTimeout, TestConn/ConcurrentMethods, because
|
||||
// those tests assume that write errors are recoverable, and
|
||||
// they're not on our Conn due to cipher security.
|
||||
t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977")
|
||||
nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
|
||||
s1, s2 := tsnettest.NewConn("noise", 4096)
|
||||
controlKey := key.NewPrivate()
|
||||
machineKey := key.NewPrivate()
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
var err error
|
||||
c2, err = Server(context.Background(), s2, controlKey)
|
||||
serverErr <- err
|
||||
}()
|
||||
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public())
|
||||
if err != nil {
|
||||
s1.Close()
|
||||
s2.Close()
|
||||
return nil, nil, nil, fmt.Errorf("connecting client: %w", err)
|
||||
}
|
||||
if err := <-serverErr; err != nil {
|
||||
c1.Close()
|
||||
s1.Close()
|
||||
s2.Close()
|
||||
return nil, nil, nil, fmt.Errorf("connecting server: %w", err)
|
||||
}
|
||||
return c1, c2, func() {
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// mkConns creates synthetic Noise Conns wrapping the given net.Conns.
|
||||
// This function is for testing just the Conn transport logic without
|
||||
// having to muck about with Noise handshakes.
|
||||
func mkConns(s1, s2 net.Conn) (*Conn, *Conn) {
|
||||
var k1, k2 [chp.KeySize]byte
|
||||
if _, err := rand.Read(k1[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if _, err := rand.Read(k2[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret1 := &Conn{
|
||||
conn: s1,
|
||||
tx: txState{cipher: newCHP(k1)},
|
||||
rx: rxState{cipher: newCHP(k2)},
|
||||
}
|
||||
ret2 := &Conn{
|
||||
conn: s2,
|
||||
tx: txState{cipher: newCHP(k2)},
|
||||
rx: rxState{cipher: newCHP(k1)},
|
||||
}
|
||||
|
||||
return ret1, ret2
|
||||
}
|
||||
|
||||
type readSink struct {
|
||||
r io.Reader
|
||||
|
||||
cond *sync.Cond
|
||||
sync.Mutex
|
||||
bs bytes.Buffer
|
||||
err error
|
||||
}
|
||||
|
||||
func sinkReads(r io.Reader) *readSink {
|
||||
ret := &readSink{
|
||||
r: r,
|
||||
}
|
||||
ret.cond = sync.NewCond(&ret.Mutex)
|
||||
go func() {
|
||||
var buf [4096]byte
|
||||
for {
|
||||
n, err := r.Read(buf[:])
|
||||
ret.Lock()
|
||||
ret.bs.Write(buf[:n])
|
||||
if err != nil {
|
||||
ret.err = err
|
||||
}
|
||||
ret.cond.Broadcast()
|
||||
ret.Unlock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (s *readSink) String(total int) string {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
for s.bs.Len() < total && s.err == nil {
|
||||
s.cond.Wait()
|
||||
}
|
||||
if s.err != nil {
|
||||
total = s.bs.Len()
|
||||
}
|
||||
return string(s.bs.Bytes()[:total])
|
||||
}
|
||||
|
||||
func (s *readSink) Error() error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
for s.err == nil {
|
||||
s.cond.Wait()
|
||||
}
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *readSink) Total() int {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return s.bs.Len()
|
||||
}
|
||||
|
||||
func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) {
|
||||
var (
|
||||
controlKey = key.NewPrivate()
|
||||
machineKey = key.NewPrivate()
|
||||
server *Conn
|
||||
serverErr = make(chan error, 1)
|
||||
)
|
||||
go func() {
|
||||
var err error
|
||||
server, err = Server(context.Background(), serverConn, controlKey)
|
||||
serverErr <- err
|
||||
}()
|
||||
|
||||
client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public())
|
||||
if err != nil {
|
||||
t.Fatalf("client connection failed: %v", err)
|
||||
}
|
||||
if err := <-serverErr; err != nil {
|
||||
t.Fatalf("server connection failed: %v", err)
|
||||
}
|
||||
return client, server
|
||||
}
|
||||
|
||||
func pair(t *testing.T) (*Conn, *Conn) {
|
||||
s1, s2 := tsnettest.NewConn("noise", 128000)
|
||||
return pairWithConns(t, s1, s2)
|
||||
}
|
@ -0,0 +1,361 @@
|
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package noise
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
chp "golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const (
|
||||
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
|
||||
invalidNonce = ^uint64(0)
|
||||
)
|
||||
|
||||
// Client initiates a Noise client handshake, returning the resulting
|
||||
// Noise connection.
|
||||
//
|
||||
// The context deadline, if any, covers the entire handshaking
|
||||
// process.
|
||||
func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlKey key.Public) (*Conn, error) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
conn.SetDeadline(time.Time{})
|
||||
}()
|
||||
}
|
||||
|
||||
var s symmetricState
|
||||
s.Initialize()
|
||||
|
||||
// <- s
|
||||
// ...
|
||||
s.MixHash(controlKey[:])
|
||||
|
||||
var init initiationMessage
|
||||
// -> e, es, s, ss
|
||||
machineEphemeral := key.NewPrivate()
|
||||
machineEphemeralPub := machineEphemeral.Public()
|
||||
copy(init.MachineEphemeralPub(), machineEphemeralPub[:])
|
||||
s.MixHash(machineEphemeralPub[:])
|
||||
if err := s.MixDH(machineEphemeral, controlKey); err != nil {
|
||||
return nil, fmt.Errorf("computing es: %w", err)
|
||||
}
|
||||
machineKeyPub := machineKey.Public()
|
||||
copy(init.MachinePub(), s.EncryptAndHash(machineKeyPub[:]))
|
||||
if err := s.MixDH(machineKey, controlKey); err != nil {
|
||||
return nil, fmt.Errorf("computing ss: %w", err)
|
||||
}
|
||||
copy(init.Tag(), s.EncryptAndHash(nil)) // empty message payload
|
||||
|
||||
if _, err := conn.Write(init[:]); err != nil {
|
||||
return nil, fmt.Errorf("writing initiation: %w", err)
|
||||
}
|
||||
|
||||
// <- e, ee, se
|
||||
var resp responseMessage
|
||||
if _, err := io.ReadFull(conn, resp[:]); err != nil {
|
||||
return nil, fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
var controlEphemeralPub key.Public
|
||||
copy(controlEphemeralPub[:], resp.ControlEphemeralPub())
|
||||
s.MixHash(controlEphemeralPub[:])
|
||||
if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
|
||||
return nil, fmt.Errorf("computing ee: %w", err)
|
||||
}
|
||||
if err := s.MixDH(machineKey, controlEphemeralPub); err != nil {
|
||||
return nil, fmt.Errorf("computing se: %w", err)
|
||||
}
|
||||
if _, err := s.DecryptAndHash(resp.Tag()); err != nil {
|
||||
return nil, fmt.Errorf("decrypting payload: %w", err)
|
||||
}
|
||||
|
||||
c1, c2, err := s.Split()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finalizing handshake: %w", err)
|
||||
}
|
||||
|
||||
return &Conn{
|
||||
conn: conn,
|
||||
peer: controlKey,
|
||||
handshakeHash: s.h,
|
||||
tx: txState{
|
||||
cipher: c1,
|
||||
},
|
||||
rx: rxState{
|
||||
cipher: c2,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Server initiates a Noise server handshake, returning the resulting
|
||||
// Noise connection.
|
||||
//
|
||||
// The context deadline, if any, covers the entire handshaking
|
||||
// process.
|
||||
func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, error) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
conn.SetDeadline(time.Time{})
|
||||
}()
|
||||
}
|
||||
|
||||
var s symmetricState
|
||||
s.Initialize()
|
||||
|
||||
// <- s
|
||||
// ...
|
||||
controlKeyPub := controlKey.Public()
|
||||
s.MixHash(controlKeyPub[:])
|
||||
|
||||
// -> e, es, s, ss
|
||||
var init initiationMessage
|
||||
if _, err := io.ReadFull(conn, init[:]); err != nil {
|
||||
return nil, fmt.Errorf("reading initiation: %w", err)
|
||||
}
|
||||
|
||||
var machineEphemeralPub key.Public
|
||||
copy(machineEphemeralPub[:], init.MachineEphemeralPub())
|
||||
s.MixHash(machineEphemeralPub[:])
|
||||
if err := s.MixDH(controlKey, machineEphemeralPub); err != nil {
|
||||
return nil, fmt.Errorf("computing es: %w", err)
|
||||
}
|
||||
var machineKey key.Public
|
||||
rs, err := s.DecryptAndHash(init.MachinePub())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting machine key: %w", err)
|
||||
}
|
||||
copy(machineKey[:], rs)
|
||||
if err := s.MixDH(controlKey, machineKey); err != nil {
|
||||
return nil, fmt.Errorf("computing ss: %w", err)
|
||||
}
|
||||
if _, err := s.DecryptAndHash(init.Tag()); err != nil {
|
||||
return nil, fmt.Errorf("decrypting initiation tag: %w", err)
|
||||
}
|
||||
|
||||
// <- e, ee, se
|
||||
var resp responseMessage
|
||||
controlEphemeral := key.NewPrivate()
|
||||
controlEphemeralPub := controlEphemeral.Public()
|
||||
copy(resp.ControlEphemeralPub(), controlEphemeralPub[:])
|
||||
s.MixHash(controlEphemeralPub[:])
|
||||
if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
|
||||
return nil, fmt.Errorf("computing ee: %w", err)
|
||||
}
|
||||
if err := s.MixDH(controlEphemeral, machineKey); err != nil {
|
||||
return nil, fmt.Errorf("computing se: %w", err)
|
||||
}
|
||||
copy(resp.Tag(), s.EncryptAndHash(nil)) // empty message payload
|
||||
|
||||
c1, c2, err := s.Split()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finalizing handshake: %w", err)
|
||||
}
|
||||
|
||||
if _, err := conn.Write(resp[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Conn{
|
||||
conn: conn,
|
||||
peer: machineKey,
|
||||
handshakeHash: s.h,
|
||||
tx: txState{
|
||||
cipher: c2,
|
||||
},
|
||||
rx: rxState{
|
||||
cipher: c1,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// initiationMessage is the Noise protocol message sent from a client
|
||||
// machine to a control server.
|
||||
type initiationMessage [96]byte
|
||||
|
||||
func (m *initiationMessage) MachineEphemeralPub() []byte { return m[:32] }
|
||||
func (m *initiationMessage) MachinePub() []byte { return m[32:80] }
|
||||
func (m *initiationMessage) Tag() []byte { return m[80:] }
|
||||
|
||||
// responseMessage is the Noise protocol message sent from a control
|
||||
// server to a client machine.
|
||||
type responseMessage [48]byte
|
||||
|
||||
func (m *responseMessage) ControlEphemeralPub() []byte { return m[:32] }
|
||||
func (m *responseMessage) Tag() []byte { return m[32:] }
|
||||
|
||||
// symmetricState is the SymmetricState object from the Noise protocol
|
||||
// spec. It contains all the symmetric cipher state of an in-flight
|
||||
// handshake. Field names match the variable names in the spec.
|
||||
type symmetricState struct {
|
||||
h [blake2s.Size]byte
|
||||
ck [blake2s.Size]byte
|
||||
|
||||
k [chp.KeySize]byte
|
||||
n uint64
|
||||
|
||||
mixer hash.Hash // for updating h
|
||||
}
|
||||
|
||||
// Initialize sets s to the initial handshake state, prior to
|
||||
// processing any Noise messages.
|
||||
func (s *symmetricState) Initialize() {
|
||||
if s.mixer != nil {
|
||||
panic("symmetricState cannot be reused")
|
||||
}
|
||||
s.h = blake2s.Sum256([]byte(protocolName))
|
||||
s.ck = s.h
|
||||
s.k = [chp.KeySize]byte{}
|
||||
s.n = invalidNonce
|
||||
s.mixer = newBLAKE2s()
|
||||
// Mix in an empty prologue.
|
||||
s.MixHash(nil)
|
||||
}
|
||||
|
||||
// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
|
||||
// concatenation.
|
||||
func (s *symmetricState) MixHash(data []byte) {
|
||||
s.mixer.Reset()
|
||||
s.mixer.Write(s.h[:])
|
||||
s.mixer.Write(data)
|
||||
s.mixer.Sum(s.h[:0]) // TODO: check this actually updates s.h correctly...
|
||||
}
|
||||
|
||||
// MixDH updates s.ck and s.k with the result of X25519(priv, pub).
|
||||
//
|
||||
// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
|
||||
// it as a single function allows for strongly-typed arguments that
|
||||
// reduce the risk of error in the caller (e.g. invoking X25519 with
|
||||
// two private keys, or two public keys), and thus producing the wrong
|
||||
// calculation.
|
||||
func (s *symmetricState) MixDH(priv key.Private, pub key.Public) error {
|
||||
// TODO(danderson): check that this operation is correct. The docs
|
||||
// for X25519 say that the 2nd arg must be either Basepoint or the
|
||||
// output of another X25519 call.
|
||||
//
|
||||
// I think this is correct, because pub is the result of a
|
||||
// ScalarBaseMult on the private key, and our private key
|
||||
// generation code clamps keys to avoid low order points. I
|
||||
// believe that makes pub equivalent to the output of
|
||||
// X25519(privateKey, Basepoint), and so the contract is
|
||||
// respected.
|
||||
keyData, err := curve25519.X25519(priv[:], pub[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("computing X25519: %w", err)
|
||||
}
|
||||
|
||||
r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
|
||||
if _, err := io.ReadFull(r, s.ck[:]); err != nil {
|
||||
return fmt.Errorf("extracting ck: %w", err)
|
||||
}
|
||||
if _, err := io.ReadFull(r, s.k[:]); err != nil {
|
||||
return fmt.Errorf("extracting k: %w", err)
|
||||
}
|
||||
s.n = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncryptAndHash encrypts the given plaintext using the current s.k,
|
||||
// mixes the ciphertext into s.h, and returns the ciphertext.
|
||||
func (s *symmetricState) EncryptAndHash(plaintext []byte) []byte {
|
||||
if s.n == invalidNonce {
|
||||
// Noise in general permits writing "ciphertext" without a
|
||||
// key, but in IK it cannot happen.
|
||||
panic("attempted encryption with uninitialized key")
|
||||
}
|
||||
aead := newCHP(s.k)
|
||||
var nonce [chp.NonceSize]byte
|
||||
binary.BigEndian.PutUint64(nonce[4:], s.n)
|
||||
s.n++
|
||||
ret := aead.Seal(nil, nonce[:], plaintext, s.h[:])
|
||||
s.MixHash(ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
// DecryptAndHash decrypts the given ciphertext using the current
|
||||
// s.k. If decryption is successful, it mixes the ciphertext into s.h
|
||||
// and returns the plaintext.
|
||||
func (s *symmetricState) DecryptAndHash(ciphertext []byte) ([]byte, error) {
|
||||
if s.n == invalidNonce {
|
||||
// Noise in general permits "ciphertext" without a key, but in
|
||||
// IK it cannot happen.
|
||||
panic("attempted encryption with uninitialized key")
|
||||
}
|
||||
aead := newCHP(s.k)
|
||||
var nonce [chp.NonceSize]byte
|
||||
binary.BigEndian.PutUint64(nonce[4:], s.n)
|
||||
s.n++
|
||||
ret, err := aead.Open(nil, nonce[:], ciphertext, s.h[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.MixHash(ciphertext)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Split returns two ChaCha20Poly1305 ciphers with keys derives from
|
||||
// the current handshake state. Methods on s must not be used again
|
||||
// after calling Split().
|
||||
func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
|
||||
var k1, k2 [chp.KeySize]byte
|
||||
r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
|
||||
if _, err := io.ReadFull(r, k1[:]); err != nil {
|
||||
return nil, nil, fmt.Errorf("extracting k1: %w", err)
|
||||
}
|
||||
if _, err := io.ReadFull(r, k2[:]); err != nil {
|
||||
return nil, nil, fmt.Errorf("extracting k2: %w", err)
|
||||
}
|
||||
c1, err = chp.New(k1[:])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
|
||||
}
|
||||
c2, err = chp.New(k2[:])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
|
||||
}
|
||||
return c1, c2, nil
|
||||
}
|
||||
|
||||
// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
|
||||
// error.
|
||||
func newBLAKE2s() hash.Hash {
|
||||
h, err := blake2s.New256(nil)
|
||||
if err != nil {
|
||||
// Should never happen, errors only happen when using BLAKE2s
|
||||
// in MAC mode with a key.
|
||||
panic(fmt.Sprintf("blake2s construction: %v", err))
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
|
||||
// panics on error.
|
||||
func newCHP(key [chp.KeySize]byte) cipher.AEAD {
|
||||
aead, err := chp.New(key[:])
|
||||
if err != nil {
|
||||
// Can only happen if we passed a key of the wrong length. The
|
||||
// function signature prevents that.
|
||||
panic(fmt.Sprintf("chacha20poly1305 construction: %v", err))
|
||||
}
|
||||
return aead
|
||||
}
|
@ -0,0 +1,290 @@
|
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package noise
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tsnettest "tailscale.com/net/nettest"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
var (
|
||||
clientConn, serverConn = tsnettest.NewConn("noise", 128000)
|
||||
serverKey = key.NewPrivate()
|
||||
clientKey = key.NewPrivate()
|
||||
server *Conn
|
||||
serverErr = make(chan error, 1)
|
||||
)
|
||||
go func() {
|
||||
var err error
|
||||
server, err = Server(context.Background(), serverConn, serverKey)
|
||||
serverErr <- err
|
||||
}()
|
||||
|
||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||
if err != nil {
|
||||
t.Fatalf("client connection failed: %v", err)
|
||||
}
|
||||
if err := <-serverErr; err != nil {
|
||||
t.Fatalf("server connection failed: %v", err)
|
||||
}
|
||||
|
||||
if client.HandshakeHash() != server.HandshakeHash() {
|
||||
t.Fatal("client and server disagree on handshake hash")
|
||||
}
|
||||
|
||||
if client.Peer() != serverKey.Public() {
|
||||
t.Fatal("client peer key isn't serverKey")
|
||||
}
|
||||
if server.Peer() != clientKey.Public() {
|
||||
t.Fatal("client peer key isn't serverKey")
|
||||
}
|
||||
}
|
||||
|
||||
// Check that handshaking repeatedly with the same long-term keys
|
||||
// result in different handshake hashes and wire traffic.
|
||||
func TestNoReuse(t *testing.T) {
|
||||
var (
|
||||
hashes = map[[32]byte]bool{}
|
||||
clientHandshakes = map[[96]byte]bool{}
|
||||
serverHandshakes = map[[48]byte]bool{}
|
||||
packets = map[[32]byte]bool{}
|
||||
)
|
||||
for i := 0; i < 10; i++ {
|
||||
var (
|
||||
clientRaw, serverRaw = tsnettest.NewConn("noise", 128000)
|
||||
clientBuf, serverBuf bytes.Buffer
|
||||
clientConn = &readerConn{clientRaw, io.TeeReader(clientRaw, &clientBuf)}
|
||||
serverConn = &readerConn{serverRaw, io.TeeReader(serverRaw, &serverBuf)}
|
||||
serverKey = key.NewPrivate()
|
||||
clientKey = key.NewPrivate()
|
||||
server *Conn
|
||||
serverErr = make(chan error, 1)
|
||||
)
|
||||
go func() {
|
||||
var err error
|
||||
server, err = Server(context.Background(), serverConn, serverKey)
|
||||
serverErr <- err
|
||||
}()
|
||||
|
||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||
if err != nil {
|
||||
t.Fatalf("client connection failed: %v", err)
|
||||
}
|
||||
if err := <-serverErr; err != nil {
|
||||
t.Fatalf("server connection failed: %v", err)
|
||||
}
|
||||
|
||||
var clientHS [96]byte
|
||||
copy(clientHS[:], serverBuf.Bytes())
|
||||
if clientHandshakes[clientHS] {
|
||||
t.Fatal("client handshake seen twice")
|
||||
}
|
||||
clientHandshakes[clientHS] = true
|
||||
|
||||
var serverHS [48]byte
|
||||
copy(serverHS[:], clientBuf.Bytes())
|
||||
if serverHandshakes[serverHS] {
|
||||
t.Fatal("server handshake seen twice")
|
||||
}
|
||||
serverHandshakes[serverHS] = true
|
||||
|
||||
clientBuf.Reset()
|
||||
serverBuf.Reset()
|
||||
cb := sinkReads(client)
|
||||
sb := sinkReads(server)
|
||||
|
||||
if hashes[client.HandshakeHash()] {
|
||||
t.Fatalf("handshake hash %v seen twice", client.HandshakeHash())
|
||||
}
|
||||
hashes[client.HandshakeHash()] = true
|
||||
|
||||
// Sending 14 bytes turns into 32 bytes on the wire (+16 for
|
||||
// the poly1305 tag, +2 length header)
|
||||
if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
|
||||
t.Fatalf("client>server write failed: %v", err)
|
||||
}
|
||||
if _, err := io.WriteString(server, strings.Repeat("b", 14)); err != nil {
|
||||
t.Fatalf("server>client write failed: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the bytes to be read, so we know they've traveled end to end
|
||||
cb.String(14)
|
||||
sb.String(14)
|
||||
|
||||
var clientWire, serverWire [32]byte
|
||||
copy(clientWire[:], clientBuf.Bytes())
|
||||
copy(serverWire[:], serverBuf.Bytes())
|
||||
|
||||
if packets[clientWire] {
|
||||
t.Fatalf("client wire traffic seen twice")
|
||||
}
|
||||
packets[clientWire] = true
|
||||
if packets[serverWire] {
|
||||
t.Fatalf("server wire traffic seen twice")
|
||||
}
|
||||
packets[serverWire] = true
|
||||
}
|
||||
}
|
||||
|
||||
// tamperReader wraps a reader and mutates the Nth byte.
|
||||
type tamperReader struct {
|
||||
r io.Reader
|
||||
n int
|
||||
total int
|
||||
}
|
||||
|
||||
func (r *tamperReader) Read(bs []byte) (int, error) {
|
||||
n, err := r.r.Read(bs)
|
||||
if off := r.n - r.total; off >= 0 && off < n {
|
||||
bs[off] += 1
|
||||
}
|
||||
r.total += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func TestTampering(t *testing.T) {
|
||||
// Tamper with every byte of the client initiation message.
|
||||
for i := 0; i < 96; i++ {
|
||||
var (
|
||||
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
|
||||
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}}
|
||||
serverKey = key.NewPrivate()
|
||||
clientKey = key.NewPrivate()
|
||||
serverErr = make(chan error, 1)
|
||||
)
|
||||
go func() {
|
||||
_, err := Server(context.Background(), serverConn, serverKey)
|
||||
// If the server failed, we have to close the Conn to
|
||||
// unblock the client.
|
||||
if err != nil {
|
||||
serverConn.Close()
|
||||
}
|
||||
serverErr <- err
|
||||
}()
|
||||
|
||||
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||
if err == nil {
|
||||
t.Fatal("client connection succeeded despite tampering")
|
||||
}
|
||||
if err := <-serverErr; err == nil {
|
||||
t.Fatalf("server connection succeeded despite tampering")
|
||||
}
|
||||
}
|
||||
|
||||
// Tamper with every byte of the server response message.
|
||||
for i := 0; i < 48; i++ {
|
||||
var (
|
||||
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
|
||||
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}}
|
||||
serverKey = key.NewPrivate()
|
||||
clientKey = key.NewPrivate()
|
||||
serverErr = make(chan error, 1)
|
||||
)
|
||||
go func() {
|
||||
_, err := Server(context.Background(), serverConn, serverKey)
|
||||
serverErr <- err
|
||||
}()
|
||||
|
||||
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||
if err == nil {
|
||||
t.Fatal("client connection succeeded despite tampering")
|
||||
}
|
||||
// The server shouldn't fail, because the tampering took place
|
||||
// in its response.
|
||||
if err := <-serverErr; err != nil {
|
||||
t.Fatalf("server connection failed despite no tampering: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Tamper with every byte of the first server>client transport message.
|
||||
for i := 0; i < 32; i++ {
|
||||
var (
|
||||
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
|
||||
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 48 + i, 0}}
|
||||
serverKey = key.NewPrivate()
|
||||
clientKey = key.NewPrivate()
|
||||
serverErr = make(chan error, 1)
|
||||
)
|
||||
go func() {
|
||||
server, err := Server(context.Background(), serverConn, serverKey)
|
||||
serverErr <- err
|
||||
_, err = io.WriteString(server, strings.Repeat("a", 14))
|
||||
serverErr <- err
|
||||
}()
|
||||
|
||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||
if err != nil {
|
||||
t.Fatalf("client handshake failed: %v", err)
|
||||
}
|
||||
// The server shouldn't fail, because the tampering took place
|
||||
// in its response.
|
||||
if err := <-serverErr; err != nil {
|
||||
t.Fatalf("server handshake failed: %v", err)
|
||||
}
|
||||
|
||||
// The client needs a timeout if the tampering is hitting the length header.
|
||||
if i == 0 || i == 1 {
|
||||
client.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||
}
|
||||
|
||||
var bs [100]byte
|
||||
n, err := client.Read(bs[:])
|
||||
if err == nil {
|
||||
t.Fatal("read succeeded despite tampering")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Fatal("conn yielded some bytes despite tampering")
|
||||
}
|
||||
}
|
||||
|
||||
// Tamper with every byte of the first client>server transport message.
|
||||
for i := 0; i < 32; i++ {
|
||||
var (
|
||||
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
|
||||
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 96 + i, 0}}
|
||||
serverKey = key.NewPrivate()
|
||||
clientKey = key.NewPrivate()
|
||||
serverErr = make(chan error, 1)
|
||||
)
|
||||
go func() {
|
||||
server, err := Server(context.Background(), serverConn, serverKey)
|
||||
serverErr <- err
|
||||
var bs [100]byte
|
||||
// The server needs a timeout if the tampering is hitting the length header.
|
||||
if i == 0 || i == 1 {
|
||||
server.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||
}
|
||||
n, err := server.Read(bs[:])
|
||||
if n != 0 {
|
||||
panic("server got bytes despite tampering")
|
||||
} else {
|
||||
serverErr <- err
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||
if err != nil {
|
||||
t.Fatalf("client handshake failed: %v", err)
|
||||
}
|
||||
if err := <-serverErr; err != nil {
|
||||
t.Fatalf("server handshake failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
|
||||
t.Fatalf("client>server write failed: %v", err)
|
||||
}
|
||||
if err := <-serverErr; err == nil {
|
||||
t.Fatal("server successfully received bytes despite tampering")
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,238 @@
|
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package noise
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
tsnettest "tailscale.com/net/nettest"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// Can a reference Noise IK client talk to our server?
|
||||
func TestInteropClient(t *testing.T) {
|
||||
var (
|
||||
s1, s2 = tsnettest.NewConn("noise", 128000)
|
||||
controlKey = key.NewPrivate()
|
||||
machineKey = key.NewPrivate()
|
||||
serverErr = make(chan error, 2)
|
||||
serverBytes = make(chan []byte, 1)
|
||||
c2s = "client>server"
|
||||
s2c = "server>client"
|
||||
)
|
||||
|
||||
go func() {
|
||||
server, err := Server(context.Background(), s2, controlKey)
|
||||
serverErr <- err
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var buf [1024]byte
|
||||
_, err = io.ReadFull(server, buf[:len(c2s)])
|
||||
serverBytes <- buf[:len(c2s)]
|
||||
if err != nil {
|
||||
serverErr <- err
|
||||
return
|
||||
}
|
||||
_, err = server.Write([]byte(s2c))
|
||||
serverErr <- err
|
||||
}()
|
||||
|
||||
gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s))
|
||||
if err != nil {
|
||||
t.Fatalf("failed client interop: %v", err)
|
||||
}
|
||||
if string(gotS2C) != s2c {
|
||||
t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c)
|
||||
}
|
||||
|
||||
if err := <-serverErr; err != nil {
|
||||
t.Fatalf("server handshake failed: %v", err)
|
||||
}
|
||||
if err := <-serverErr; err != nil {
|
||||
t.Fatalf("server read/write failed: %v", err)
|
||||
}
|
||||
if got := string(<-serverBytes); got != c2s {
|
||||
t.Fatalf("server received %q, want %q", got, c2s)
|
||||
}
|
||||
}
|
||||
|
||||
// Can our client talk to a reference Noise IK server?
|
||||
func TestInteropServer(t *testing.T) {
|
||||
var (
|
||||
s1, s2 = tsnettest.NewConn("noise", 128000)
|
||||
controlKey = key.NewPrivate()
|
||||
machineKey = key.NewPrivate()
|
||||
clientErr = make(chan error, 2)
|
||||
clientBytes = make(chan []byte, 1)
|
||||
c2s = "client>server"
|
||||
s2c = "server>client"
|
||||
)
|
||||
|
||||
go func() {
|
||||
client, err := Client(context.Background(), s1, machineKey, controlKey.Public())
|
||||
clientErr <- err
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = client.Write([]byte(c2s))
|
||||
if err != nil {
|
||||
clientErr <- err
|
||||
return
|
||||
}
|
||||
var buf [1024]byte
|
||||
_, err = io.ReadFull(client, buf[:len(s2c)])
|
||||
clientBytes <- buf[:len(s2c)]
|
||||
clientErr <- err
|
||||
}()
|
||||
|
||||
gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c))
|
||||
if err != nil {
|
||||
t.Fatalf("failed server interop: %v", err)
|
||||
}
|
||||
if string(gotC2S) != c2s {
|
||||
t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s)
|
||||
}
|
||||
|
||||
if err := <-clientErr; err != nil {
|
||||
t.Fatalf("client handshake failed: %v", err)
|
||||
}
|
||||
if err := <-clientErr; err != nil {
|
||||
t.Fatalf("client read/write failed: %v", err)
|
||||
}
|
||||
if got := string(<-clientBytes); got != s2c {
|
||||
t.Fatalf("client received %q, want %q", got, s2c)
|
||||
}
|
||||
}
|
||||
|
||||
// noiseExplorerClient uses the Noise Explorer implementation of Noise
|
||||
// IK to handshake as a Noise client on conn, transmit payload, and
|
||||
// read+return a payload from the peer.
|
||||
func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Private, payload []byte) ([]byte, error) {
|
||||
mk := keypair{
|
||||
private_key: machineKey,
|
||||
public_key: machineKey.Public(),
|
||||
}
|
||||
session := InitSession(true, nil, mk, controlKey)
|
||||
|
||||
_, msg1 := SendMessage(&session, nil)
|
||||
if _, err := conn.Write(msg1.ne[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := conn.Write(msg1.ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := conn.Write(msg1.ciphertext); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var buf [1024]byte
|
||||
if _, err := io.ReadFull(conn, buf[:48]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg2 := messagebuffer{
|
||||
ciphertext: buf[32:48],
|
||||
}
|
||||
copy(msg2.ne[:], buf[:32])
|
||||
_, p, valid := RecvMessage(&session, &msg2)
|
||||
if !valid {
|
||||
return nil, errors.New("handshake failed")
|
||||
}
|
||||
if len(p) != 0 {
|
||||
return nil, errors.New("non-empty payload")
|
||||
}
|
||||
|
||||
_, msg3 := SendMessage(&session, payload)
|
||||
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg3.ciphertext)))
|
||||
if _, err := conn.Write(buf[:2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := conn.Write(msg3.ciphertext); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plen := int(binary.BigEndian.Uint16(buf[:2]))
|
||||
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg4 := messagebuffer{
|
||||
ciphertext: buf[:plen],
|
||||
}
|
||||
_, p, valid = RecvMessage(&session, &msg4)
|
||||
if !valid {
|
||||
return nil, errors.New("transport message decryption failed")
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey key.Public, payload []byte) ([]byte, error) {
|
||||
mk := keypair{
|
||||
private_key: controlKey,
|
||||
public_key: controlKey.Public(),
|
||||
}
|
||||
session := InitSession(false, nil, mk, [32]byte{})
|
||||
|
||||
var buf [1024]byte
|
||||
if _, err := io.ReadFull(conn, buf[:96]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg1 := messagebuffer{
|
||||
ns: buf[32:80],
|
||||
ciphertext: buf[80:96],
|
||||
}
|
||||
copy(msg1.ne[:], buf[:32])
|
||||
_, p, valid := RecvMessage(&session, &msg1)
|
||||
if !valid {
|
||||
return nil, errors.New("handshake failed")
|
||||
}
|
||||
if len(p) != 0 {
|
||||
return nil, errors.New("non-empty payload")
|
||||
}
|
||||
|
||||
_, msg2 := SendMessage(&session, nil)
|
||||
if _, err := conn.Write(msg2.ne[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := conn.Write(msg2.ciphertext[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plen := int(binary.BigEndian.Uint16(buf[:2]))
|
||||
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg3 := messagebuffer{
|
||||
ciphertext: buf[:plen],
|
||||
}
|
||||
_, p, valid = RecvMessage(&session, &msg3)
|
||||
if !valid {
|
||||
return nil, errors.New("transport message decryption failed")
|
||||
}
|
||||
|
||||
_, msg4 := SendMessage(&session, payload)
|
||||
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg4.ciphertext)))
|
||||
if _, err := conn.Write(buf[:2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := conn.Write(msg4.ciphertext); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
@ -0,0 +1,475 @@
|
||||
// This file contains the implementation of Noise IK from
|
||||
// https://noiseexplorer.com/ . Unlike the rest of this repository,
|
||||
// this file is licensed under the terms of the GNU GPL v3. See
|
||||
// https://source.symbolic.software/noiseexplorer/noiseexplorer for
|
||||
// more information.
|
||||
//
|
||||
// This file is used here to verify that Tailscale's implementation of
|
||||
// Noise IK is interoperable with another implementation.
|
||||
//lint:file-ignore SA4006 not our code.
|
||||
|
||||
/*
|
||||
IK:
|
||||
<- s
|
||||
...
|
||||
-> e, es, s, ss
|
||||
<- e, ee, se
|
||||
->
|
||||
<-
|
||||
*/
|
||||
|
||||
// Implementation Version: 1.0.2
|
||||
|
||||
/* ---------------------------------------------------------------- *
|
||||
* PARAMETERS *
|
||||
* ---------------------------------------------------------------- */
|
||||
|
||||
package noise
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"hash"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
/* ---------------------------------------------------------------- *
|
||||
* TYPES *
|
||||
* ---------------------------------------------------------------- */
|
||||
|
||||
type keypair struct {
|
||||
public_key [32]byte
|
||||
private_key [32]byte
|
||||
}
|
||||
|
||||
type messagebuffer struct {
|
||||
ne [32]byte
|
||||
ns []byte
|
||||
ciphertext []byte
|
||||
}
|
||||
|
||||
type cipherstate struct {
|
||||
k [32]byte
|
||||
n uint32
|
||||
}
|
||||
|
||||
type symmetricstate struct {
|
||||
cs cipherstate
|
||||
ck [32]byte
|
||||
h [32]byte
|
||||
}
|
||||
|
||||
type handshakestate struct {
|
||||
ss symmetricstate
|
||||
s keypair
|
||||
e keypair
|
||||
rs [32]byte
|
||||
re [32]byte
|
||||
psk [32]byte
|
||||
}
|
||||
|
||||
type noisesession struct {
|
||||
hs handshakestate
|
||||
h [32]byte
|
||||
cs1 cipherstate
|
||||
cs2 cipherstate
|
||||
mc uint64
|
||||
i bool
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------- *
|
||||
* CONSTANTS *
|
||||
* ---------------------------------------------------------------- */
|
||||
|
||||
var emptyKey = [32]byte{
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
}
|
||||
|
||||
var minNonce = uint32(0)
|
||||
|
||||
/* ---------------------------------------------------------------- *
|
||||
* UTILITY FUNCTIONS *
|
||||
* ---------------------------------------------------------------- */
|
||||
|
||||
func getPublicKey(kp *keypair) [32]byte {
|
||||
return kp.public_key
|
||||
}
|
||||
|
||||
func isEmptyKey(k [32]byte) bool {
|
||||
return subtle.ConstantTimeCompare(k[:], emptyKey[:]) == 1
|
||||
}
|
||||
|
||||
func validatePublicKey(k []byte) bool {
|
||||
forbiddenCurveValues := [12][]byte{
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{224, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 0},
|
||||
{95, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 87},
|
||||
{236, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
|
||||
{237, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
|
||||
{238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
|
||||
{205, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 128},
|
||||
{76, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 215},
|
||||
{217, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
|
||||
{218, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
|
||||
{219, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 25},
|
||||
}
|
||||
|
||||
for _, testValue := range forbiddenCurveValues {
|
||||
if subtle.ConstantTimeCompare(k[:], testValue[:]) == 1 {
|
||||
panic("Invalid public key")
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------- *
|
||||
* PRIMITIVES *
|
||||
* ---------------------------------------------------------------- */
|
||||
|
||||
func incrementNonce(n uint32) uint32 {
|
||||
return n + 1
|
||||
}
|
||||
|
||||
func dh(private_key [32]byte, public_key [32]byte) [32]byte {
|
||||
var ss [32]byte
|
||||
curve25519.ScalarMult(&ss, &private_key, &public_key)
|
||||
return ss
|
||||
}
|
||||
|
||||
func generateKeypair() keypair {
|
||||
var public_key [32]byte
|
||||
var private_key [32]byte
|
||||
_, _ = rand.Read(private_key[:])
|
||||
curve25519.ScalarBaseMult(&public_key, &private_key)
|
||||
if validatePublicKey(public_key[:]) {
|
||||
return keypair{public_key, private_key}
|
||||
}
|
||||
return generateKeypair()
|
||||
}
|
||||
|
||||
func generatePublicKey(private_key [32]byte) [32]byte {
|
||||
var public_key [32]byte
|
||||
curve25519.ScalarBaseMult(&public_key, &private_key)
|
||||
return public_key
|
||||
}
|
||||
|
||||
func encrypt(k [32]byte, n uint32, ad []byte, plaintext []byte) []byte {
|
||||
var nonce [12]byte
|
||||
var ciphertext []byte
|
||||
enc, _ := chacha20poly1305.New(k[:])
|
||||
binary.LittleEndian.PutUint32(nonce[4:], n)
|
||||
ciphertext = enc.Seal(nil, nonce[:], plaintext, ad)
|
||||
return ciphertext
|
||||
}
|
||||
|
||||
func decrypt(k [32]byte, n uint32, ad []byte, ciphertext []byte) (bool, []byte, []byte) {
|
||||
var nonce [12]byte
|
||||
var plaintext []byte
|
||||
enc, err := chacha20poly1305.New(k[:])
|
||||
binary.LittleEndian.PutUint32(nonce[4:], n)
|
||||
plaintext, err = enc.Open(nil, nonce[:], ciphertext, ad)
|
||||
return (err == nil), ad, plaintext
|
||||
}
|
||||
|
||||
func getHash(a []byte, b []byte) [32]byte {
|
||||
return blake2s.Sum256(append(a, b...))
|
||||
}
|
||||
|
||||
func hashProtocolName(protocolName []byte) [32]byte {
|
||||
var h [32]byte
|
||||
if len(protocolName) <= 32 {
|
||||
copy(h[:], protocolName)
|
||||
} else {
|
||||
h = getHash(protocolName, []byte{})
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func blake2HkdfInterface() hash.Hash {
|
||||
h, _ := blake2s.New256([]byte{})
|
||||
return h
|
||||
}
|
||||
|
||||
func getHkdf(ck [32]byte, ikm []byte) ([32]byte, [32]byte, [32]byte) {
|
||||
var k1 [32]byte
|
||||
var k2 [32]byte
|
||||
var k3 [32]byte
|
||||
output := hkdf.New(blake2HkdfInterface, ikm[:], ck[:], []byte{})
|
||||
io.ReadFull(output, k1[:])
|
||||
io.ReadFull(output, k2[:])
|
||||
io.ReadFull(output, k3[:])
|
||||
return k1, k2, k3
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------- *
|
||||
* STATE MANAGEMENT *
|
||||
* ---------------------------------------------------------------- */
|
||||
|
||||
/* CipherState */
|
||||
func initializeKey(k [32]byte) cipherstate {
|
||||
return cipherstate{k, minNonce}
|
||||
}
|
||||
|
||||
func hasKey(cs *cipherstate) bool {
|
||||
return !isEmptyKey(cs.k)
|
||||
}
|
||||
|
||||
func setNonce(cs *cipherstate, newNonce uint32) *cipherstate {
|
||||
cs.n = newNonce
|
||||
return cs
|
||||
}
|
||||
|
||||
func encryptWithAd(cs *cipherstate, ad []byte, plaintext []byte) (*cipherstate, []byte) {
|
||||
e := encrypt(cs.k, cs.n, ad, plaintext)
|
||||
cs = setNonce(cs, incrementNonce(cs.n))
|
||||
return cs, e
|
||||
}
|
||||
|
||||
func decryptWithAd(cs *cipherstate, ad []byte, ciphertext []byte) (*cipherstate, []byte, bool) {
|
||||
valid, ad, plaintext := decrypt(cs.k, cs.n, ad, ciphertext)
|
||||
cs = setNonce(cs, incrementNonce(cs.n))
|
||||
return cs, plaintext, valid
|
||||
}
|
||||
|
||||
func reKey(cs *cipherstate) *cipherstate {
|
||||
e := encrypt(cs.k, math.MaxUint32, []byte{}, emptyKey[:])
|
||||
copy(cs.k[:], e)
|
||||
return cs
|
||||
}
|
||||
|
||||
/* SymmetricState */
|
||||
|
||||
func initializeSymmetric(protocolName []byte) symmetricstate {
|
||||
h := hashProtocolName(protocolName)
|
||||
ck := h
|
||||
cs := initializeKey(emptyKey)
|
||||
return symmetricstate{cs, ck, h}
|
||||
}
|
||||
|
||||
func mixKey(ss *symmetricstate, ikm [32]byte) *symmetricstate {
|
||||
ck, tempK, _ := getHkdf(ss.ck, ikm[:])
|
||||
ss.cs = initializeKey(tempK)
|
||||
ss.ck = ck
|
||||
return ss
|
||||
}
|
||||
|
||||
func mixHash(ss *symmetricstate, data []byte) *symmetricstate {
|
||||
ss.h = getHash(ss.h[:], data)
|
||||
return ss
|
||||
}
|
||||
|
||||
func mixKeyAndHash(ss *symmetricstate, ikm [32]byte) *symmetricstate {
|
||||
var tempH [32]byte
|
||||
var tempK [32]byte
|
||||
ss.ck, tempH, tempK = getHkdf(ss.ck, ikm[:])
|
||||
ss = mixHash(ss, tempH[:])
|
||||
ss.cs = initializeKey(tempK)
|
||||
return ss
|
||||
}
|
||||
|
||||
func getHandshakeHash(ss *symmetricstate) [32]byte {
|
||||
return ss.h
|
||||
}
|
||||
|
||||
func encryptAndHash(ss *symmetricstate, plaintext []byte) (*symmetricstate, []byte) {
|
||||
var ciphertext []byte
|
||||
if hasKey(&ss.cs) {
|
||||
_, ciphertext = encryptWithAd(&ss.cs, ss.h[:], plaintext)
|
||||
} else {
|
||||
ciphertext = plaintext
|
||||
}
|
||||
ss = mixHash(ss, ciphertext)
|
||||
return ss, ciphertext
|
||||
}
|
||||
|
||||
func decryptAndHash(ss *symmetricstate, ciphertext []byte) (*symmetricstate, []byte, bool) {
|
||||
var plaintext []byte
|
||||
var valid bool
|
||||
if hasKey(&ss.cs) {
|
||||
_, plaintext, valid = decryptWithAd(&ss.cs, ss.h[:], ciphertext)
|
||||
} else {
|
||||
plaintext, valid = ciphertext, true
|
||||
}
|
||||
ss = mixHash(ss, ciphertext)
|
||||
return ss, plaintext, valid
|
||||
}
|
||||
|
||||
func split(ss *symmetricstate) (cipherstate, cipherstate) {
|
||||
tempK1, tempK2, _ := getHkdf(ss.ck, []byte{})
|
||||
cs1 := initializeKey(tempK1)
|
||||
cs2 := initializeKey(tempK2)
|
||||
return cs1, cs2
|
||||
}
|
||||
|
||||
/* HandshakeState */
|
||||
|
||||
func initializeInitiator(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate {
|
||||
var ss symmetricstate
|
||||
var e keypair
|
||||
var re [32]byte
|
||||
name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s")
|
||||
ss = initializeSymmetric(name)
|
||||
mixHash(&ss, prologue)
|
||||
mixHash(&ss, rs[:])
|
||||
return handshakestate{ss, s, e, rs, re, psk}
|
||||
}
|
||||
|
||||
func initializeResponder(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate {
|
||||
var ss symmetricstate
|
||||
var e keypair
|
||||
var re [32]byte
|
||||
name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s")
|
||||
ss = initializeSymmetric(name)
|
||||
mixHash(&ss, prologue)
|
||||
mixHash(&ss, s.public_key[:])
|
||||
return handshakestate{ss, s, e, rs, re, psk}
|
||||
}
|
||||
|
||||
func writeMessageA(hs *handshakestate, payload []byte) (*handshakestate, messagebuffer) {
|
||||
ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
|
||||
hs.e = generateKeypair()
|
||||
ne = hs.e.public_key
|
||||
mixHash(&hs.ss, ne[:])
|
||||
/* No PSK, so skipping mixKey */
|
||||
mixKey(&hs.ss, dh(hs.e.private_key, hs.rs))
|
||||
spk := make([]byte, len(hs.s.public_key))
|
||||
copy(spk[:], hs.s.public_key[:])
|
||||
_, ns = encryptAndHash(&hs.ss, spk)
|
||||
mixKey(&hs.ss, dh(hs.s.private_key, hs.rs))
|
||||
_, ciphertext = encryptAndHash(&hs.ss, payload)
|
||||
messageBuffer := messagebuffer{ne, ns, ciphertext}
|
||||
return hs, messageBuffer
|
||||
}
|
||||
|
||||
func writeMessageB(hs *handshakestate, payload []byte) ([32]byte, messagebuffer, cipherstate, cipherstate) {
|
||||
ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
|
||||
hs.e = generateKeypair()
|
||||
ne = hs.e.public_key
|
||||
mixHash(&hs.ss, ne[:])
|
||||
/* No PSK, so skipping mixKey */
|
||||
mixKey(&hs.ss, dh(hs.e.private_key, hs.re))
|
||||
mixKey(&hs.ss, dh(hs.e.private_key, hs.rs))
|
||||
_, ciphertext = encryptAndHash(&hs.ss, payload)
|
||||
messageBuffer := messagebuffer{ne, ns, ciphertext}
|
||||
cs1, cs2 := split(&hs.ss)
|
||||
return hs.ss.h, messageBuffer, cs1, cs2
|
||||
}
|
||||
|
||||
func writeMessageRegular(cs *cipherstate, payload []byte) (*cipherstate, messagebuffer) {
|
||||
ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
|
||||
cs, ciphertext = encryptWithAd(cs, []byte{}, payload)
|
||||
messageBuffer := messagebuffer{ne, ns, ciphertext}
|
||||
return cs, messageBuffer
|
||||
}
|
||||
|
||||
func readMessageA(hs *handshakestate, message *messagebuffer) (*handshakestate, []byte, bool) {
|
||||
valid1 := true
|
||||
if validatePublicKey(message.ne[:]) {
|
||||
hs.re = message.ne
|
||||
}
|
||||
mixHash(&hs.ss, hs.re[:])
|
||||
/* No PSK, so skipping mixKey */
|
||||
mixKey(&hs.ss, dh(hs.s.private_key, hs.re))
|
||||
_, ns, valid1 := decryptAndHash(&hs.ss, message.ns)
|
||||
if valid1 && len(ns) == 32 && validatePublicKey(message.ns[:]) {
|
||||
copy(hs.rs[:], ns)
|
||||
}
|
||||
mixKey(&hs.ss, dh(hs.s.private_key, hs.rs))
|
||||
_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext)
|
||||
return hs, plaintext, (valid1 && valid2)
|
||||
}
|
||||
|
||||
func readMessageB(hs *handshakestate, message *messagebuffer) ([32]byte, []byte, bool, cipherstate, cipherstate) {
|
||||
valid1 := true
|
||||
if validatePublicKey(message.ne[:]) {
|
||||
hs.re = message.ne
|
||||
}
|
||||
mixHash(&hs.ss, hs.re[:])
|
||||
/* No PSK, so skipping mixKey */
|
||||
mixKey(&hs.ss, dh(hs.e.private_key, hs.re))
|
||||
mixKey(&hs.ss, dh(hs.s.private_key, hs.re))
|
||||
_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext)
|
||||
cs1, cs2 := split(&hs.ss)
|
||||
return hs.ss.h, plaintext, (valid1 && valid2), cs1, cs2
|
||||
}
|
||||
|
||||
func readMessageRegular(cs *cipherstate, message *messagebuffer) (*cipherstate, []byte, bool) {
|
||||
/* No encrypted keys */
|
||||
_, plaintext, valid2 := decryptWithAd(cs, []byte{}, message.ciphertext)
|
||||
return cs, plaintext, valid2
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------- *
|
||||
* PROCESSES *
|
||||
* ---------------------------------------------------------------- */
|
||||
|
||||
func InitSession(initiator bool, prologue []byte, s keypair, rs [32]byte) noisesession {
|
||||
var session noisesession
|
||||
psk := emptyKey
|
||||
if initiator {
|
||||
session.hs = initializeInitiator(prologue, s, rs, psk)
|
||||
} else {
|
||||
session.hs = initializeResponder(prologue, s, rs, psk)
|
||||
}
|
||||
session.i = initiator
|
||||
session.mc = 0
|
||||
return session
|
||||
}
|
||||
|
||||
func SendMessage(session *noisesession, message []byte) (*noisesession, messagebuffer) {
|
||||
var messageBuffer messagebuffer
|
||||
if session.mc == 0 {
|
||||
_, messageBuffer = writeMessageA(&session.hs, message)
|
||||
}
|
||||
if session.mc == 1 {
|
||||
session.h, messageBuffer, session.cs1, session.cs2 = writeMessageB(&session.hs, message)
|
||||
session.hs = handshakestate{}
|
||||
}
|
||||
if session.mc > 1 {
|
||||
if session.i {
|
||||
_, messageBuffer = writeMessageRegular(&session.cs1, message)
|
||||
} else {
|
||||
_, messageBuffer = writeMessageRegular(&session.cs2, message)
|
||||
}
|
||||
}
|
||||
session.mc = session.mc + 1
|
||||
return session, messageBuffer
|
||||
}
|
||||
|
||||
func RecvMessage(session *noisesession, message *messagebuffer) (*noisesession, []byte, bool) {
|
||||
var plaintext []byte
|
||||
var valid bool
|
||||
if session.mc == 0 {
|
||||
_, plaintext, valid = readMessageA(&session.hs, message)
|
||||
}
|
||||
if session.mc == 1 {
|
||||
session.h, plaintext, valid, session.cs1, session.cs2 = readMessageB(&session.hs, message)
|
||||
session.hs = handshakestate{}
|
||||
}
|
||||
if session.mc > 1 {
|
||||
if session.i {
|
||||
_, plaintext, valid = readMessageRegular(&session.cs2, message)
|
||||
} else {
|
||||
_, plaintext, valid = readMessageRegular(&session.cs1, message)
|
||||
}
|
||||
}
|
||||
session.mc = session.mc + 1
|
||||
return session, plaintext, valid
|
||||
}
|
||||
|
||||
func main() {}
|
Loading…
Reference in New Issue