From da7544bcc589d9ba1f39b0fcb6391aa3af16e575 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Sun, 27 Jun 2021 01:21:53 -0700 Subject: [PATCH] control/noise: implement the base transport for the 2021 control protocol. Signed-off-by: David Anderson --- control/noise/conn.go | 330 +++++++++++++++++++ control/noise/conn_test.go | 339 ++++++++++++++++++++ control/noise/handshake.go | 361 +++++++++++++++++++++ control/noise/handshake_test.go | 290 +++++++++++++++++ control/noise/interop_test.go | 238 ++++++++++++++ control/noise/noiseexplorer_test.go | 475 ++++++++++++++++++++++++++++ scripts/check_license_headers.sh | 6 + 7 files changed, 2039 insertions(+) create mode 100644 control/noise/conn.go create mode 100644 control/noise/conn_test.go create mode 100644 control/noise/handshake.go create mode 100644 control/noise/handshake_test.go create mode 100644 control/noise/interop_test.go create mode 100644 control/noise/noiseexplorer_test.go diff --git a/control/noise/conn.go b/control/noise/conn.go new file mode 100644 index 000000000..efeb538d6 --- /dev/null +++ b/control/noise/conn.go @@ -0,0 +1,330 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package noise implements the base transport of the Tailscale 2021 +// control protocol. +// +// The base transport implements Noise IK, instantiated with +// Curve25519, ChaCha20Poly1305 and BLAKE2s. +package noise + +import ( + "crypto/cipher" + "encoding/binary" + "fmt" + "net" + "sync" + "time" + + "golang.org/x/crypto/blake2s" + chp "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/poly1305" + "tailscale.com/types/key" +) + +const ( + maxPlaintextSize = 4096 + maxCiphertextSize = maxPlaintextSize + poly1305.TagSize + maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header +) + +// A Conn is a secured Noise connection. It implements the net.Conn +// interface, with the unusual trait that any write error (including a +// SetWriteDeadline induced i/o timeout) cause all future writes to +// fail. +type Conn struct { + conn net.Conn + peer key.Public + handshakeHash [blake2s.Size]byte + rx rxState + tx txState +} + +// rxState is all the Conn state that Read uses. +type rxState struct { + sync.Mutex + cipher cipher.AEAD + nonce [chp.NonceSize]byte + buf [maxPacketSize]byte + n int // number of valid bytes in buf + next int // offset of next undecrypted packet + plaintext []byte // slice into buf of decrypted bytes +} + +// txState is all the Conn state that Write uses. +type txState struct { + sync.Mutex + cipher cipher.AEAD + nonce [chp.NonceSize]byte + buf [maxPacketSize]byte + err error // records the first partial write error for all future calls +} + +// HandshakeHash returns the Noise handshake hash for the connection, +// which can be used to bind other messages to this connection +// (i.e. to ensure that the message wasn't replayed from a different +// connection). +func (c *Conn) HandshakeHash() [blake2s.Size]byte { + return c.handshakeHash +} + +// Peer returns the peer's long-term public key. +func (c *Conn) Peer() key.Public { + return c.peer +} + +// validNonce reports whether nonce is in the valid range for use: 0 +// through 2^64-2. +func validNonce(nonce []byte) bool { + return binary.BigEndian.Uint32(nonce[:4]) == 0 && binary.BigEndian.Uint64(nonce[4:]) != invalidNonce +} + +// readNLocked reads into c.rxBuf until rxBuf contains at least total +// bytes. Returns a slice of the available bytes in rxBuf, or an +// error if fewer than total bytes are available. +func (c *Conn) readNLocked(total int) ([]byte, error) { + if total > maxPacketSize { + return nil, errReadTooBig{total} + } + for { + if total <= c.rx.n { + return c.rx.buf[:c.rx.n], nil + } + + n, err := c.conn.Read(c.rx.buf[c.rx.n:]) + c.rx.n += n + if err != nil { + return nil, err + } + } +} + +// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext +// to the decrypted bytes. Returns an error if the cipher is exhausted +// (i.e. can no longer be used safely) or decryption fails. +func (c *Conn) decryptLocked(ciphertext []byte) (err error) { + if !validNonce(c.rx.nonce[:]) { + return errCipherExhausted{} + } + + c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) + + // Safe to increment the nonce here, because we checked for nonce + // wraparound above. + binary.BigEndian.PutUint64(c.rx.nonce[4:], 1+binary.BigEndian.Uint64(c.rx.nonce[4:])) + + if err != nil { + // Once a decryption has failed, our Conn is no longer + // synchronized with our peer. Nuke the cipher state to be + // safe, so that no further decryptions are attempted. + c.rx.cipher = nil + } + return err +} + +// encryptLocked encrypts plaintext into c.tx.buf (including the +// 2-byte length header) and returns a slice of the ciphertext, or an +// error if the cipher is exhausted (i.e. can no longer be used safely). +func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) { + if !validNonce(c.tx.nonce[:]) { + // Received 2^64-1 messages on this cipher state. Connection + // is no longer usable. + return nil, errCipherExhausted{} + } + + binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize)) + ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil) + + // Safe to increment the nonce here, because we checked for nonce + // wraparound above. + binary.BigEndian.PutUint64(c.tx.nonce[4:], 1+binary.BigEndian.Uint64(c.tx.nonce[4:])) + + return ret, nil +} + +// wholeCiphertextLocked returns a slice of one whole Noise frame from +// c.rx.buf, if one whole ciphertext is available, and advances the +// read state to the next Noise frame in the buffer. Returns nil +// without advancing read state if there's not one whole ciphertext in +// c.rx.buf. +func (c *Conn) wholeCiphertextLocked() []byte { + available := c.rx.n - c.rx.next + if available < 2 { + return nil + } + bs := c.rx.buf[c.rx.next:c.rx.n] + totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2 + if len(bs) < totalSize { + return nil + } + c.rx.next += totalSize + return bs[:totalSize] +} + +// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed, +// and sets c.rx.plaintext to point to the decrypted +// bytes. c.rx.plaintext is only valid if err == nil. +func (c *Conn) decryptOneLocked() error { + c.rx.plaintext = nil + + // Fast path: do we have one whole ciphertext frame buffered + // already? + if bs := c.wholeCiphertextLocked(); bs != nil { + return c.decryptLocked(bs[2:]) + } + + if c.rx.next != 0 { + // To simplify the read logic, move the remainder of the + // buffered bytes back to the head of the buffer, so we can + // grow it without worrying about wraparound. + copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) + c.rx.n -= c.rx.next + c.rx.next = 0 + } + + bs, err := c.readNLocked(2) + if err != nil { + return err + } + totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2 + bs, err = c.readNLocked(totalLen) + if err != nil { + return err + } + + c.rx.next = totalLen + bs = bs[2:totalLen] + + return c.decryptLocked(bs) +} + +// Read implements io.Reader. +func (c *Conn) Read(bs []byte) (int, error) { + c.rx.Lock() + defer c.rx.Unlock() + + if c.rx.cipher == nil { + return 0, net.ErrClosed + } + // Loop to handle receiving a zero-byte Noise message. Just skip + // over it and keep decrypting until we find some bytes. + for len(c.rx.plaintext) == 0 { + if err := c.decryptOneLocked(); err != nil { + return 0, err + } + } + n := copy(bs, c.rx.plaintext) + c.rx.plaintext = c.rx.plaintext[n:] + return n, nil +} + +// Write implements io.Writer. +func (c *Conn) Write(bs []byte) (n int, err error) { + c.tx.Lock() + defer c.tx.Unlock() + + if c.tx.err != nil { + return 0, c.tx.err + } + defer func() { + if err != nil { + // All write errors are fatal for this conn, so clear the + // cipher state whenever an error happens. + c.tx.cipher = nil + } + if c.tx.err == nil { + // Only set c.tx.err if not nil so that we can return one + // error on the first failure, and a different one for + // subsequent calls. See the error handling around Write + // below for why. + c.tx.err = err + } + }() + + if c.tx.cipher == nil { + return 0, net.ErrClosed + } + + var sent int + for len(bs) > 0 { + toSend := bs + if len(toSend) > maxPlaintextSize { + toSend = bs[:maxPlaintextSize] + } + bs = bs[len(toSend):] + + ciphertext, err := c.encryptLocked(toSend) + if err != nil { + return 0, err + } + + if n, err := c.conn.Write(ciphertext); err != nil { + sent += n + // Return the raw error on the Write that actually + // failed. For future writes, return that error wrapped in + // a desync error. + c.tx.err = errPartialWrite{err} + return sent, err + } + sent += len(toSend) + } + return sent, nil +} + +// Close implements io.Closer. +func (c *Conn) Close() error { + closeErr := c.conn.Close() // unblocks any waiting reads or writes + c.rx.Lock() + c.rx.cipher = nil + c.rx.Unlock() + c.tx.Lock() + c.tx.cipher = nil + c.tx.Unlock() + return closeErr +} + +func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } +func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } +func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } +func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } +func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } + +// errCipherExhausted is the error returned when we run out of nonces +// on a cipher. +type errCipherExhausted struct{} + +func (errCipherExhausted) Error() string { + return "cipher exhausted, no more nonces available for current key" +} +func (errCipherExhausted) Timeout() bool { return false } +func (errCipherExhausted) Temporary() bool { return false } + +// errPartialWrite is the error returned when the cipher state has +// become unusable due to a past partial write. +type errPartialWrite struct { + err error +} + +func (e errPartialWrite) Error() string { + return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) +} +func (e errPartialWrite) Unwrap() error { return e.err } +func (e errPartialWrite) Temporary() bool { return false } +func (e errPartialWrite) Timeout() bool { return false } + +// errReadTooBig is the error returned when the peer sent an +// unacceptably large Noise frame. +type errReadTooBig struct { + requested int +} + +func (e errReadTooBig) Error() string { + return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) +} +func (e errReadTooBig) Temporary() bool { + // permanent error because this error only occurs when our peer + // sends us a frame so large we're unwilling to ever decode it. + return false +} +func (e errReadTooBig) Timeout() bool { return false } diff --git a/control/noise/conn_test.go b/control/noise/conn_test.go new file mode 100644 index 000000000..170a44b34 --- /dev/null +++ b/control/noise/conn_test.go @@ -0,0 +1,339 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package noise + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "net" + "strings" + "sync" + "testing" + "testing/iotest" + + chp "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/nettest" + tsnettest "tailscale.com/net/nettest" + "tailscale.com/types/key" +) + +func TestMessageSize(t *testing.T) { + // This test is a regression guard against someone looking at + // maxCiphertextSize, going "huh, we could be more efficient if it + // were larger, and accidentally violating the Noise spec. Do not + // change this max value, it's a deliberate limitation of the + // cryptographic protocol we use (see Section 3 "Message Format" + // of the Noise spec). + const max = 65535 + if maxCiphertextSize > max { + t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max) + } +} + +func TestConnBasic(t *testing.T) { + client, server := pair(t) + + sb := sinkReads(server) + + want := "test" + if _, err := io.WriteString(client, want); err != nil { + t.Fatalf("client write failed: %v", err) + } + client.Close() + + if got := sb.String(4); got != want { + t.Fatalf("wrong content received: got %q, want %q", got, want) + } + if err := sb.Error(); err != io.EOF { + t.Fatal("client close wasn't seen by server") + } + if sb.Total() != 4 { + t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total()) + } +} + +// bufferedWriteConn wraps a net.Conn and gives control over how +// Writes get batched out. +type bufferedWriteConn struct { + net.Conn + w *bufio.Writer + manualFlush bool +} + +func (c *bufferedWriteConn) Write(bs []byte) (int, error) { + n, err := c.w.Write(bs) + if err == nil && !c.manualFlush { + err = c.w.Flush() + } + return n, err +} + +// TestFastPath exercises the Read codepath that can receive multiple +// Noise frames at once and decode each in turn without making another +// syscall. +func TestFastPath(t *testing.T) { + s1, s2 := tsnettest.NewConn("noise", 128000) + b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false} + client, server := pairWithConns(t, b, s2) + + b.manualFlush = true + + sb := sinkReads(server) + + const packets = 10 + s := "test" + for i := 0; i < packets; i++ { + // Many separate writes, to force separate Noise frames that + // all get buffered up and then all sent as a single slice to + // the server. + if _, err := io.WriteString(client, s); err != nil { + t.Fatalf("client write1 failed: %v", err) + } + } + if err := b.w.Flush(); err != nil { + t.Fatalf("client flush failed: %v", err) + } + client.Close() + + want := strings.Repeat(s, packets) + if got := sb.String(len(want)); got != want { + t.Fatalf("wrong content received: got %q, want %q", got, want) + } + if err := sb.Error(); err != io.EOF { + t.Fatalf("client close wasn't seen by server") + } +} + +// Writes things larger than a single Noise frame, to check the +// chunking on the encoder and decoder. +func TestBigData(t *testing.T) { + client, server := pair(t) + + serverReads := sinkReads(server) + clientReads := sinkReads(client) + + const sz = 15 * 1024 // 15KiB + clientStr := strings.Repeat("abcde", sz/5) + serverStr := strings.Repeat("fghij", sz/5*2) + + if _, err := io.WriteString(client, clientStr); err != nil { + t.Fatalf("writing client>server: %v", err) + } + if _, err := io.WriteString(server, serverStr); err != nil { + t.Fatalf("writing server>client: %v", err) + } + + if serverGot := serverReads.String(sz); serverGot != clientStr { + t.Error("server didn't receive what client sent") + } + if clientGot := clientReads.String(2 * sz); clientGot != serverStr { + t.Error("client didn't receive what server sent") + } + + getNonce := func(n [chp.NonceSize]byte) uint64 { + if binary.BigEndian.Uint32(n[:4]) != 0 { + panic("unexpected nonce") + } + return binary.BigEndian.Uint64(n[4:]) + } + + // Reach into the Conns and verify the cipher nonces advanced as + // expected. + if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) { + t.Error("desynchronized client tx nonce") + } + if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) { + t.Error("desynchronized server tx nonce") + } + if n := getNonce(client.tx.nonce); n != 4 { + t.Errorf("wrong client tx nonce, got %d want 4", n) + } + if n := getNonce(server.tx.nonce); n != 8 { + t.Errorf("wrong client tx nonce, got %d want 8", n) + } +} + +// readerConn wraps a net.Conn and routes its Reads through a separate +// io.Reader. +type readerConn struct { + net.Conn + r io.Reader +} + +func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) } + +// Check that the receiver can handle not being able to read an entire +// frame in a single syscall. +func TestDataTrickle(t *testing.T) { + s1, s2 := tsnettest.NewConn("noise", 128000) + client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)}) + serverReads := sinkReads(server) + + const sz = 10000 + clientStr := strings.Repeat("abcde", sz/5) + if _, err := io.WriteString(client, clientStr); err != nil { + t.Fatalf("writing client>server: %v", err) + } + + serverGot := serverReads.String(sz) + if serverGot != clientStr { + t.Error("server didn't receive what client sent") + } +} + +func TestConnStd(t *testing.T) { + // You can run this test manually, and noise.Conn should pass all + // of them except for TestConn/PastTimeout, + // TestConn/FutureTimeout, TestConn/ConcurrentMethods, because + // those tests assume that write errors are recoverable, and + // they're not on our Conn due to cipher security. + t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977") + nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { + s1, s2 := tsnettest.NewConn("noise", 4096) + controlKey := key.NewPrivate() + machineKey := key.NewPrivate() + serverErr := make(chan error, 1) + go func() { + var err error + c2, err = Server(context.Background(), s2, controlKey) + serverErr <- err + }() + c1, err = Client(context.Background(), s1, machineKey, controlKey.Public()) + if err != nil { + s1.Close() + s2.Close() + return nil, nil, nil, fmt.Errorf("connecting client: %w", err) + } + if err := <-serverErr; err != nil { + c1.Close() + s1.Close() + s2.Close() + return nil, nil, nil, fmt.Errorf("connecting server: %w", err) + } + return c1, c2, func() { + c1.Close() + c2.Close() + }, nil + }) +} + +// mkConns creates synthetic Noise Conns wrapping the given net.Conns. +// This function is for testing just the Conn transport logic without +// having to muck about with Noise handshakes. +func mkConns(s1, s2 net.Conn) (*Conn, *Conn) { + var k1, k2 [chp.KeySize]byte + if _, err := rand.Read(k1[:]); err != nil { + panic(err) + } + if _, err := rand.Read(k2[:]); err != nil { + panic(err) + } + + ret1 := &Conn{ + conn: s1, + tx: txState{cipher: newCHP(k1)}, + rx: rxState{cipher: newCHP(k2)}, + } + ret2 := &Conn{ + conn: s2, + tx: txState{cipher: newCHP(k2)}, + rx: rxState{cipher: newCHP(k1)}, + } + + return ret1, ret2 +} + +type readSink struct { + r io.Reader + + cond *sync.Cond + sync.Mutex + bs bytes.Buffer + err error +} + +func sinkReads(r io.Reader) *readSink { + ret := &readSink{ + r: r, + } + ret.cond = sync.NewCond(&ret.Mutex) + go func() { + var buf [4096]byte + for { + n, err := r.Read(buf[:]) + ret.Lock() + ret.bs.Write(buf[:n]) + if err != nil { + ret.err = err + } + ret.cond.Broadcast() + ret.Unlock() + if err != nil { + return + } + } + }() + return ret +} + +func (s *readSink) String(total int) string { + s.Lock() + defer s.Unlock() + for s.bs.Len() < total && s.err == nil { + s.cond.Wait() + } + if s.err != nil { + total = s.bs.Len() + } + return string(s.bs.Bytes()[:total]) +} + +func (s *readSink) Error() error { + s.Lock() + defer s.Unlock() + for s.err == nil { + s.cond.Wait() + } + return s.err +} + +func (s *readSink) Total() int { + s.Lock() + defer s.Unlock() + return s.bs.Len() +} + +func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) { + var ( + controlKey = key.NewPrivate() + machineKey = key.NewPrivate() + server *Conn + serverErr = make(chan error, 1) + ) + go func() { + var err error + server, err = Server(context.Background(), serverConn, controlKey) + serverErr <- err + }() + + client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public()) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + if err := <-serverErr; err != nil { + t.Fatalf("server connection failed: %v", err) + } + return client, server +} + +func pair(t *testing.T) (*Conn, *Conn) { + s1, s2 := tsnettest.NewConn("noise", 128000) + return pairWithConns(t, s1, s2) +} diff --git a/control/noise/handshake.go b/control/noise/handshake.go new file mode 100644 index 000000000..910a7601c --- /dev/null +++ b/control/noise/handshake.go @@ -0,0 +1,361 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package noise + +import ( + "context" + "crypto/cipher" + "encoding/binary" + "fmt" + "hash" + "io" + "net" + "time" + + "golang.org/x/crypto/blake2s" + chp "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" + "tailscale.com/types/key" +) + +const ( + protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" + invalidNonce = ^uint64(0) +) + +// Client initiates a Noise client handshake, returning the resulting +// Noise connection. +// +// The context deadline, if any, covers the entire handshaking +// process. +func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlKey key.Public) (*Conn, error) { + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() + } + + var s symmetricState + s.Initialize() + + // <- s + // ... + s.MixHash(controlKey[:]) + + var init initiationMessage + // -> e, es, s, ss + machineEphemeral := key.NewPrivate() + machineEphemeralPub := machineEphemeral.Public() + copy(init.MachineEphemeralPub(), machineEphemeralPub[:]) + s.MixHash(machineEphemeralPub[:]) + if err := s.MixDH(machineEphemeral, controlKey); err != nil { + return nil, fmt.Errorf("computing es: %w", err) + } + machineKeyPub := machineKey.Public() + copy(init.MachinePub(), s.EncryptAndHash(machineKeyPub[:])) + if err := s.MixDH(machineKey, controlKey); err != nil { + return nil, fmt.Errorf("computing ss: %w", err) + } + copy(init.Tag(), s.EncryptAndHash(nil)) // empty message payload + + if _, err := conn.Write(init[:]); err != nil { + return nil, fmt.Errorf("writing initiation: %w", err) + } + + // <- e, ee, se + var resp responseMessage + if _, err := io.ReadFull(conn, resp[:]); err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + var controlEphemeralPub key.Public + copy(controlEphemeralPub[:], resp.ControlEphemeralPub()) + s.MixHash(controlEphemeralPub[:]) + if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { + return nil, fmt.Errorf("computing ee: %w", err) + } + if err := s.MixDH(machineKey, controlEphemeralPub); err != nil { + return nil, fmt.Errorf("computing se: %w", err) + } + if _, err := s.DecryptAndHash(resp.Tag()); err != nil { + return nil, fmt.Errorf("decrypting payload: %w", err) + } + + c1, c2, err := s.Split() + if err != nil { + return nil, fmt.Errorf("finalizing handshake: %w", err) + } + + return &Conn{ + conn: conn, + peer: controlKey, + handshakeHash: s.h, + tx: txState{ + cipher: c1, + }, + rx: rxState{ + cipher: c2, + }, + }, nil +} + +// Server initiates a Noise server handshake, returning the resulting +// Noise connection. +// +// The context deadline, if any, covers the entire handshaking +// process. +func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, error) { + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() + } + + var s symmetricState + s.Initialize() + + // <- s + // ... + controlKeyPub := controlKey.Public() + s.MixHash(controlKeyPub[:]) + + // -> e, es, s, ss + var init initiationMessage + if _, err := io.ReadFull(conn, init[:]); err != nil { + return nil, fmt.Errorf("reading initiation: %w", err) + } + + var machineEphemeralPub key.Public + copy(machineEphemeralPub[:], init.MachineEphemeralPub()) + s.MixHash(machineEphemeralPub[:]) + if err := s.MixDH(controlKey, machineEphemeralPub); err != nil { + return nil, fmt.Errorf("computing es: %w", err) + } + var machineKey key.Public + rs, err := s.DecryptAndHash(init.MachinePub()) + if err != nil { + return nil, fmt.Errorf("decrypting machine key: %w", err) + } + copy(machineKey[:], rs) + if err := s.MixDH(controlKey, machineKey); err != nil { + return nil, fmt.Errorf("computing ss: %w", err) + } + if _, err := s.DecryptAndHash(init.Tag()); err != nil { + return nil, fmt.Errorf("decrypting initiation tag: %w", err) + } + + // <- e, ee, se + var resp responseMessage + controlEphemeral := key.NewPrivate() + controlEphemeralPub := controlEphemeral.Public() + copy(resp.ControlEphemeralPub(), controlEphemeralPub[:]) + s.MixHash(controlEphemeralPub[:]) + if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { + return nil, fmt.Errorf("computing ee: %w", err) + } + if err := s.MixDH(controlEphemeral, machineKey); err != nil { + return nil, fmt.Errorf("computing se: %w", err) + } + copy(resp.Tag(), s.EncryptAndHash(nil)) // empty message payload + + c1, c2, err := s.Split() + if err != nil { + return nil, fmt.Errorf("finalizing handshake: %w", err) + } + + if _, err := conn.Write(resp[:]); err != nil { + return nil, err + } + + return &Conn{ + conn: conn, + peer: machineKey, + handshakeHash: s.h, + tx: txState{ + cipher: c2, + }, + rx: rxState{ + cipher: c1, + }, + }, nil +} + +// initiationMessage is the Noise protocol message sent from a client +// machine to a control server. +type initiationMessage [96]byte + +func (m *initiationMessage) MachineEphemeralPub() []byte { return m[:32] } +func (m *initiationMessage) MachinePub() []byte { return m[32:80] } +func (m *initiationMessage) Tag() []byte { return m[80:] } + +// responseMessage is the Noise protocol message sent from a control +// server to a client machine. +type responseMessage [48]byte + +func (m *responseMessage) ControlEphemeralPub() []byte { return m[:32] } +func (m *responseMessage) Tag() []byte { return m[32:] } + +// symmetricState is the SymmetricState object from the Noise protocol +// spec. It contains all the symmetric cipher state of an in-flight +// handshake. Field names match the variable names in the spec. +type symmetricState struct { + h [blake2s.Size]byte + ck [blake2s.Size]byte + + k [chp.KeySize]byte + n uint64 + + mixer hash.Hash // for updating h +} + +// Initialize sets s to the initial handshake state, prior to +// processing any Noise messages. +func (s *symmetricState) Initialize() { + if s.mixer != nil { + panic("symmetricState cannot be reused") + } + s.h = blake2s.Sum256([]byte(protocolName)) + s.ck = s.h + s.k = [chp.KeySize]byte{} + s.n = invalidNonce + s.mixer = newBLAKE2s() + // Mix in an empty prologue. + s.MixHash(nil) +} + +// MixHash updates s.h to be BLAKE2s(s.h || data), where || is +// concatenation. +func (s *symmetricState) MixHash(data []byte) { + s.mixer.Reset() + s.mixer.Write(s.h[:]) + s.mixer.Write(data) + s.mixer.Sum(s.h[:0]) // TODO: check this actually updates s.h correctly... +} + +// MixDH updates s.ck and s.k with the result of X25519(priv, pub). +// +// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing +// it as a single function allows for strongly-typed arguments that +// reduce the risk of error in the caller (e.g. invoking X25519 with +// two private keys, or two public keys), and thus producing the wrong +// calculation. +func (s *symmetricState) MixDH(priv key.Private, pub key.Public) error { + // TODO(danderson): check that this operation is correct. The docs + // for X25519 say that the 2nd arg must be either Basepoint or the + // output of another X25519 call. + // + // I think this is correct, because pub is the result of a + // ScalarBaseMult on the private key, and our private key + // generation code clamps keys to avoid low order points. I + // believe that makes pub equivalent to the output of + // X25519(privateKey, Basepoint), and so the contract is + // respected. + keyData, err := curve25519.X25519(priv[:], pub[:]) + if err != nil { + return fmt.Errorf("computing X25519: %w", err) + } + + r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) + if _, err := io.ReadFull(r, s.ck[:]); err != nil { + return fmt.Errorf("extracting ck: %w", err) + } + if _, err := io.ReadFull(r, s.k[:]); err != nil { + return fmt.Errorf("extracting k: %w", err) + } + s.n = 0 + return nil +} + +// EncryptAndHash encrypts the given plaintext using the current s.k, +// mixes the ciphertext into s.h, and returns the ciphertext. +func (s *symmetricState) EncryptAndHash(plaintext []byte) []byte { + if s.n == invalidNonce { + // Noise in general permits writing "ciphertext" without a + // key, but in IK it cannot happen. + panic("attempted encryption with uninitialized key") + } + aead := newCHP(s.k) + var nonce [chp.NonceSize]byte + binary.BigEndian.PutUint64(nonce[4:], s.n) + s.n++ + ret := aead.Seal(nil, nonce[:], plaintext, s.h[:]) + s.MixHash(ret) + return ret +} + +// DecryptAndHash decrypts the given ciphertext using the current +// s.k. If decryption is successful, it mixes the ciphertext into s.h +// and returns the plaintext. +func (s *symmetricState) DecryptAndHash(ciphertext []byte) ([]byte, error) { + if s.n == invalidNonce { + // Noise in general permits "ciphertext" without a key, but in + // IK it cannot happen. + panic("attempted encryption with uninitialized key") + } + aead := newCHP(s.k) + var nonce [chp.NonceSize]byte + binary.BigEndian.PutUint64(nonce[4:], s.n) + s.n++ + ret, err := aead.Open(nil, nonce[:], ciphertext, s.h[:]) + if err != nil { + return nil, err + } + s.MixHash(ciphertext) + return ret, nil +} + +// Split returns two ChaCha20Poly1305 ciphers with keys derives from +// the current handshake state. Methods on s must not be used again +// after calling Split(). +func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { + var k1, k2 [chp.KeySize]byte + r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) + if _, err := io.ReadFull(r, k1[:]); err != nil { + return nil, nil, fmt.Errorf("extracting k1: %w", err) + } + if _, err := io.ReadFull(r, k2[:]); err != nil { + return nil, nil, fmt.Errorf("extracting k2: %w", err) + } + c1, err = chp.New(k1[:]) + if err != nil { + return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) + } + c2, err = chp.New(k2[:]) + if err != nil { + return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) + } + return c1, c2, nil +} + +// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on +// error. +func newBLAKE2s() hash.Hash { + h, err := blake2s.New256(nil) + if err != nil { + // Should never happen, errors only happen when using BLAKE2s + // in MAC mode with a key. + panic(fmt.Sprintf("blake2s construction: %v", err)) + } + return h +} + +// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or +// panics on error. +func newCHP(key [chp.KeySize]byte) cipher.AEAD { + aead, err := chp.New(key[:]) + if err != nil { + // Can only happen if we passed a key of the wrong length. The + // function signature prevents that. + panic(fmt.Sprintf("chacha20poly1305 construction: %v", err)) + } + return aead +} diff --git a/control/noise/handshake_test.go b/control/noise/handshake_test.go new file mode 100644 index 000000000..172ee0ff8 --- /dev/null +++ b/control/noise/handshake_test.go @@ -0,0 +1,290 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package noise + +import ( + "bytes" + "context" + "io" + "strings" + "testing" + "time" + + tsnettest "tailscale.com/net/nettest" + "tailscale.com/types/key" +) + +func TestHandshake(t *testing.T) { + var ( + clientConn, serverConn = tsnettest.NewConn("noise", 128000) + serverKey = key.NewPrivate() + clientKey = key.NewPrivate() + server *Conn + serverErr = make(chan error, 1) + ) + go func() { + var err error + server, err = Server(context.Background(), serverConn, serverKey) + serverErr <- err + }() + + client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + if err := <-serverErr; err != nil { + t.Fatalf("server connection failed: %v", err) + } + + if client.HandshakeHash() != server.HandshakeHash() { + t.Fatal("client and server disagree on handshake hash") + } + + if client.Peer() != serverKey.Public() { + t.Fatal("client peer key isn't serverKey") + } + if server.Peer() != clientKey.Public() { + t.Fatal("client peer key isn't serverKey") + } +} + +// Check that handshaking repeatedly with the same long-term keys +// result in different handshake hashes and wire traffic. +func TestNoReuse(t *testing.T) { + var ( + hashes = map[[32]byte]bool{} + clientHandshakes = map[[96]byte]bool{} + serverHandshakes = map[[48]byte]bool{} + packets = map[[32]byte]bool{} + ) + for i := 0; i < 10; i++ { + var ( + clientRaw, serverRaw = tsnettest.NewConn("noise", 128000) + clientBuf, serverBuf bytes.Buffer + clientConn = &readerConn{clientRaw, io.TeeReader(clientRaw, &clientBuf)} + serverConn = &readerConn{serverRaw, io.TeeReader(serverRaw, &serverBuf)} + serverKey = key.NewPrivate() + clientKey = key.NewPrivate() + server *Conn + serverErr = make(chan error, 1) + ) + go func() { + var err error + server, err = Server(context.Background(), serverConn, serverKey) + serverErr <- err + }() + + client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + if err := <-serverErr; err != nil { + t.Fatalf("server connection failed: %v", err) + } + + var clientHS [96]byte + copy(clientHS[:], serverBuf.Bytes()) + if clientHandshakes[clientHS] { + t.Fatal("client handshake seen twice") + } + clientHandshakes[clientHS] = true + + var serverHS [48]byte + copy(serverHS[:], clientBuf.Bytes()) + if serverHandshakes[serverHS] { + t.Fatal("server handshake seen twice") + } + serverHandshakes[serverHS] = true + + clientBuf.Reset() + serverBuf.Reset() + cb := sinkReads(client) + sb := sinkReads(server) + + if hashes[client.HandshakeHash()] { + t.Fatalf("handshake hash %v seen twice", client.HandshakeHash()) + } + hashes[client.HandshakeHash()] = true + + // Sending 14 bytes turns into 32 bytes on the wire (+16 for + // the poly1305 tag, +2 length header) + if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil { + t.Fatalf("client>server write failed: %v", err) + } + if _, err := io.WriteString(server, strings.Repeat("b", 14)); err != nil { + t.Fatalf("server>client write failed: %v", err) + } + + // Wait for the bytes to be read, so we know they've traveled end to end + cb.String(14) + sb.String(14) + + var clientWire, serverWire [32]byte + copy(clientWire[:], clientBuf.Bytes()) + copy(serverWire[:], serverBuf.Bytes()) + + if packets[clientWire] { + t.Fatalf("client wire traffic seen twice") + } + packets[clientWire] = true + if packets[serverWire] { + t.Fatalf("server wire traffic seen twice") + } + packets[serverWire] = true + } +} + +// tamperReader wraps a reader and mutates the Nth byte. +type tamperReader struct { + r io.Reader + n int + total int +} + +func (r *tamperReader) Read(bs []byte) (int, error) { + n, err := r.r.Read(bs) + if off := r.n - r.total; off >= 0 && off < n { + bs[off] += 1 + } + r.total += n + return n, err +} + +func TestTampering(t *testing.T) { + // Tamper with every byte of the client initiation message. + for i := 0; i < 96; i++ { + var ( + clientConn, serverRaw = tsnettest.NewConn("noise", 128000) + serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}} + serverKey = key.NewPrivate() + clientKey = key.NewPrivate() + serverErr = make(chan error, 1) + ) + go func() { + _, err := Server(context.Background(), serverConn, serverKey) + // If the server failed, we have to close the Conn to + // unblock the client. + if err != nil { + serverConn.Close() + } + serverErr <- err + }() + + _, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + if err == nil { + t.Fatal("client connection succeeded despite tampering") + } + if err := <-serverErr; err == nil { + t.Fatalf("server connection succeeded despite tampering") + } + } + + // Tamper with every byte of the server response message. + for i := 0; i < 48; i++ { + var ( + clientRaw, serverConn = tsnettest.NewConn("noise", 128000) + clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}} + serverKey = key.NewPrivate() + clientKey = key.NewPrivate() + serverErr = make(chan error, 1) + ) + go func() { + _, err := Server(context.Background(), serverConn, serverKey) + serverErr <- err + }() + + _, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + if err == nil { + t.Fatal("client connection succeeded despite tampering") + } + // The server shouldn't fail, because the tampering took place + // in its response. + if err := <-serverErr; err != nil { + t.Fatalf("server connection failed despite no tampering: %v", err) + } + } + + // Tamper with every byte of the first server>client transport message. + for i := 0; i < 32; i++ { + var ( + clientRaw, serverConn = tsnettest.NewConn("noise", 128000) + clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 48 + i, 0}} + serverKey = key.NewPrivate() + clientKey = key.NewPrivate() + serverErr = make(chan error, 1) + ) + go func() { + server, err := Server(context.Background(), serverConn, serverKey) + serverErr <- err + _, err = io.WriteString(server, strings.Repeat("a", 14)) + serverErr <- err + }() + + client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + if err != nil { + t.Fatalf("client handshake failed: %v", err) + } + // The server shouldn't fail, because the tampering took place + // in its response. + if err := <-serverErr; err != nil { + t.Fatalf("server handshake failed: %v", err) + } + + // The client needs a timeout if the tampering is hitting the length header. + if i == 0 || i == 1 { + client.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + } + + var bs [100]byte + n, err := client.Read(bs[:]) + if err == nil { + t.Fatal("read succeeded despite tampering") + } + if n != 0 { + t.Fatal("conn yielded some bytes despite tampering") + } + } + + // Tamper with every byte of the first client>server transport message. + for i := 0; i < 32; i++ { + var ( + clientConn, serverRaw = tsnettest.NewConn("noise", 128000) + serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 96 + i, 0}} + serverKey = key.NewPrivate() + clientKey = key.NewPrivate() + serverErr = make(chan error, 1) + ) + go func() { + server, err := Server(context.Background(), serverConn, serverKey) + serverErr <- err + var bs [100]byte + // The server needs a timeout if the tampering is hitting the length header. + if i == 0 || i == 1 { + server.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + } + n, err := server.Read(bs[:]) + if n != 0 { + panic("server got bytes despite tampering") + } else { + serverErr <- err + } + }() + + client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public()) + if err != nil { + t.Fatalf("client handshake failed: %v", err) + } + if err := <-serverErr; err != nil { + t.Fatalf("server handshake failed: %v", err) + } + + if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil { + t.Fatalf("client>server write failed: %v", err) + } + if err := <-serverErr; err == nil { + t.Fatal("server successfully received bytes despite tampering") + } + } +} diff --git a/control/noise/interop_test.go b/control/noise/interop_test.go new file mode 100644 index 000000000..7f9b0926a --- /dev/null +++ b/control/noise/interop_test.go @@ -0,0 +1,238 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package noise + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "testing" + + tsnettest "tailscale.com/net/nettest" + "tailscale.com/types/key" +) + +// Can a reference Noise IK client talk to our server? +func TestInteropClient(t *testing.T) { + var ( + s1, s2 = tsnettest.NewConn("noise", 128000) + controlKey = key.NewPrivate() + machineKey = key.NewPrivate() + serverErr = make(chan error, 2) + serverBytes = make(chan []byte, 1) + c2s = "client>server" + s2c = "server>client" + ) + + go func() { + server, err := Server(context.Background(), s2, controlKey) + serverErr <- err + if err != nil { + return + } + var buf [1024]byte + _, err = io.ReadFull(server, buf[:len(c2s)]) + serverBytes <- buf[:len(c2s)] + if err != nil { + serverErr <- err + return + } + _, err = server.Write([]byte(s2c)) + serverErr <- err + }() + + gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) + if err != nil { + t.Fatalf("failed client interop: %v", err) + } + if string(gotS2C) != s2c { + t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) + } + + if err := <-serverErr; err != nil { + t.Fatalf("server handshake failed: %v", err) + } + if err := <-serverErr; err != nil { + t.Fatalf("server read/write failed: %v", err) + } + if got := string(<-serverBytes); got != c2s { + t.Fatalf("server received %q, want %q", got, c2s) + } +} + +// Can our client talk to a reference Noise IK server? +func TestInteropServer(t *testing.T) { + var ( + s1, s2 = tsnettest.NewConn("noise", 128000) + controlKey = key.NewPrivate() + machineKey = key.NewPrivate() + clientErr = make(chan error, 2) + clientBytes = make(chan []byte, 1) + c2s = "client>server" + s2c = "server>client" + ) + + go func() { + client, err := Client(context.Background(), s1, machineKey, controlKey.Public()) + clientErr <- err + if err != nil { + return + } + _, err = client.Write([]byte(c2s)) + if err != nil { + clientErr <- err + return + } + var buf [1024]byte + _, err = io.ReadFull(client, buf[:len(s2c)]) + clientBytes <- buf[:len(s2c)] + clientErr <- err + }() + + gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) + if err != nil { + t.Fatalf("failed server interop: %v", err) + } + if string(gotC2S) != c2s { + t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) + } + + if err := <-clientErr; err != nil { + t.Fatalf("client handshake failed: %v", err) + } + if err := <-clientErr; err != nil { + t.Fatalf("client read/write failed: %v", err) + } + if got := string(<-clientBytes); got != s2c { + t.Fatalf("client received %q, want %q", got, s2c) + } +} + +// noiseExplorerClient uses the Noise Explorer implementation of Noise +// IK to handshake as a Noise client on conn, transmit payload, and +// read+return a payload from the peer. +func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Private, payload []byte) ([]byte, error) { + mk := keypair{ + private_key: machineKey, + public_key: machineKey.Public(), + } + session := InitSession(true, nil, mk, controlKey) + + _, msg1 := SendMessage(&session, nil) + if _, err := conn.Write(msg1.ne[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ns); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ciphertext); err != nil { + return nil, err + } + + var buf [1024]byte + if _, err := io.ReadFull(conn, buf[:48]); err != nil { + return nil, err + } + msg2 := messagebuffer{ + ciphertext: buf[32:48], + } + copy(msg2.ne[:], buf[:32]) + _, p, valid := RecvMessage(&session, &msg2) + if !valid { + return nil, errors.New("handshake failed") + } + if len(p) != 0 { + return nil, errors.New("non-empty payload") + } + + _, msg3 := SendMessage(&session, payload) + binary.BigEndian.PutUint16(buf[:2], uint16(len(msg3.ciphertext))) + if _, err := conn.Write(buf[:2]); err != nil { + return nil, err + } + if _, err := conn.Write(msg3.ciphertext); err != nil { + return nil, err + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return nil, err + } + plen := int(binary.BigEndian.Uint16(buf[:2])) + if _, err := io.ReadFull(conn, buf[:plen]); err != nil { + return nil, err + } + + msg4 := messagebuffer{ + ciphertext: buf[:plen], + } + _, p, valid = RecvMessage(&session, &msg4) + if !valid { + return nil, errors.New("transport message decryption failed") + } + + return p, nil +} + +func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey key.Public, payload []byte) ([]byte, error) { + mk := keypair{ + private_key: controlKey, + public_key: controlKey.Public(), + } + session := InitSession(false, nil, mk, [32]byte{}) + + var buf [1024]byte + if _, err := io.ReadFull(conn, buf[:96]); err != nil { + return nil, err + } + msg1 := messagebuffer{ + ns: buf[32:80], + ciphertext: buf[80:96], + } + copy(msg1.ne[:], buf[:32]) + _, p, valid := RecvMessage(&session, &msg1) + if !valid { + return nil, errors.New("handshake failed") + } + if len(p) != 0 { + return nil, errors.New("non-empty payload") + } + + _, msg2 := SendMessage(&session, nil) + if _, err := conn.Write(msg2.ne[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg2.ciphertext[:]); err != nil { + return nil, err + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return nil, err + } + plen := int(binary.BigEndian.Uint16(buf[:2])) + if _, err := io.ReadFull(conn, buf[:plen]); err != nil { + return nil, err + } + + msg3 := messagebuffer{ + ciphertext: buf[:plen], + } + _, p, valid = RecvMessage(&session, &msg3) + if !valid { + return nil, errors.New("transport message decryption failed") + } + + _, msg4 := SendMessage(&session, payload) + binary.BigEndian.PutUint16(buf[:2], uint16(len(msg4.ciphertext))) + if _, err := conn.Write(buf[:2]); err != nil { + return nil, err + } + if _, err := conn.Write(msg4.ciphertext); err != nil { + return nil, err + } + + return p, nil +} diff --git a/control/noise/noiseexplorer_test.go b/control/noise/noiseexplorer_test.go new file mode 100644 index 000000000..cd70be713 --- /dev/null +++ b/control/noise/noiseexplorer_test.go @@ -0,0 +1,475 @@ +// This file contains the implementation of Noise IK from +// https://noiseexplorer.com/ . Unlike the rest of this repository, +// this file is licensed under the terms of the GNU GPL v3. See +// https://source.symbolic.software/noiseexplorer/noiseexplorer for +// more information. +// +// This file is used here to verify that Tailscale's implementation of +// Noise IK is interoperable with another implementation. +//lint:file-ignore SA4006 not our code. + +/* +IK: + <- s + ... + -> e, es, s, ss + <- e, ee, se + -> + <- +*/ + +// Implementation Version: 1.0.2 + +/* ---------------------------------------------------------------- * + * PARAMETERS * + * ---------------------------------------------------------------- */ + +package noise + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/binary" + "hash" + "io" + "math" + + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +/* ---------------------------------------------------------------- * + * TYPES * + * ---------------------------------------------------------------- */ + +type keypair struct { + public_key [32]byte + private_key [32]byte +} + +type messagebuffer struct { + ne [32]byte + ns []byte + ciphertext []byte +} + +type cipherstate struct { + k [32]byte + n uint32 +} + +type symmetricstate struct { + cs cipherstate + ck [32]byte + h [32]byte +} + +type handshakestate struct { + ss symmetricstate + s keypair + e keypair + rs [32]byte + re [32]byte + psk [32]byte +} + +type noisesession struct { + hs handshakestate + h [32]byte + cs1 cipherstate + cs2 cipherstate + mc uint64 + i bool +} + +/* ---------------------------------------------------------------- * + * CONSTANTS * + * ---------------------------------------------------------------- */ + +var emptyKey = [32]byte{ + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, +} + +var minNonce = uint32(0) + +/* ---------------------------------------------------------------- * + * UTILITY FUNCTIONS * + * ---------------------------------------------------------------- */ + +func getPublicKey(kp *keypair) [32]byte { + return kp.public_key +} + +func isEmptyKey(k [32]byte) bool { + return subtle.ConstantTimeCompare(k[:], emptyKey[:]) == 1 +} + +func validatePublicKey(k []byte) bool { + forbiddenCurveValues := [12][]byte{ + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {224, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 0}, + {95, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 87}, + {236, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, + {237, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, + {238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, + {205, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 128}, + {76, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 215}, + {217, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, + {218, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, + {219, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 25}, + } + + for _, testValue := range forbiddenCurveValues { + if subtle.ConstantTimeCompare(k[:], testValue[:]) == 1 { + panic("Invalid public key") + } + } + return true +} + +/* ---------------------------------------------------------------- * + * PRIMITIVES * + * ---------------------------------------------------------------- */ + +func incrementNonce(n uint32) uint32 { + return n + 1 +} + +func dh(private_key [32]byte, public_key [32]byte) [32]byte { + var ss [32]byte + curve25519.ScalarMult(&ss, &private_key, &public_key) + return ss +} + +func generateKeypair() keypair { + var public_key [32]byte + var private_key [32]byte + _, _ = rand.Read(private_key[:]) + curve25519.ScalarBaseMult(&public_key, &private_key) + if validatePublicKey(public_key[:]) { + return keypair{public_key, private_key} + } + return generateKeypair() +} + +func generatePublicKey(private_key [32]byte) [32]byte { + var public_key [32]byte + curve25519.ScalarBaseMult(&public_key, &private_key) + return public_key +} + +func encrypt(k [32]byte, n uint32, ad []byte, plaintext []byte) []byte { + var nonce [12]byte + var ciphertext []byte + enc, _ := chacha20poly1305.New(k[:]) + binary.LittleEndian.PutUint32(nonce[4:], n) + ciphertext = enc.Seal(nil, nonce[:], plaintext, ad) + return ciphertext +} + +func decrypt(k [32]byte, n uint32, ad []byte, ciphertext []byte) (bool, []byte, []byte) { + var nonce [12]byte + var plaintext []byte + enc, err := chacha20poly1305.New(k[:]) + binary.LittleEndian.PutUint32(nonce[4:], n) + plaintext, err = enc.Open(nil, nonce[:], ciphertext, ad) + return (err == nil), ad, plaintext +} + +func getHash(a []byte, b []byte) [32]byte { + return blake2s.Sum256(append(a, b...)) +} + +func hashProtocolName(protocolName []byte) [32]byte { + var h [32]byte + if len(protocolName) <= 32 { + copy(h[:], protocolName) + } else { + h = getHash(protocolName, []byte{}) + } + return h +} + +func blake2HkdfInterface() hash.Hash { + h, _ := blake2s.New256([]byte{}) + return h +} + +func getHkdf(ck [32]byte, ikm []byte) ([32]byte, [32]byte, [32]byte) { + var k1 [32]byte + var k2 [32]byte + var k3 [32]byte + output := hkdf.New(blake2HkdfInterface, ikm[:], ck[:], []byte{}) + io.ReadFull(output, k1[:]) + io.ReadFull(output, k2[:]) + io.ReadFull(output, k3[:]) + return k1, k2, k3 +} + +/* ---------------------------------------------------------------- * + * STATE MANAGEMENT * + * ---------------------------------------------------------------- */ + +/* CipherState */ +func initializeKey(k [32]byte) cipherstate { + return cipherstate{k, minNonce} +} + +func hasKey(cs *cipherstate) bool { + return !isEmptyKey(cs.k) +} + +func setNonce(cs *cipherstate, newNonce uint32) *cipherstate { + cs.n = newNonce + return cs +} + +func encryptWithAd(cs *cipherstate, ad []byte, plaintext []byte) (*cipherstate, []byte) { + e := encrypt(cs.k, cs.n, ad, plaintext) + cs = setNonce(cs, incrementNonce(cs.n)) + return cs, e +} + +func decryptWithAd(cs *cipherstate, ad []byte, ciphertext []byte) (*cipherstate, []byte, bool) { + valid, ad, plaintext := decrypt(cs.k, cs.n, ad, ciphertext) + cs = setNonce(cs, incrementNonce(cs.n)) + return cs, plaintext, valid +} + +func reKey(cs *cipherstate) *cipherstate { + e := encrypt(cs.k, math.MaxUint32, []byte{}, emptyKey[:]) + copy(cs.k[:], e) + return cs +} + +/* SymmetricState */ + +func initializeSymmetric(protocolName []byte) symmetricstate { + h := hashProtocolName(protocolName) + ck := h + cs := initializeKey(emptyKey) + return symmetricstate{cs, ck, h} +} + +func mixKey(ss *symmetricstate, ikm [32]byte) *symmetricstate { + ck, tempK, _ := getHkdf(ss.ck, ikm[:]) + ss.cs = initializeKey(tempK) + ss.ck = ck + return ss +} + +func mixHash(ss *symmetricstate, data []byte) *symmetricstate { + ss.h = getHash(ss.h[:], data) + return ss +} + +func mixKeyAndHash(ss *symmetricstate, ikm [32]byte) *symmetricstate { + var tempH [32]byte + var tempK [32]byte + ss.ck, tempH, tempK = getHkdf(ss.ck, ikm[:]) + ss = mixHash(ss, tempH[:]) + ss.cs = initializeKey(tempK) + return ss +} + +func getHandshakeHash(ss *symmetricstate) [32]byte { + return ss.h +} + +func encryptAndHash(ss *symmetricstate, plaintext []byte) (*symmetricstate, []byte) { + var ciphertext []byte + if hasKey(&ss.cs) { + _, ciphertext = encryptWithAd(&ss.cs, ss.h[:], plaintext) + } else { + ciphertext = plaintext + } + ss = mixHash(ss, ciphertext) + return ss, ciphertext +} + +func decryptAndHash(ss *symmetricstate, ciphertext []byte) (*symmetricstate, []byte, bool) { + var plaintext []byte + var valid bool + if hasKey(&ss.cs) { + _, plaintext, valid = decryptWithAd(&ss.cs, ss.h[:], ciphertext) + } else { + plaintext, valid = ciphertext, true + } + ss = mixHash(ss, ciphertext) + return ss, plaintext, valid +} + +func split(ss *symmetricstate) (cipherstate, cipherstate) { + tempK1, tempK2, _ := getHkdf(ss.ck, []byte{}) + cs1 := initializeKey(tempK1) + cs2 := initializeKey(tempK2) + return cs1, cs2 +} + +/* HandshakeState */ + +func initializeInitiator(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate { + var ss symmetricstate + var e keypair + var re [32]byte + name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s") + ss = initializeSymmetric(name) + mixHash(&ss, prologue) + mixHash(&ss, rs[:]) + return handshakestate{ss, s, e, rs, re, psk} +} + +func initializeResponder(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate { + var ss symmetricstate + var e keypair + var re [32]byte + name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s") + ss = initializeSymmetric(name) + mixHash(&ss, prologue) + mixHash(&ss, s.public_key[:]) + return handshakestate{ss, s, e, rs, re, psk} +} + +func writeMessageA(hs *handshakestate, payload []byte) (*handshakestate, messagebuffer) { + ne, ns, ciphertext := emptyKey, []byte{}, []byte{} + hs.e = generateKeypair() + ne = hs.e.public_key + mixHash(&hs.ss, ne[:]) + /* No PSK, so skipping mixKey */ + mixKey(&hs.ss, dh(hs.e.private_key, hs.rs)) + spk := make([]byte, len(hs.s.public_key)) + copy(spk[:], hs.s.public_key[:]) + _, ns = encryptAndHash(&hs.ss, spk) + mixKey(&hs.ss, dh(hs.s.private_key, hs.rs)) + _, ciphertext = encryptAndHash(&hs.ss, payload) + messageBuffer := messagebuffer{ne, ns, ciphertext} + return hs, messageBuffer +} + +func writeMessageB(hs *handshakestate, payload []byte) ([32]byte, messagebuffer, cipherstate, cipherstate) { + ne, ns, ciphertext := emptyKey, []byte{}, []byte{} + hs.e = generateKeypair() + ne = hs.e.public_key + mixHash(&hs.ss, ne[:]) + /* No PSK, so skipping mixKey */ + mixKey(&hs.ss, dh(hs.e.private_key, hs.re)) + mixKey(&hs.ss, dh(hs.e.private_key, hs.rs)) + _, ciphertext = encryptAndHash(&hs.ss, payload) + messageBuffer := messagebuffer{ne, ns, ciphertext} + cs1, cs2 := split(&hs.ss) + return hs.ss.h, messageBuffer, cs1, cs2 +} + +func writeMessageRegular(cs *cipherstate, payload []byte) (*cipherstate, messagebuffer) { + ne, ns, ciphertext := emptyKey, []byte{}, []byte{} + cs, ciphertext = encryptWithAd(cs, []byte{}, payload) + messageBuffer := messagebuffer{ne, ns, ciphertext} + return cs, messageBuffer +} + +func readMessageA(hs *handshakestate, message *messagebuffer) (*handshakestate, []byte, bool) { + valid1 := true + if validatePublicKey(message.ne[:]) { + hs.re = message.ne + } + mixHash(&hs.ss, hs.re[:]) + /* No PSK, so skipping mixKey */ + mixKey(&hs.ss, dh(hs.s.private_key, hs.re)) + _, ns, valid1 := decryptAndHash(&hs.ss, message.ns) + if valid1 && len(ns) == 32 && validatePublicKey(message.ns[:]) { + copy(hs.rs[:], ns) + } + mixKey(&hs.ss, dh(hs.s.private_key, hs.rs)) + _, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext) + return hs, plaintext, (valid1 && valid2) +} + +func readMessageB(hs *handshakestate, message *messagebuffer) ([32]byte, []byte, bool, cipherstate, cipherstate) { + valid1 := true + if validatePublicKey(message.ne[:]) { + hs.re = message.ne + } + mixHash(&hs.ss, hs.re[:]) + /* No PSK, so skipping mixKey */ + mixKey(&hs.ss, dh(hs.e.private_key, hs.re)) + mixKey(&hs.ss, dh(hs.s.private_key, hs.re)) + _, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext) + cs1, cs2 := split(&hs.ss) + return hs.ss.h, plaintext, (valid1 && valid2), cs1, cs2 +} + +func readMessageRegular(cs *cipherstate, message *messagebuffer) (*cipherstate, []byte, bool) { + /* No encrypted keys */ + _, plaintext, valid2 := decryptWithAd(cs, []byte{}, message.ciphertext) + return cs, plaintext, valid2 +} + +/* ---------------------------------------------------------------- * + * PROCESSES * + * ---------------------------------------------------------------- */ + +func InitSession(initiator bool, prologue []byte, s keypair, rs [32]byte) noisesession { + var session noisesession + psk := emptyKey + if initiator { + session.hs = initializeInitiator(prologue, s, rs, psk) + } else { + session.hs = initializeResponder(prologue, s, rs, psk) + } + session.i = initiator + session.mc = 0 + return session +} + +func SendMessage(session *noisesession, message []byte) (*noisesession, messagebuffer) { + var messageBuffer messagebuffer + if session.mc == 0 { + _, messageBuffer = writeMessageA(&session.hs, message) + } + if session.mc == 1 { + session.h, messageBuffer, session.cs1, session.cs2 = writeMessageB(&session.hs, message) + session.hs = handshakestate{} + } + if session.mc > 1 { + if session.i { + _, messageBuffer = writeMessageRegular(&session.cs1, message) + } else { + _, messageBuffer = writeMessageRegular(&session.cs2, message) + } + } + session.mc = session.mc + 1 + return session, messageBuffer +} + +func RecvMessage(session *noisesession, message *messagebuffer) (*noisesession, []byte, bool) { + var plaintext []byte + var valid bool + if session.mc == 0 { + _, plaintext, valid = readMessageA(&session.hs, message) + } + if session.mc == 1 { + session.h, plaintext, valid, session.cs1, session.cs2 = readMessageB(&session.hs, message) + session.hs = handshakestate{} + } + if session.mc > 1 { + if session.i { + _, plaintext, valid = readMessageRegular(&session.cs2, message) + } else { + _, plaintext, valid = readMessageRegular(&session.cs1, message) + } + } + session.mc = session.mc + 1 + return session, plaintext, valid +} + +func main() {} diff --git a/scripts/check_license_headers.sh b/scripts/check_license_headers.sh index b1a265cb2..e3558efa7 100755 --- a/scripts/check_license_headers.sh +++ b/scripts/check_license_headers.sh @@ -38,6 +38,12 @@ for file in $(find $1 -name '*.go' -not -path '*/.git/*'); do $1/wgengine/router/ifconfig_windows.go) # WireGuard copyright. ;; + *_string.go) + # Generated file from go:generate stringer + ;; + $1/control/noise/noiseexplorer_test.go) + # Noiseexplorer.com copyright. + ;; *) header="$(head -3 $file)" if ! check_file "$header"; then