mirror of https://github.com/tailscale/tailscale/
control/noise: implement the base transport for the 2021 control protocol.
Signed-off-by: David Anderson <danderson@tailscale.com>pull/3293/head
parent
3e1daab704
commit
da7544bcc5
@ -0,0 +1,330 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package noise implements the base transport of the Tailscale 2021
|
||||||
|
// control protocol.
|
||||||
|
//
|
||||||
|
// The base transport implements Noise IK, instantiated with
|
||||||
|
// Curve25519, ChaCha20Poly1305 and BLAKE2s.
|
||||||
|
package noise
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
chp "golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/crypto/poly1305"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxPlaintextSize = 4096
|
||||||
|
maxCiphertextSize = maxPlaintextSize + poly1305.TagSize
|
||||||
|
maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Conn is a secured Noise connection. It implements the net.Conn
|
||||||
|
// interface, with the unusual trait that any write error (including a
|
||||||
|
// SetWriteDeadline induced i/o timeout) cause all future writes to
|
||||||
|
// fail.
|
||||||
|
type Conn struct {
|
||||||
|
conn net.Conn
|
||||||
|
peer key.Public
|
||||||
|
handshakeHash [blake2s.Size]byte
|
||||||
|
rx rxState
|
||||||
|
tx txState
|
||||||
|
}
|
||||||
|
|
||||||
|
// rxState is all the Conn state that Read uses.
|
||||||
|
type rxState struct {
|
||||||
|
sync.Mutex
|
||||||
|
cipher cipher.AEAD
|
||||||
|
nonce [chp.NonceSize]byte
|
||||||
|
buf [maxPacketSize]byte
|
||||||
|
n int // number of valid bytes in buf
|
||||||
|
next int // offset of next undecrypted packet
|
||||||
|
plaintext []byte // slice into buf of decrypted bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// txState is all the Conn state that Write uses.
|
||||||
|
type txState struct {
|
||||||
|
sync.Mutex
|
||||||
|
cipher cipher.AEAD
|
||||||
|
nonce [chp.NonceSize]byte
|
||||||
|
buf [maxPacketSize]byte
|
||||||
|
err error // records the first partial write error for all future calls
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandshakeHash returns the Noise handshake hash for the connection,
|
||||||
|
// which can be used to bind other messages to this connection
|
||||||
|
// (i.e. to ensure that the message wasn't replayed from a different
|
||||||
|
// connection).
|
||||||
|
func (c *Conn) HandshakeHash() [blake2s.Size]byte {
|
||||||
|
return c.handshakeHash
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peer returns the peer's long-term public key.
|
||||||
|
func (c *Conn) Peer() key.Public {
|
||||||
|
return c.peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// validNonce reports whether nonce is in the valid range for use: 0
|
||||||
|
// through 2^64-2.
|
||||||
|
func validNonce(nonce []byte) bool {
|
||||||
|
return binary.BigEndian.Uint32(nonce[:4]) == 0 && binary.BigEndian.Uint64(nonce[4:]) != invalidNonce
|
||||||
|
}
|
||||||
|
|
||||||
|
// readNLocked reads into c.rxBuf until rxBuf contains at least total
|
||||||
|
// bytes. Returns a slice of the available bytes in rxBuf, or an
|
||||||
|
// error if fewer than total bytes are available.
|
||||||
|
func (c *Conn) readNLocked(total int) ([]byte, error) {
|
||||||
|
if total > maxPacketSize {
|
||||||
|
return nil, errReadTooBig{total}
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
if total <= c.rx.n {
|
||||||
|
return c.rx.buf[:c.rx.n], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := c.conn.Read(c.rx.buf[c.rx.n:])
|
||||||
|
c.rx.n += n
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext
|
||||||
|
// to the decrypted bytes. Returns an error if the cipher is exhausted
|
||||||
|
// (i.e. can no longer be used safely) or decryption fails.
|
||||||
|
func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
|
||||||
|
if !validNonce(c.rx.nonce[:]) {
|
||||||
|
return errCipherExhausted{}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil)
|
||||||
|
|
||||||
|
// Safe to increment the nonce here, because we checked for nonce
|
||||||
|
// wraparound above.
|
||||||
|
binary.BigEndian.PutUint64(c.rx.nonce[4:], 1+binary.BigEndian.Uint64(c.rx.nonce[4:]))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Once a decryption has failed, our Conn is no longer
|
||||||
|
// synchronized with our peer. Nuke the cipher state to be
|
||||||
|
// safe, so that no further decryptions are attempted.
|
||||||
|
c.rx.cipher = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// encryptLocked encrypts plaintext into c.tx.buf (including the
|
||||||
|
// 2-byte length header) and returns a slice of the ciphertext, or an
|
||||||
|
// error if the cipher is exhausted (i.e. can no longer be used safely).
|
||||||
|
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
|
||||||
|
if !validNonce(c.tx.nonce[:]) {
|
||||||
|
// Received 2^64-1 messages on this cipher state. Connection
|
||||||
|
// is no longer usable.
|
||||||
|
return nil, errCipherExhausted{}
|
||||||
|
}
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize))
|
||||||
|
ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil)
|
||||||
|
|
||||||
|
// Safe to increment the nonce here, because we checked for nonce
|
||||||
|
// wraparound above.
|
||||||
|
binary.BigEndian.PutUint64(c.tx.nonce[4:], 1+binary.BigEndian.Uint64(c.tx.nonce[4:]))
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// wholeCiphertextLocked returns a slice of one whole Noise frame from
|
||||||
|
// c.rx.buf, if one whole ciphertext is available, and advances the
|
||||||
|
// read state to the next Noise frame in the buffer. Returns nil
|
||||||
|
// without advancing read state if there's not one whole ciphertext in
|
||||||
|
// c.rx.buf.
|
||||||
|
func (c *Conn) wholeCiphertextLocked() []byte {
|
||||||
|
available := c.rx.n - c.rx.next
|
||||||
|
if available < 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bs := c.rx.buf[c.rx.next:c.rx.n]
|
||||||
|
totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2
|
||||||
|
if len(bs) < totalSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.rx.next += totalSize
|
||||||
|
return bs[:totalSize]
|
||||||
|
}
|
||||||
|
|
||||||
|
// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed,
|
||||||
|
// and sets c.rx.plaintext to point to the decrypted
|
||||||
|
// bytes. c.rx.plaintext is only valid if err == nil.
|
||||||
|
func (c *Conn) decryptOneLocked() error {
|
||||||
|
c.rx.plaintext = nil
|
||||||
|
|
||||||
|
// Fast path: do we have one whole ciphertext frame buffered
|
||||||
|
// already?
|
||||||
|
if bs := c.wholeCiphertextLocked(); bs != nil {
|
||||||
|
return c.decryptLocked(bs[2:])
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.rx.next != 0 {
|
||||||
|
// To simplify the read logic, move the remainder of the
|
||||||
|
// buffered bytes back to the head of the buffer, so we can
|
||||||
|
// grow it without worrying about wraparound.
|
||||||
|
copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n])
|
||||||
|
c.rx.n -= c.rx.next
|
||||||
|
c.rx.next = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
bs, err := c.readNLocked(2)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2
|
||||||
|
bs, err = c.readNLocked(totalLen)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.rx.next = totalLen
|
||||||
|
bs = bs[2:totalLen]
|
||||||
|
|
||||||
|
return c.decryptLocked(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements io.Reader.
|
||||||
|
func (c *Conn) Read(bs []byte) (int, error) {
|
||||||
|
c.rx.Lock()
|
||||||
|
defer c.rx.Unlock()
|
||||||
|
|
||||||
|
if c.rx.cipher == nil {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
// Loop to handle receiving a zero-byte Noise message. Just skip
|
||||||
|
// over it and keep decrypting until we find some bytes.
|
||||||
|
for len(c.rx.plaintext) == 0 {
|
||||||
|
if err := c.decryptOneLocked(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n := copy(bs, c.rx.plaintext)
|
||||||
|
c.rx.plaintext = c.rx.plaintext[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.Writer.
|
||||||
|
func (c *Conn) Write(bs []byte) (n int, err error) {
|
||||||
|
c.tx.Lock()
|
||||||
|
defer c.tx.Unlock()
|
||||||
|
|
||||||
|
if c.tx.err != nil {
|
||||||
|
return 0, c.tx.err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
// All write errors are fatal for this conn, so clear the
|
||||||
|
// cipher state whenever an error happens.
|
||||||
|
c.tx.cipher = nil
|
||||||
|
}
|
||||||
|
if c.tx.err == nil {
|
||||||
|
// Only set c.tx.err if not nil so that we can return one
|
||||||
|
// error on the first failure, and a different one for
|
||||||
|
// subsequent calls. See the error handling around Write
|
||||||
|
// below for why.
|
||||||
|
c.tx.err = err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if c.tx.cipher == nil {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
var sent int
|
||||||
|
for len(bs) > 0 {
|
||||||
|
toSend := bs
|
||||||
|
if len(toSend) > maxPlaintextSize {
|
||||||
|
toSend = bs[:maxPlaintextSize]
|
||||||
|
}
|
||||||
|
bs = bs[len(toSend):]
|
||||||
|
|
||||||
|
ciphertext, err := c.encryptLocked(toSend)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n, err := c.conn.Write(ciphertext); err != nil {
|
||||||
|
sent += n
|
||||||
|
// Return the raw error on the Write that actually
|
||||||
|
// failed. For future writes, return that error wrapped in
|
||||||
|
// a desync error.
|
||||||
|
c.tx.err = errPartialWrite{err}
|
||||||
|
return sent, err
|
||||||
|
}
|
||||||
|
sent += len(toSend)
|
||||||
|
}
|
||||||
|
return sent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements io.Closer.
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
closeErr := c.conn.Close() // unblocks any waiting reads or writes
|
||||||
|
c.rx.Lock()
|
||||||
|
c.rx.cipher = nil
|
||||||
|
c.rx.Unlock()
|
||||||
|
c.tx.Lock()
|
||||||
|
c.tx.cipher = nil
|
||||||
|
c.tx.Unlock()
|
||||||
|
return closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||||
|
func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
||||||
|
func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
||||||
|
func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
||||||
|
func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
||||||
|
|
||||||
|
// errCipherExhausted is the error returned when we run out of nonces
|
||||||
|
// on a cipher.
|
||||||
|
type errCipherExhausted struct{}
|
||||||
|
|
||||||
|
func (errCipherExhausted) Error() string {
|
||||||
|
return "cipher exhausted, no more nonces available for current key"
|
||||||
|
}
|
||||||
|
func (errCipherExhausted) Timeout() bool { return false }
|
||||||
|
func (errCipherExhausted) Temporary() bool { return false }
|
||||||
|
|
||||||
|
// errPartialWrite is the error returned when the cipher state has
|
||||||
|
// become unusable due to a past partial write.
|
||||||
|
type errPartialWrite struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e errPartialWrite) Error() string {
|
||||||
|
return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err)
|
||||||
|
}
|
||||||
|
func (e errPartialWrite) Unwrap() error { return e.err }
|
||||||
|
func (e errPartialWrite) Temporary() bool { return false }
|
||||||
|
func (e errPartialWrite) Timeout() bool { return false }
|
||||||
|
|
||||||
|
// errReadTooBig is the error returned when the peer sent an
|
||||||
|
// unacceptably large Noise frame.
|
||||||
|
type errReadTooBig struct {
|
||||||
|
requested int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e errReadTooBig) Error() string {
|
||||||
|
return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested)
|
||||||
|
}
|
||||||
|
func (e errReadTooBig) Temporary() bool {
|
||||||
|
// permanent error because this error only occurs when our peer
|
||||||
|
// sends us a frame so large we're unwilling to ever decode it.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
func (e errReadTooBig) Timeout() bool { return false }
|
@ -0,0 +1,339 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package noise
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"testing/iotest"
|
||||||
|
|
||||||
|
chp "golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/net/nettest"
|
||||||
|
tsnettest "tailscale.com/net/nettest"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMessageSize(t *testing.T) {
|
||||||
|
// This test is a regression guard against someone looking at
|
||||||
|
// maxCiphertextSize, going "huh, we could be more efficient if it
|
||||||
|
// were larger, and accidentally violating the Noise spec. Do not
|
||||||
|
// change this max value, it's a deliberate limitation of the
|
||||||
|
// cryptographic protocol we use (see Section 3 "Message Format"
|
||||||
|
// of the Noise spec).
|
||||||
|
const max = 65535
|
||||||
|
if maxCiphertextSize > max {
|
||||||
|
t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnBasic(t *testing.T) {
|
||||||
|
client, server := pair(t)
|
||||||
|
|
||||||
|
sb := sinkReads(server)
|
||||||
|
|
||||||
|
want := "test"
|
||||||
|
if _, err := io.WriteString(client, want); err != nil {
|
||||||
|
t.Fatalf("client write failed: %v", err)
|
||||||
|
}
|
||||||
|
client.Close()
|
||||||
|
|
||||||
|
if got := sb.String(4); got != want {
|
||||||
|
t.Fatalf("wrong content received: got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if err := sb.Error(); err != io.EOF {
|
||||||
|
t.Fatal("client close wasn't seen by server")
|
||||||
|
}
|
||||||
|
if sb.Total() != 4 {
|
||||||
|
t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// bufferedWriteConn wraps a net.Conn and gives control over how
|
||||||
|
// Writes get batched out.
|
||||||
|
type bufferedWriteConn struct {
|
||||||
|
net.Conn
|
||||||
|
w *bufio.Writer
|
||||||
|
manualFlush bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *bufferedWriteConn) Write(bs []byte) (int, error) {
|
||||||
|
n, err := c.w.Write(bs)
|
||||||
|
if err == nil && !c.manualFlush {
|
||||||
|
err = c.w.Flush()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFastPath exercises the Read codepath that can receive multiple
|
||||||
|
// Noise frames at once and decode each in turn without making another
|
||||||
|
// syscall.
|
||||||
|
func TestFastPath(t *testing.T) {
|
||||||
|
s1, s2 := tsnettest.NewConn("noise", 128000)
|
||||||
|
b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false}
|
||||||
|
client, server := pairWithConns(t, b, s2)
|
||||||
|
|
||||||
|
b.manualFlush = true
|
||||||
|
|
||||||
|
sb := sinkReads(server)
|
||||||
|
|
||||||
|
const packets = 10
|
||||||
|
s := "test"
|
||||||
|
for i := 0; i < packets; i++ {
|
||||||
|
// Many separate writes, to force separate Noise frames that
|
||||||
|
// all get buffered up and then all sent as a single slice to
|
||||||
|
// the server.
|
||||||
|
if _, err := io.WriteString(client, s); err != nil {
|
||||||
|
t.Fatalf("client write1 failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := b.w.Flush(); err != nil {
|
||||||
|
t.Fatalf("client flush failed: %v", err)
|
||||||
|
}
|
||||||
|
client.Close()
|
||||||
|
|
||||||
|
want := strings.Repeat(s, packets)
|
||||||
|
if got := sb.String(len(want)); got != want {
|
||||||
|
t.Fatalf("wrong content received: got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if err := sb.Error(); err != io.EOF {
|
||||||
|
t.Fatalf("client close wasn't seen by server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writes things larger than a single Noise frame, to check the
|
||||||
|
// chunking on the encoder and decoder.
|
||||||
|
func TestBigData(t *testing.T) {
|
||||||
|
client, server := pair(t)
|
||||||
|
|
||||||
|
serverReads := sinkReads(server)
|
||||||
|
clientReads := sinkReads(client)
|
||||||
|
|
||||||
|
const sz = 15 * 1024 // 15KiB
|
||||||
|
clientStr := strings.Repeat("abcde", sz/5)
|
||||||
|
serverStr := strings.Repeat("fghij", sz/5*2)
|
||||||
|
|
||||||
|
if _, err := io.WriteString(client, clientStr); err != nil {
|
||||||
|
t.Fatalf("writing client>server: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(server, serverStr); err != nil {
|
||||||
|
t.Fatalf("writing server>client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if serverGot := serverReads.String(sz); serverGot != clientStr {
|
||||||
|
t.Error("server didn't receive what client sent")
|
||||||
|
}
|
||||||
|
if clientGot := clientReads.String(2 * sz); clientGot != serverStr {
|
||||||
|
t.Error("client didn't receive what server sent")
|
||||||
|
}
|
||||||
|
|
||||||
|
getNonce := func(n [chp.NonceSize]byte) uint64 {
|
||||||
|
if binary.BigEndian.Uint32(n[:4]) != 0 {
|
||||||
|
panic("unexpected nonce")
|
||||||
|
}
|
||||||
|
return binary.BigEndian.Uint64(n[4:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reach into the Conns and verify the cipher nonces advanced as
|
||||||
|
// expected.
|
||||||
|
if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) {
|
||||||
|
t.Error("desynchronized client tx nonce")
|
||||||
|
}
|
||||||
|
if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) {
|
||||||
|
t.Error("desynchronized server tx nonce")
|
||||||
|
}
|
||||||
|
if n := getNonce(client.tx.nonce); n != 4 {
|
||||||
|
t.Errorf("wrong client tx nonce, got %d want 4", n)
|
||||||
|
}
|
||||||
|
if n := getNonce(server.tx.nonce); n != 8 {
|
||||||
|
t.Errorf("wrong client tx nonce, got %d want 8", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readerConn wraps a net.Conn and routes its Reads through a separate
|
||||||
|
// io.Reader.
|
||||||
|
type readerConn struct {
|
||||||
|
net.Conn
|
||||||
|
r io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) }
|
||||||
|
|
||||||
|
// Check that the receiver can handle not being able to read an entire
|
||||||
|
// frame in a single syscall.
|
||||||
|
func TestDataTrickle(t *testing.T) {
|
||||||
|
s1, s2 := tsnettest.NewConn("noise", 128000)
|
||||||
|
client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)})
|
||||||
|
serverReads := sinkReads(server)
|
||||||
|
|
||||||
|
const sz = 10000
|
||||||
|
clientStr := strings.Repeat("abcde", sz/5)
|
||||||
|
if _, err := io.WriteString(client, clientStr); err != nil {
|
||||||
|
t.Fatalf("writing client>server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverGot := serverReads.String(sz)
|
||||||
|
if serverGot != clientStr {
|
||||||
|
t.Error("server didn't receive what client sent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnStd(t *testing.T) {
|
||||||
|
// You can run this test manually, and noise.Conn should pass all
|
||||||
|
// of them except for TestConn/PastTimeout,
|
||||||
|
// TestConn/FutureTimeout, TestConn/ConcurrentMethods, because
|
||||||
|
// those tests assume that write errors are recoverable, and
|
||||||
|
// they're not on our Conn due to cipher security.
|
||||||
|
t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977")
|
||||||
|
nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
|
||||||
|
s1, s2 := tsnettest.NewConn("noise", 4096)
|
||||||
|
controlKey := key.NewPrivate()
|
||||||
|
machineKey := key.NewPrivate()
|
||||||
|
serverErr := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
var err error
|
||||||
|
c2, err = Server(context.Background(), s2, controlKey)
|
||||||
|
serverErr <- err
|
||||||
|
}()
|
||||||
|
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
s1.Close()
|
||||||
|
s2.Close()
|
||||||
|
return nil, nil, nil, fmt.Errorf("connecting client: %w", err)
|
||||||
|
}
|
||||||
|
if err := <-serverErr; err != nil {
|
||||||
|
c1.Close()
|
||||||
|
s1.Close()
|
||||||
|
s2.Close()
|
||||||
|
return nil, nil, nil, fmt.Errorf("connecting server: %w", err)
|
||||||
|
}
|
||||||
|
return c1, c2, func() {
|
||||||
|
c1.Close()
|
||||||
|
c2.Close()
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// mkConns creates synthetic Noise Conns wrapping the given net.Conns.
|
||||||
|
// This function is for testing just the Conn transport logic without
|
||||||
|
// having to muck about with Noise handshakes.
|
||||||
|
func mkConns(s1, s2 net.Conn) (*Conn, *Conn) {
|
||||||
|
var k1, k2 [chp.KeySize]byte
|
||||||
|
if _, err := rand.Read(k1[:]); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if _, err := rand.Read(k2[:]); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ret1 := &Conn{
|
||||||
|
conn: s1,
|
||||||
|
tx: txState{cipher: newCHP(k1)},
|
||||||
|
rx: rxState{cipher: newCHP(k2)},
|
||||||
|
}
|
||||||
|
ret2 := &Conn{
|
||||||
|
conn: s2,
|
||||||
|
tx: txState{cipher: newCHP(k2)},
|
||||||
|
rx: rxState{cipher: newCHP(k1)},
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret1, ret2
|
||||||
|
}
|
||||||
|
|
||||||
|
type readSink struct {
|
||||||
|
r io.Reader
|
||||||
|
|
||||||
|
cond *sync.Cond
|
||||||
|
sync.Mutex
|
||||||
|
bs bytes.Buffer
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func sinkReads(r io.Reader) *readSink {
|
||||||
|
ret := &readSink{
|
||||||
|
r: r,
|
||||||
|
}
|
||||||
|
ret.cond = sync.NewCond(&ret.Mutex)
|
||||||
|
go func() {
|
||||||
|
var buf [4096]byte
|
||||||
|
for {
|
||||||
|
n, err := r.Read(buf[:])
|
||||||
|
ret.Lock()
|
||||||
|
ret.bs.Write(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
ret.err = err
|
||||||
|
}
|
||||||
|
ret.cond.Broadcast()
|
||||||
|
ret.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *readSink) String(total int) string {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
for s.bs.Len() < total && s.err == nil {
|
||||||
|
s.cond.Wait()
|
||||||
|
}
|
||||||
|
if s.err != nil {
|
||||||
|
total = s.bs.Len()
|
||||||
|
}
|
||||||
|
return string(s.bs.Bytes()[:total])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *readSink) Error() error {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
for s.err == nil {
|
||||||
|
s.cond.Wait()
|
||||||
|
}
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *readSink) Total() int {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
return s.bs.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) {
|
||||||
|
var (
|
||||||
|
controlKey = key.NewPrivate()
|
||||||
|
machineKey = key.NewPrivate()
|
||||||
|
server *Conn
|
||||||
|
serverErr = make(chan error, 1)
|
||||||
|
)
|
||||||
|
go func() {
|
||||||
|
var err error
|
||||||
|
server, err = Server(context.Background(), serverConn, controlKey)
|
||||||
|
serverErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client connection failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-serverErr; err != nil {
|
||||||
|
t.Fatalf("server connection failed: %v", err)
|
||||||
|
}
|
||||||
|
return client, server
|
||||||
|
}
|
||||||
|
|
||||||
|
func pair(t *testing.T) (*Conn, *Conn) {
|
||||||
|
s1, s2 := tsnettest.NewConn("noise", 128000)
|
||||||
|
return pairWithConns(t, s1, s2)
|
||||||
|
}
|
@ -0,0 +1,361 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package noise
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
chp "golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
|
||||||
|
invalidNonce = ^uint64(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client initiates a Noise client handshake, returning the resulting
|
||||||
|
// Noise connection.
|
||||||
|
//
|
||||||
|
// The context deadline, if any, covers the entire handshaking
|
||||||
|
// process.
|
||||||
|
func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlKey key.Public) (*Conn, error) {
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
if err := conn.SetDeadline(deadline); err != nil {
|
||||||
|
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
conn.SetDeadline(time.Time{})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
var s symmetricState
|
||||||
|
s.Initialize()
|
||||||
|
|
||||||
|
// <- s
|
||||||
|
// ...
|
||||||
|
s.MixHash(controlKey[:])
|
||||||
|
|
||||||
|
var init initiationMessage
|
||||||
|
// -> e, es, s, ss
|
||||||
|
machineEphemeral := key.NewPrivate()
|
||||||
|
machineEphemeralPub := machineEphemeral.Public()
|
||||||
|
copy(init.MachineEphemeralPub(), machineEphemeralPub[:])
|
||||||
|
s.MixHash(machineEphemeralPub[:])
|
||||||
|
if err := s.MixDH(machineEphemeral, controlKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("computing es: %w", err)
|
||||||
|
}
|
||||||
|
machineKeyPub := machineKey.Public()
|
||||||
|
copy(init.MachinePub(), s.EncryptAndHash(machineKeyPub[:]))
|
||||||
|
if err := s.MixDH(machineKey, controlKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("computing ss: %w", err)
|
||||||
|
}
|
||||||
|
copy(init.Tag(), s.EncryptAndHash(nil)) // empty message payload
|
||||||
|
|
||||||
|
if _, err := conn.Write(init[:]); err != nil {
|
||||||
|
return nil, fmt.Errorf("writing initiation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// <- e, ee, se
|
||||||
|
var resp responseMessage
|
||||||
|
if _, err := io.ReadFull(conn, resp[:]); err != nil {
|
||||||
|
return nil, fmt.Errorf("reading response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var controlEphemeralPub key.Public
|
||||||
|
copy(controlEphemeralPub[:], resp.ControlEphemeralPub())
|
||||||
|
s.MixHash(controlEphemeralPub[:])
|
||||||
|
if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
|
||||||
|
return nil, fmt.Errorf("computing ee: %w", err)
|
||||||
|
}
|
||||||
|
if err := s.MixDH(machineKey, controlEphemeralPub); err != nil {
|
||||||
|
return nil, fmt.Errorf("computing se: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := s.DecryptAndHash(resp.Tag()); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypting payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c1, c2, err := s.Split()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("finalizing handshake: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Conn{
|
||||||
|
conn: conn,
|
||||||
|
peer: controlKey,
|
||||||
|
handshakeHash: s.h,
|
||||||
|
tx: txState{
|
||||||
|
cipher: c1,
|
||||||
|
},
|
||||||
|
rx: rxState{
|
||||||
|
cipher: c2,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server initiates a Noise server handshake, returning the resulting
|
||||||
|
// Noise connection.
|
||||||
|
//
|
||||||
|
// The context deadline, if any, covers the entire handshaking
|
||||||
|
// process.
|
||||||
|
func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn, error) {
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
if err := conn.SetDeadline(deadline); err != nil {
|
||||||
|
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
conn.SetDeadline(time.Time{})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
var s symmetricState
|
||||||
|
s.Initialize()
|
||||||
|
|
||||||
|
// <- s
|
||||||
|
// ...
|
||||||
|
controlKeyPub := controlKey.Public()
|
||||||
|
s.MixHash(controlKeyPub[:])
|
||||||
|
|
||||||
|
// -> e, es, s, ss
|
||||||
|
var init initiationMessage
|
||||||
|
if _, err := io.ReadFull(conn, init[:]); err != nil {
|
||||||
|
return nil, fmt.Errorf("reading initiation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var machineEphemeralPub key.Public
|
||||||
|
copy(machineEphemeralPub[:], init.MachineEphemeralPub())
|
||||||
|
s.MixHash(machineEphemeralPub[:])
|
||||||
|
if err := s.MixDH(controlKey, machineEphemeralPub); err != nil {
|
||||||
|
return nil, fmt.Errorf("computing es: %w", err)
|
||||||
|
}
|
||||||
|
var machineKey key.Public
|
||||||
|
rs, err := s.DecryptAndHash(init.MachinePub())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypting machine key: %w", err)
|
||||||
|
}
|
||||||
|
copy(machineKey[:], rs)
|
||||||
|
if err := s.MixDH(controlKey, machineKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("computing ss: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := s.DecryptAndHash(init.Tag()); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypting initiation tag: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// <- e, ee, se
|
||||||
|
var resp responseMessage
|
||||||
|
controlEphemeral := key.NewPrivate()
|
||||||
|
controlEphemeralPub := controlEphemeral.Public()
|
||||||
|
copy(resp.ControlEphemeralPub(), controlEphemeralPub[:])
|
||||||
|
s.MixHash(controlEphemeralPub[:])
|
||||||
|
if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
|
||||||
|
return nil, fmt.Errorf("computing ee: %w", err)
|
||||||
|
}
|
||||||
|
if err := s.MixDH(controlEphemeral, machineKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("computing se: %w", err)
|
||||||
|
}
|
||||||
|
copy(resp.Tag(), s.EncryptAndHash(nil)) // empty message payload
|
||||||
|
|
||||||
|
c1, c2, err := s.Split()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("finalizing handshake: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write(resp[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Conn{
|
||||||
|
conn: conn,
|
||||||
|
peer: machineKey,
|
||||||
|
handshakeHash: s.h,
|
||||||
|
tx: txState{
|
||||||
|
cipher: c2,
|
||||||
|
},
|
||||||
|
rx: rxState{
|
||||||
|
cipher: c1,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initiationMessage is the Noise protocol message sent from a client
|
||||||
|
// machine to a control server.
|
||||||
|
type initiationMessage [96]byte
|
||||||
|
|
||||||
|
func (m *initiationMessage) MachineEphemeralPub() []byte { return m[:32] }
|
||||||
|
func (m *initiationMessage) MachinePub() []byte { return m[32:80] }
|
||||||
|
func (m *initiationMessage) Tag() []byte { return m[80:] }
|
||||||
|
|
||||||
|
// responseMessage is the Noise protocol message sent from a control
|
||||||
|
// server to a client machine.
|
||||||
|
type responseMessage [48]byte
|
||||||
|
|
||||||
|
func (m *responseMessage) ControlEphemeralPub() []byte { return m[:32] }
|
||||||
|
func (m *responseMessage) Tag() []byte { return m[32:] }
|
||||||
|
|
||||||
|
// symmetricState is the SymmetricState object from the Noise protocol
|
||||||
|
// spec. It contains all the symmetric cipher state of an in-flight
|
||||||
|
// handshake. Field names match the variable names in the spec.
|
||||||
|
type symmetricState struct {
|
||||||
|
h [blake2s.Size]byte
|
||||||
|
ck [blake2s.Size]byte
|
||||||
|
|
||||||
|
k [chp.KeySize]byte
|
||||||
|
n uint64
|
||||||
|
|
||||||
|
mixer hash.Hash // for updating h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize sets s to the initial handshake state, prior to
|
||||||
|
// processing any Noise messages.
|
||||||
|
func (s *symmetricState) Initialize() {
|
||||||
|
if s.mixer != nil {
|
||||||
|
panic("symmetricState cannot be reused")
|
||||||
|
}
|
||||||
|
s.h = blake2s.Sum256([]byte(protocolName))
|
||||||
|
s.ck = s.h
|
||||||
|
s.k = [chp.KeySize]byte{}
|
||||||
|
s.n = invalidNonce
|
||||||
|
s.mixer = newBLAKE2s()
|
||||||
|
// Mix in an empty prologue.
|
||||||
|
s.MixHash(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
|
||||||
|
// concatenation.
|
||||||
|
func (s *symmetricState) MixHash(data []byte) {
|
||||||
|
s.mixer.Reset()
|
||||||
|
s.mixer.Write(s.h[:])
|
||||||
|
s.mixer.Write(data)
|
||||||
|
s.mixer.Sum(s.h[:0]) // TODO: check this actually updates s.h correctly...
|
||||||
|
}
|
||||||
|
|
||||||
|
// MixDH updates s.ck and s.k with the result of X25519(priv, pub).
|
||||||
|
//
|
||||||
|
// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
|
||||||
|
// it as a single function allows for strongly-typed arguments that
|
||||||
|
// reduce the risk of error in the caller (e.g. invoking X25519 with
|
||||||
|
// two private keys, or two public keys), and thus producing the wrong
|
||||||
|
// calculation.
|
||||||
|
func (s *symmetricState) MixDH(priv key.Private, pub key.Public) error {
|
||||||
|
// TODO(danderson): check that this operation is correct. The docs
|
||||||
|
// for X25519 say that the 2nd arg must be either Basepoint or the
|
||||||
|
// output of another X25519 call.
|
||||||
|
//
|
||||||
|
// I think this is correct, because pub is the result of a
|
||||||
|
// ScalarBaseMult on the private key, and our private key
|
||||||
|
// generation code clamps keys to avoid low order points. I
|
||||||
|
// believe that makes pub equivalent to the output of
|
||||||
|
// X25519(privateKey, Basepoint), and so the contract is
|
||||||
|
// respected.
|
||||||
|
keyData, err := curve25519.X25519(priv[:], pub[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("computing X25519: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
|
||||||
|
if _, err := io.ReadFull(r, s.ck[:]); err != nil {
|
||||||
|
return fmt.Errorf("extracting ck: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(r, s.k[:]); err != nil {
|
||||||
|
return fmt.Errorf("extracting k: %w", err)
|
||||||
|
}
|
||||||
|
s.n = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncryptAndHash encrypts the given plaintext using the current s.k,
|
||||||
|
// mixes the ciphertext into s.h, and returns the ciphertext.
|
||||||
|
func (s *symmetricState) EncryptAndHash(plaintext []byte) []byte {
|
||||||
|
if s.n == invalidNonce {
|
||||||
|
// Noise in general permits writing "ciphertext" without a
|
||||||
|
// key, but in IK it cannot happen.
|
||||||
|
panic("attempted encryption with uninitialized key")
|
||||||
|
}
|
||||||
|
aead := newCHP(s.k)
|
||||||
|
var nonce [chp.NonceSize]byte
|
||||||
|
binary.BigEndian.PutUint64(nonce[4:], s.n)
|
||||||
|
s.n++
|
||||||
|
ret := aead.Seal(nil, nonce[:], plaintext, s.h[:])
|
||||||
|
s.MixHash(ret)
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptAndHash decrypts the given ciphertext using the current
|
||||||
|
// s.k. If decryption is successful, it mixes the ciphertext into s.h
|
||||||
|
// and returns the plaintext.
|
||||||
|
func (s *symmetricState) DecryptAndHash(ciphertext []byte) ([]byte, error) {
|
||||||
|
if s.n == invalidNonce {
|
||||||
|
// Noise in general permits "ciphertext" without a key, but in
|
||||||
|
// IK it cannot happen.
|
||||||
|
panic("attempted encryption with uninitialized key")
|
||||||
|
}
|
||||||
|
aead := newCHP(s.k)
|
||||||
|
var nonce [chp.NonceSize]byte
|
||||||
|
binary.BigEndian.PutUint64(nonce[4:], s.n)
|
||||||
|
s.n++
|
||||||
|
ret, err := aead.Open(nil, nonce[:], ciphertext, s.h[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.MixHash(ciphertext)
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split returns two ChaCha20Poly1305 ciphers with keys derives from
|
||||||
|
// the current handshake state. Methods on s must not be used again
|
||||||
|
// after calling Split().
|
||||||
|
func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
|
||||||
|
var k1, k2 [chp.KeySize]byte
|
||||||
|
r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
|
||||||
|
if _, err := io.ReadFull(r, k1[:]); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("extracting k1: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(r, k2[:]); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("extracting k2: %w", err)
|
||||||
|
}
|
||||||
|
c1, err = chp.New(k1[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
|
||||||
|
}
|
||||||
|
c2, err = chp.New(k2[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
|
||||||
|
}
|
||||||
|
return c1, c2, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
|
||||||
|
// error.
|
||||||
|
func newBLAKE2s() hash.Hash {
|
||||||
|
h, err := blake2s.New256(nil)
|
||||||
|
if err != nil {
|
||||||
|
// Should never happen, errors only happen when using BLAKE2s
|
||||||
|
// in MAC mode with a key.
|
||||||
|
panic(fmt.Sprintf("blake2s construction: %v", err))
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
|
||||||
|
// panics on error.
|
||||||
|
func newCHP(key [chp.KeySize]byte) cipher.AEAD {
|
||||||
|
aead, err := chp.New(key[:])
|
||||||
|
if err != nil {
|
||||||
|
// Can only happen if we passed a key of the wrong length. The
|
||||||
|
// function signature prevents that.
|
||||||
|
panic(fmt.Sprintf("chacha20poly1305 construction: %v", err))
|
||||||
|
}
|
||||||
|
return aead
|
||||||
|
}
|
@ -0,0 +1,290 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package noise
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
tsnettest "tailscale.com/net/nettest"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandshake(t *testing.T) {
|
||||||
|
var (
|
||||||
|
clientConn, serverConn = tsnettest.NewConn("noise", 128000)
|
||||||
|
serverKey = key.NewPrivate()
|
||||||
|
clientKey = key.NewPrivate()
|
||||||
|
server *Conn
|
||||||
|
serverErr = make(chan error, 1)
|
||||||
|
)
|
||||||
|
go func() {
|
||||||
|
var err error
|
||||||
|
server, err = Server(context.Background(), serverConn, serverKey)
|
||||||
|
serverErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client connection failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-serverErr; err != nil {
|
||||||
|
t.Fatalf("server connection failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.HandshakeHash() != server.HandshakeHash() {
|
||||||
|
t.Fatal("client and server disagree on handshake hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.Peer() != serverKey.Public() {
|
||||||
|
t.Fatal("client peer key isn't serverKey")
|
||||||
|
}
|
||||||
|
if server.Peer() != clientKey.Public() {
|
||||||
|
t.Fatal("client peer key isn't serverKey")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that handshaking repeatedly with the same long-term keys
|
||||||
|
// result in different handshake hashes and wire traffic.
|
||||||
|
func TestNoReuse(t *testing.T) {
|
||||||
|
var (
|
||||||
|
hashes = map[[32]byte]bool{}
|
||||||
|
clientHandshakes = map[[96]byte]bool{}
|
||||||
|
serverHandshakes = map[[48]byte]bool{}
|
||||||
|
packets = map[[32]byte]bool{}
|
||||||
|
)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
var (
|
||||||
|
clientRaw, serverRaw = tsnettest.NewConn("noise", 128000)
|
||||||
|
clientBuf, serverBuf bytes.Buffer
|
||||||
|
clientConn = &readerConn{clientRaw, io.TeeReader(clientRaw, &clientBuf)}
|
||||||
|
serverConn = &readerConn{serverRaw, io.TeeReader(serverRaw, &serverBuf)}
|
||||||
|
serverKey = key.NewPrivate()
|
||||||
|
clientKey = key.NewPrivate()
|
||||||
|
server *Conn
|
||||||
|
serverErr = make(chan error, 1)
|
||||||
|
)
|
||||||
|
go func() {
|
||||||
|
var err error
|
||||||
|
server, err = Server(context.Background(), serverConn, serverKey)
|
||||||
|
serverErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client connection failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-serverErr; err != nil {
|
||||||
|
t.Fatalf("server connection failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var clientHS [96]byte
|
||||||
|
copy(clientHS[:], serverBuf.Bytes())
|
||||||
|
if clientHandshakes[clientHS] {
|
||||||
|
t.Fatal("client handshake seen twice")
|
||||||
|
}
|
||||||
|
clientHandshakes[clientHS] = true
|
||||||
|
|
||||||
|
var serverHS [48]byte
|
||||||
|
copy(serverHS[:], clientBuf.Bytes())
|
||||||
|
if serverHandshakes[serverHS] {
|
||||||
|
t.Fatal("server handshake seen twice")
|
||||||
|
}
|
||||||
|
serverHandshakes[serverHS] = true
|
||||||
|
|
||||||
|
clientBuf.Reset()
|
||||||
|
serverBuf.Reset()
|
||||||
|
cb := sinkReads(client)
|
||||||
|
sb := sinkReads(server)
|
||||||
|
|
||||||
|
if hashes[client.HandshakeHash()] {
|
||||||
|
t.Fatalf("handshake hash %v seen twice", client.HandshakeHash())
|
||||||
|
}
|
||||||
|
hashes[client.HandshakeHash()] = true
|
||||||
|
|
||||||
|
// Sending 14 bytes turns into 32 bytes on the wire (+16 for
|
||||||
|
// the poly1305 tag, +2 length header)
|
||||||
|
if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
|
||||||
|
t.Fatalf("client>server write failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(server, strings.Repeat("b", 14)); err != nil {
|
||||||
|
t.Fatalf("server>client write failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the bytes to be read, so we know they've traveled end to end
|
||||||
|
cb.String(14)
|
||||||
|
sb.String(14)
|
||||||
|
|
||||||
|
var clientWire, serverWire [32]byte
|
||||||
|
copy(clientWire[:], clientBuf.Bytes())
|
||||||
|
copy(serverWire[:], serverBuf.Bytes())
|
||||||
|
|
||||||
|
if packets[clientWire] {
|
||||||
|
t.Fatalf("client wire traffic seen twice")
|
||||||
|
}
|
||||||
|
packets[clientWire] = true
|
||||||
|
if packets[serverWire] {
|
||||||
|
t.Fatalf("server wire traffic seen twice")
|
||||||
|
}
|
||||||
|
packets[serverWire] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tamperReader wraps a reader and mutates the Nth byte.
|
||||||
|
type tamperReader struct {
|
||||||
|
r io.Reader
|
||||||
|
n int
|
||||||
|
total int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tamperReader) Read(bs []byte) (int, error) {
|
||||||
|
n, err := r.r.Read(bs)
|
||||||
|
if off := r.n - r.total; off >= 0 && off < n {
|
||||||
|
bs[off] += 1
|
||||||
|
}
|
||||||
|
r.total += n
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTampering(t *testing.T) {
|
||||||
|
// Tamper with every byte of the client initiation message.
|
||||||
|
for i := 0; i < 96; i++ {
|
||||||
|
var (
|
||||||
|
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
|
||||||
|
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}}
|
||||||
|
serverKey = key.NewPrivate()
|
||||||
|
clientKey = key.NewPrivate()
|
||||||
|
serverErr = make(chan error, 1)
|
||||||
|
)
|
||||||
|
go func() {
|
||||||
|
_, err := Server(context.Background(), serverConn, serverKey)
|
||||||
|
// If the server failed, we have to close the Conn to
|
||||||
|
// unblock the client.
|
||||||
|
if err != nil {
|
||||||
|
serverConn.Close()
|
||||||
|
}
|
||||||
|
serverErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("client connection succeeded despite tampering")
|
||||||
|
}
|
||||||
|
if err := <-serverErr; err == nil {
|
||||||
|
t.Fatalf("server connection succeeded despite tampering")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tamper with every byte of the server response message.
|
||||||
|
for i := 0; i < 48; i++ {
|
||||||
|
var (
|
||||||
|
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
|
||||||
|
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}}
|
||||||
|
serverKey = key.NewPrivate()
|
||||||
|
clientKey = key.NewPrivate()
|
||||||
|
serverErr = make(chan error, 1)
|
||||||
|
)
|
||||||
|
go func() {
|
||||||
|
_, err := Server(context.Background(), serverConn, serverKey)
|
||||||
|
serverErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("client connection succeeded despite tampering")
|
||||||
|
}
|
||||||
|
// The server shouldn't fail, because the tampering took place
|
||||||
|
// in its response.
|
||||||
|
if err := <-serverErr; err != nil {
|
||||||
|
t.Fatalf("server connection failed despite no tampering: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tamper with every byte of the first server>client transport message.
|
||||||
|
for i := 0; i < 32; i++ {
|
||||||
|
var (
|
||||||
|
clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
|
||||||
|
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 48 + i, 0}}
|
||||||
|
serverKey = key.NewPrivate()
|
||||||
|
clientKey = key.NewPrivate()
|
||||||
|
serverErr = make(chan error, 1)
|
||||||
|
)
|
||||||
|
go func() {
|
||||||
|
server, err := Server(context.Background(), serverConn, serverKey)
|
||||||
|
serverErr <- err
|
||||||
|
_, err = io.WriteString(server, strings.Repeat("a", 14))
|
||||||
|
serverErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client handshake failed: %v", err)
|
||||||
|
}
|
||||||
|
// The server shouldn't fail, because the tampering took place
|
||||||
|
// in its response.
|
||||||
|
if err := <-serverErr; err != nil {
|
||||||
|
t.Fatalf("server handshake failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The client needs a timeout if the tampering is hitting the length header.
|
||||||
|
if i == 0 || i == 1 {
|
||||||
|
client.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||||
|
}
|
||||||
|
|
||||||
|
var bs [100]byte
|
||||||
|
n, err := client.Read(bs[:])
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("read succeeded despite tampering")
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Fatal("conn yielded some bytes despite tampering")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tamper with every byte of the first client>server transport message.
|
||||||
|
for i := 0; i < 32; i++ {
|
||||||
|
var (
|
||||||
|
clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
|
||||||
|
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 96 + i, 0}}
|
||||||
|
serverKey = key.NewPrivate()
|
||||||
|
clientKey = key.NewPrivate()
|
||||||
|
serverErr = make(chan error, 1)
|
||||||
|
)
|
||||||
|
go func() {
|
||||||
|
server, err := Server(context.Background(), serverConn, serverKey)
|
||||||
|
serverErr <- err
|
||||||
|
var bs [100]byte
|
||||||
|
// The server needs a timeout if the tampering is hitting the length header.
|
||||||
|
if i == 0 || i == 1 {
|
||||||
|
server.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||||
|
}
|
||||||
|
n, err := server.Read(bs[:])
|
||||||
|
if n != 0 {
|
||||||
|
panic("server got bytes despite tampering")
|
||||||
|
} else {
|
||||||
|
serverErr <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client handshake failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-serverErr; err != nil {
|
||||||
|
t.Fatalf("server handshake failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.WriteString(client, strings.Repeat("a", 14)); err != nil {
|
||||||
|
t.Fatalf("client>server write failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-serverErr; err == nil {
|
||||||
|
t.Fatal("server successfully received bytes despite tampering")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,238 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package noise
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
tsnettest "tailscale.com/net/nettest"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Can a reference Noise IK client talk to our server?
|
||||||
|
func TestInteropClient(t *testing.T) {
|
||||||
|
var (
|
||||||
|
s1, s2 = tsnettest.NewConn("noise", 128000)
|
||||||
|
controlKey = key.NewPrivate()
|
||||||
|
machineKey = key.NewPrivate()
|
||||||
|
serverErr = make(chan error, 2)
|
||||||
|
serverBytes = make(chan []byte, 1)
|
||||||
|
c2s = "client>server"
|
||||||
|
s2c = "server>client"
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
server, err := Server(context.Background(), s2, controlKey)
|
||||||
|
serverErr <- err
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var buf [1024]byte
|
||||||
|
_, err = io.ReadFull(server, buf[:len(c2s)])
|
||||||
|
serverBytes <- buf[:len(c2s)]
|
||||||
|
if err != nil {
|
||||||
|
serverErr <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = server.Write([]byte(s2c))
|
||||||
|
serverErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed client interop: %v", err)
|
||||||
|
}
|
||||||
|
if string(gotS2C) != s2c {
|
||||||
|
t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := <-serverErr; err != nil {
|
||||||
|
t.Fatalf("server handshake failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-serverErr; err != nil {
|
||||||
|
t.Fatalf("server read/write failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := string(<-serverBytes); got != c2s {
|
||||||
|
t.Fatalf("server received %q, want %q", got, c2s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can our client talk to a reference Noise IK server?
|
||||||
|
func TestInteropServer(t *testing.T) {
|
||||||
|
var (
|
||||||
|
s1, s2 = tsnettest.NewConn("noise", 128000)
|
||||||
|
controlKey = key.NewPrivate()
|
||||||
|
machineKey = key.NewPrivate()
|
||||||
|
clientErr = make(chan error, 2)
|
||||||
|
clientBytes = make(chan []byte, 1)
|
||||||
|
c2s = "client>server"
|
||||||
|
s2c = "server>client"
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
client, err := Client(context.Background(), s1, machineKey, controlKey.Public())
|
||||||
|
clientErr <- err
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = client.Write([]byte(c2s))
|
||||||
|
if err != nil {
|
||||||
|
clientErr <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var buf [1024]byte
|
||||||
|
_, err = io.ReadFull(client, buf[:len(s2c)])
|
||||||
|
clientBytes <- buf[:len(s2c)]
|
||||||
|
clientErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed server interop: %v", err)
|
||||||
|
}
|
||||||
|
if string(gotC2S) != c2s {
|
||||||
|
t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := <-clientErr; err != nil {
|
||||||
|
t.Fatalf("client handshake failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-clientErr; err != nil {
|
||||||
|
t.Fatalf("client read/write failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := string(<-clientBytes); got != s2c {
|
||||||
|
t.Fatalf("client received %q, want %q", got, s2c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// noiseExplorerClient uses the Noise Explorer implementation of Noise
|
||||||
|
// IK to handshake as a Noise client on conn, transmit payload, and
|
||||||
|
// read+return a payload from the peer.
|
||||||
|
func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Private, payload []byte) ([]byte, error) {
|
||||||
|
mk := keypair{
|
||||||
|
private_key: machineKey,
|
||||||
|
public_key: machineKey.Public(),
|
||||||
|
}
|
||||||
|
session := InitSession(true, nil, mk, controlKey)
|
||||||
|
|
||||||
|
_, msg1 := SendMessage(&session, nil)
|
||||||
|
if _, err := conn.Write(msg1.ne[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := conn.Write(msg1.ns); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := conn.Write(msg1.ciphertext); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf [1024]byte
|
||||||
|
if _, err := io.ReadFull(conn, buf[:48]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg2 := messagebuffer{
|
||||||
|
ciphertext: buf[32:48],
|
||||||
|
}
|
||||||
|
copy(msg2.ne[:], buf[:32])
|
||||||
|
_, p, valid := RecvMessage(&session, &msg2)
|
||||||
|
if !valid {
|
||||||
|
return nil, errors.New("handshake failed")
|
||||||
|
}
|
||||||
|
if len(p) != 0 {
|
||||||
|
return nil, errors.New("non-empty payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, msg3 := SendMessage(&session, payload)
|
||||||
|
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg3.ciphertext)))
|
||||||
|
if _, err := conn.Write(buf[:2]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := conn.Write(msg3.ciphertext); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
plen := int(binary.BigEndian.Uint16(buf[:2]))
|
||||||
|
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg4 := messagebuffer{
|
||||||
|
ciphertext: buf[:plen],
|
||||||
|
}
|
||||||
|
_, p, valid = RecvMessage(&session, &msg4)
|
||||||
|
if !valid {
|
||||||
|
return nil, errors.New("transport message decryption failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey key.Public, payload []byte) ([]byte, error) {
|
||||||
|
mk := keypair{
|
||||||
|
private_key: controlKey,
|
||||||
|
public_key: controlKey.Public(),
|
||||||
|
}
|
||||||
|
session := InitSession(false, nil, mk, [32]byte{})
|
||||||
|
|
||||||
|
var buf [1024]byte
|
||||||
|
if _, err := io.ReadFull(conn, buf[:96]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg1 := messagebuffer{
|
||||||
|
ns: buf[32:80],
|
||||||
|
ciphertext: buf[80:96],
|
||||||
|
}
|
||||||
|
copy(msg1.ne[:], buf[:32])
|
||||||
|
_, p, valid := RecvMessage(&session, &msg1)
|
||||||
|
if !valid {
|
||||||
|
return nil, errors.New("handshake failed")
|
||||||
|
}
|
||||||
|
if len(p) != 0 {
|
||||||
|
return nil, errors.New("non-empty payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, msg2 := SendMessage(&session, nil)
|
||||||
|
if _, err := conn.Write(msg2.ne[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := conn.Write(msg2.ciphertext[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
plen := int(binary.BigEndian.Uint16(buf[:2]))
|
||||||
|
if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg3 := messagebuffer{
|
||||||
|
ciphertext: buf[:plen],
|
||||||
|
}
|
||||||
|
_, p, valid = RecvMessage(&session, &msg3)
|
||||||
|
if !valid {
|
||||||
|
return nil, errors.New("transport message decryption failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, msg4 := SendMessage(&session, payload)
|
||||||
|
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg4.ciphertext)))
|
||||||
|
if _, err := conn.Write(buf[:2]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := conn.Write(msg4.ciphertext); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
@ -0,0 +1,475 @@
|
|||||||
|
// This file contains the implementation of Noise IK from
|
||||||
|
// https://noiseexplorer.com/ . Unlike the rest of this repository,
|
||||||
|
// this file is licensed under the terms of the GNU GPL v3. See
|
||||||
|
// https://source.symbolic.software/noiseexplorer/noiseexplorer for
|
||||||
|
// more information.
|
||||||
|
//
|
||||||
|
// This file is used here to verify that Tailscale's implementation of
|
||||||
|
// Noise IK is interoperable with another implementation.
|
||||||
|
//lint:file-ignore SA4006 not our code.
|
||||||
|
|
||||||
|
/*
|
||||||
|
IK:
|
||||||
|
<- s
|
||||||
|
...
|
||||||
|
-> e, es, s, ss
|
||||||
|
<- e, ee, se
|
||||||
|
->
|
||||||
|
<-
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Implementation Version: 1.0.2
|
||||||
|
|
||||||
|
/* ---------------------------------------------------------------- *
|
||||||
|
* PARAMETERS *
|
||||||
|
* ---------------------------------------------------------------- */
|
||||||
|
|
||||||
|
package noise
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/binary"
|
||||||
|
"hash"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* ---------------------------------------------------------------- *
|
||||||
|
* TYPES *
|
||||||
|
* ---------------------------------------------------------------- */
|
||||||
|
|
||||||
|
type keypair struct {
|
||||||
|
public_key [32]byte
|
||||||
|
private_key [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type messagebuffer struct {
|
||||||
|
ne [32]byte
|
||||||
|
ns []byte
|
||||||
|
ciphertext []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type cipherstate struct {
|
||||||
|
k [32]byte
|
||||||
|
n uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type symmetricstate struct {
|
||||||
|
cs cipherstate
|
||||||
|
ck [32]byte
|
||||||
|
h [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type handshakestate struct {
|
||||||
|
ss symmetricstate
|
||||||
|
s keypair
|
||||||
|
e keypair
|
||||||
|
rs [32]byte
|
||||||
|
re [32]byte
|
||||||
|
psk [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type noisesession struct {
|
||||||
|
hs handshakestate
|
||||||
|
h [32]byte
|
||||||
|
cs1 cipherstate
|
||||||
|
cs2 cipherstate
|
||||||
|
mc uint64
|
||||||
|
i bool
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------------------------------------------------------------- *
|
||||||
|
* CONSTANTS *
|
||||||
|
* ---------------------------------------------------------------- */
|
||||||
|
|
||||||
|
var emptyKey = [32]byte{
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
}
|
||||||
|
|
||||||
|
var minNonce = uint32(0)
|
||||||
|
|
||||||
|
/* ---------------------------------------------------------------- *
|
||||||
|
* UTILITY FUNCTIONS *
|
||||||
|
* ---------------------------------------------------------------- */
|
||||||
|
|
||||||
|
func getPublicKey(kp *keypair) [32]byte {
|
||||||
|
return kp.public_key
|
||||||
|
}
|
||||||
|
|
||||||
|
func isEmptyKey(k [32]byte) bool {
|
||||||
|
return subtle.ConstantTimeCompare(k[:], emptyKey[:]) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePublicKey(k []byte) bool {
|
||||||
|
forbiddenCurveValues := [12][]byte{
|
||||||
|
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
{224, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 0},
|
||||||
|
{95, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 87},
|
||||||
|
{236, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
|
||||||
|
{237, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
|
||||||
|
{238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127},
|
||||||
|
{205, 235, 122, 124, 59, 65, 184, 174, 22, 86, 227, 250, 241, 159, 196, 106, 218, 9, 141, 235, 156, 50, 177, 253, 134, 98, 5, 22, 95, 73, 184, 128},
|
||||||
|
{76, 156, 149, 188, 163, 80, 140, 36, 177, 208, 177, 85, 156, 131, 239, 91, 4, 68, 92, 196, 88, 28, 142, 134, 216, 34, 78, 221, 208, 159, 17, 215},
|
||||||
|
{217, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
|
||||||
|
{218, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
|
||||||
|
{219, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 25},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testValue := range forbiddenCurveValues {
|
||||||
|
if subtle.ConstantTimeCompare(k[:], testValue[:]) == 1 {
|
||||||
|
panic("Invalid public key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------------------------------------------------------------- *
|
||||||
|
* PRIMITIVES *
|
||||||
|
* ---------------------------------------------------------------- */
|
||||||
|
|
||||||
|
func incrementNonce(n uint32) uint32 {
|
||||||
|
return n + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func dh(private_key [32]byte, public_key [32]byte) [32]byte {
|
||||||
|
var ss [32]byte
|
||||||
|
curve25519.ScalarMult(&ss, &private_key, &public_key)
|
||||||
|
return ss
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateKeypair() keypair {
|
||||||
|
var public_key [32]byte
|
||||||
|
var private_key [32]byte
|
||||||
|
_, _ = rand.Read(private_key[:])
|
||||||
|
curve25519.ScalarBaseMult(&public_key, &private_key)
|
||||||
|
if validatePublicKey(public_key[:]) {
|
||||||
|
return keypair{public_key, private_key}
|
||||||
|
}
|
||||||
|
return generateKeypair()
|
||||||
|
}
|
||||||
|
|
||||||
|
func generatePublicKey(private_key [32]byte) [32]byte {
|
||||||
|
var public_key [32]byte
|
||||||
|
curve25519.ScalarBaseMult(&public_key, &private_key)
|
||||||
|
return public_key
|
||||||
|
}
|
||||||
|
|
||||||
|
func encrypt(k [32]byte, n uint32, ad []byte, plaintext []byte) []byte {
|
||||||
|
var nonce [12]byte
|
||||||
|
var ciphertext []byte
|
||||||
|
enc, _ := chacha20poly1305.New(k[:])
|
||||||
|
binary.LittleEndian.PutUint32(nonce[4:], n)
|
||||||
|
ciphertext = enc.Seal(nil, nonce[:], plaintext, ad)
|
||||||
|
return ciphertext
|
||||||
|
}
|
||||||
|
|
||||||
|
func decrypt(k [32]byte, n uint32, ad []byte, ciphertext []byte) (bool, []byte, []byte) {
|
||||||
|
var nonce [12]byte
|
||||||
|
var plaintext []byte
|
||||||
|
enc, err := chacha20poly1305.New(k[:])
|
||||||
|
binary.LittleEndian.PutUint32(nonce[4:], n)
|
||||||
|
plaintext, err = enc.Open(nil, nonce[:], ciphertext, ad)
|
||||||
|
return (err == nil), ad, plaintext
|
||||||
|
}
|
||||||
|
|
||||||
|
func getHash(a []byte, b []byte) [32]byte {
|
||||||
|
return blake2s.Sum256(append(a, b...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func hashProtocolName(protocolName []byte) [32]byte {
|
||||||
|
var h [32]byte
|
||||||
|
if len(protocolName) <= 32 {
|
||||||
|
copy(h[:], protocolName)
|
||||||
|
} else {
|
||||||
|
h = getHash(protocolName, []byte{})
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func blake2HkdfInterface() hash.Hash {
|
||||||
|
h, _ := blake2s.New256([]byte{})
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func getHkdf(ck [32]byte, ikm []byte) ([32]byte, [32]byte, [32]byte) {
|
||||||
|
var k1 [32]byte
|
||||||
|
var k2 [32]byte
|
||||||
|
var k3 [32]byte
|
||||||
|
output := hkdf.New(blake2HkdfInterface, ikm[:], ck[:], []byte{})
|
||||||
|
io.ReadFull(output, k1[:])
|
||||||
|
io.ReadFull(output, k2[:])
|
||||||
|
io.ReadFull(output, k3[:])
|
||||||
|
return k1, k2, k3
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------------------------------------------------------------- *
|
||||||
|
* STATE MANAGEMENT *
|
||||||
|
* ---------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/* CipherState */
|
||||||
|
func initializeKey(k [32]byte) cipherstate {
|
||||||
|
return cipherstate{k, minNonce}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasKey(cs *cipherstate) bool {
|
||||||
|
return !isEmptyKey(cs.k)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setNonce(cs *cipherstate, newNonce uint32) *cipherstate {
|
||||||
|
cs.n = newNonce
|
||||||
|
return cs
|
||||||
|
}
|
||||||
|
|
||||||
|
func encryptWithAd(cs *cipherstate, ad []byte, plaintext []byte) (*cipherstate, []byte) {
|
||||||
|
e := encrypt(cs.k, cs.n, ad, plaintext)
|
||||||
|
cs = setNonce(cs, incrementNonce(cs.n))
|
||||||
|
return cs, e
|
||||||
|
}
|
||||||
|
|
||||||
|
func decryptWithAd(cs *cipherstate, ad []byte, ciphertext []byte) (*cipherstate, []byte, bool) {
|
||||||
|
valid, ad, plaintext := decrypt(cs.k, cs.n, ad, ciphertext)
|
||||||
|
cs = setNonce(cs, incrementNonce(cs.n))
|
||||||
|
return cs, plaintext, valid
|
||||||
|
}
|
||||||
|
|
||||||
|
func reKey(cs *cipherstate) *cipherstate {
|
||||||
|
e := encrypt(cs.k, math.MaxUint32, []byte{}, emptyKey[:])
|
||||||
|
copy(cs.k[:], e)
|
||||||
|
return cs
|
||||||
|
}
|
||||||
|
|
||||||
|
/* SymmetricState */
|
||||||
|
|
||||||
|
func initializeSymmetric(protocolName []byte) symmetricstate {
|
||||||
|
h := hashProtocolName(protocolName)
|
||||||
|
ck := h
|
||||||
|
cs := initializeKey(emptyKey)
|
||||||
|
return symmetricstate{cs, ck, h}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mixKey(ss *symmetricstate, ikm [32]byte) *symmetricstate {
|
||||||
|
ck, tempK, _ := getHkdf(ss.ck, ikm[:])
|
||||||
|
ss.cs = initializeKey(tempK)
|
||||||
|
ss.ck = ck
|
||||||
|
return ss
|
||||||
|
}
|
||||||
|
|
||||||
|
func mixHash(ss *symmetricstate, data []byte) *symmetricstate {
|
||||||
|
ss.h = getHash(ss.h[:], data)
|
||||||
|
return ss
|
||||||
|
}
|
||||||
|
|
||||||
|
func mixKeyAndHash(ss *symmetricstate, ikm [32]byte) *symmetricstate {
|
||||||
|
var tempH [32]byte
|
||||||
|
var tempK [32]byte
|
||||||
|
ss.ck, tempH, tempK = getHkdf(ss.ck, ikm[:])
|
||||||
|
ss = mixHash(ss, tempH[:])
|
||||||
|
ss.cs = initializeKey(tempK)
|
||||||
|
return ss
|
||||||
|
}
|
||||||
|
|
||||||
|
func getHandshakeHash(ss *symmetricstate) [32]byte {
|
||||||
|
return ss.h
|
||||||
|
}
|
||||||
|
|
||||||
|
func encryptAndHash(ss *symmetricstate, plaintext []byte) (*symmetricstate, []byte) {
|
||||||
|
var ciphertext []byte
|
||||||
|
if hasKey(&ss.cs) {
|
||||||
|
_, ciphertext = encryptWithAd(&ss.cs, ss.h[:], plaintext)
|
||||||
|
} else {
|
||||||
|
ciphertext = plaintext
|
||||||
|
}
|
||||||
|
ss = mixHash(ss, ciphertext)
|
||||||
|
return ss, ciphertext
|
||||||
|
}
|
||||||
|
|
||||||
|
func decryptAndHash(ss *symmetricstate, ciphertext []byte) (*symmetricstate, []byte, bool) {
|
||||||
|
var plaintext []byte
|
||||||
|
var valid bool
|
||||||
|
if hasKey(&ss.cs) {
|
||||||
|
_, plaintext, valid = decryptWithAd(&ss.cs, ss.h[:], ciphertext)
|
||||||
|
} else {
|
||||||
|
plaintext, valid = ciphertext, true
|
||||||
|
}
|
||||||
|
ss = mixHash(ss, ciphertext)
|
||||||
|
return ss, plaintext, valid
|
||||||
|
}
|
||||||
|
|
||||||
|
func split(ss *symmetricstate) (cipherstate, cipherstate) {
|
||||||
|
tempK1, tempK2, _ := getHkdf(ss.ck, []byte{})
|
||||||
|
cs1 := initializeKey(tempK1)
|
||||||
|
cs2 := initializeKey(tempK2)
|
||||||
|
return cs1, cs2
|
||||||
|
}
|
||||||
|
|
||||||
|
/* HandshakeState */
|
||||||
|
|
||||||
|
func initializeInitiator(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate {
|
||||||
|
var ss symmetricstate
|
||||||
|
var e keypair
|
||||||
|
var re [32]byte
|
||||||
|
name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s")
|
||||||
|
ss = initializeSymmetric(name)
|
||||||
|
mixHash(&ss, prologue)
|
||||||
|
mixHash(&ss, rs[:])
|
||||||
|
return handshakestate{ss, s, e, rs, re, psk}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initializeResponder(prologue []byte, s keypair, rs [32]byte, psk [32]byte) handshakestate {
|
||||||
|
var ss symmetricstate
|
||||||
|
var e keypair
|
||||||
|
var re [32]byte
|
||||||
|
name := []byte("Noise_IK_25519_ChaChaPoly_BLAKE2s")
|
||||||
|
ss = initializeSymmetric(name)
|
||||||
|
mixHash(&ss, prologue)
|
||||||
|
mixHash(&ss, s.public_key[:])
|
||||||
|
return handshakestate{ss, s, e, rs, re, psk}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeMessageA(hs *handshakestate, payload []byte) (*handshakestate, messagebuffer) {
|
||||||
|
ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
|
||||||
|
hs.e = generateKeypair()
|
||||||
|
ne = hs.e.public_key
|
||||||
|
mixHash(&hs.ss, ne[:])
|
||||||
|
/* No PSK, so skipping mixKey */
|
||||||
|
mixKey(&hs.ss, dh(hs.e.private_key, hs.rs))
|
||||||
|
spk := make([]byte, len(hs.s.public_key))
|
||||||
|
copy(spk[:], hs.s.public_key[:])
|
||||||
|
_, ns = encryptAndHash(&hs.ss, spk)
|
||||||
|
mixKey(&hs.ss, dh(hs.s.private_key, hs.rs))
|
||||||
|
_, ciphertext = encryptAndHash(&hs.ss, payload)
|
||||||
|
messageBuffer := messagebuffer{ne, ns, ciphertext}
|
||||||
|
return hs, messageBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeMessageB(hs *handshakestate, payload []byte) ([32]byte, messagebuffer, cipherstate, cipherstate) {
|
||||||
|
ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
|
||||||
|
hs.e = generateKeypair()
|
||||||
|
ne = hs.e.public_key
|
||||||
|
mixHash(&hs.ss, ne[:])
|
||||||
|
/* No PSK, so skipping mixKey */
|
||||||
|
mixKey(&hs.ss, dh(hs.e.private_key, hs.re))
|
||||||
|
mixKey(&hs.ss, dh(hs.e.private_key, hs.rs))
|
||||||
|
_, ciphertext = encryptAndHash(&hs.ss, payload)
|
||||||
|
messageBuffer := messagebuffer{ne, ns, ciphertext}
|
||||||
|
cs1, cs2 := split(&hs.ss)
|
||||||
|
return hs.ss.h, messageBuffer, cs1, cs2
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeMessageRegular(cs *cipherstate, payload []byte) (*cipherstate, messagebuffer) {
|
||||||
|
ne, ns, ciphertext := emptyKey, []byte{}, []byte{}
|
||||||
|
cs, ciphertext = encryptWithAd(cs, []byte{}, payload)
|
||||||
|
messageBuffer := messagebuffer{ne, ns, ciphertext}
|
||||||
|
return cs, messageBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func readMessageA(hs *handshakestate, message *messagebuffer) (*handshakestate, []byte, bool) {
|
||||||
|
valid1 := true
|
||||||
|
if validatePublicKey(message.ne[:]) {
|
||||||
|
hs.re = message.ne
|
||||||
|
}
|
||||||
|
mixHash(&hs.ss, hs.re[:])
|
||||||
|
/* No PSK, so skipping mixKey */
|
||||||
|
mixKey(&hs.ss, dh(hs.s.private_key, hs.re))
|
||||||
|
_, ns, valid1 := decryptAndHash(&hs.ss, message.ns)
|
||||||
|
if valid1 && len(ns) == 32 && validatePublicKey(message.ns[:]) {
|
||||||
|
copy(hs.rs[:], ns)
|
||||||
|
}
|
||||||
|
mixKey(&hs.ss, dh(hs.s.private_key, hs.rs))
|
||||||
|
_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext)
|
||||||
|
return hs, plaintext, (valid1 && valid2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readMessageB(hs *handshakestate, message *messagebuffer) ([32]byte, []byte, bool, cipherstate, cipherstate) {
|
||||||
|
valid1 := true
|
||||||
|
if validatePublicKey(message.ne[:]) {
|
||||||
|
hs.re = message.ne
|
||||||
|
}
|
||||||
|
mixHash(&hs.ss, hs.re[:])
|
||||||
|
/* No PSK, so skipping mixKey */
|
||||||
|
mixKey(&hs.ss, dh(hs.e.private_key, hs.re))
|
||||||
|
mixKey(&hs.ss, dh(hs.s.private_key, hs.re))
|
||||||
|
_, plaintext, valid2 := decryptAndHash(&hs.ss, message.ciphertext)
|
||||||
|
cs1, cs2 := split(&hs.ss)
|
||||||
|
return hs.ss.h, plaintext, (valid1 && valid2), cs1, cs2
|
||||||
|
}
|
||||||
|
|
||||||
|
func readMessageRegular(cs *cipherstate, message *messagebuffer) (*cipherstate, []byte, bool) {
|
||||||
|
/* No encrypted keys */
|
||||||
|
_, plaintext, valid2 := decryptWithAd(cs, []byte{}, message.ciphertext)
|
||||||
|
return cs, plaintext, valid2
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------------------------------------------------------------- *
|
||||||
|
* PROCESSES *
|
||||||
|
* ---------------------------------------------------------------- */
|
||||||
|
|
||||||
|
func InitSession(initiator bool, prologue []byte, s keypair, rs [32]byte) noisesession {
|
||||||
|
var session noisesession
|
||||||
|
psk := emptyKey
|
||||||
|
if initiator {
|
||||||
|
session.hs = initializeInitiator(prologue, s, rs, psk)
|
||||||
|
} else {
|
||||||
|
session.hs = initializeResponder(prologue, s, rs, psk)
|
||||||
|
}
|
||||||
|
session.i = initiator
|
||||||
|
session.mc = 0
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendMessage(session *noisesession, message []byte) (*noisesession, messagebuffer) {
|
||||||
|
var messageBuffer messagebuffer
|
||||||
|
if session.mc == 0 {
|
||||||
|
_, messageBuffer = writeMessageA(&session.hs, message)
|
||||||
|
}
|
||||||
|
if session.mc == 1 {
|
||||||
|
session.h, messageBuffer, session.cs1, session.cs2 = writeMessageB(&session.hs, message)
|
||||||
|
session.hs = handshakestate{}
|
||||||
|
}
|
||||||
|
if session.mc > 1 {
|
||||||
|
if session.i {
|
||||||
|
_, messageBuffer = writeMessageRegular(&session.cs1, message)
|
||||||
|
} else {
|
||||||
|
_, messageBuffer = writeMessageRegular(&session.cs2, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
session.mc = session.mc + 1
|
||||||
|
return session, messageBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func RecvMessage(session *noisesession, message *messagebuffer) (*noisesession, []byte, bool) {
|
||||||
|
var plaintext []byte
|
||||||
|
var valid bool
|
||||||
|
if session.mc == 0 {
|
||||||
|
_, plaintext, valid = readMessageA(&session.hs, message)
|
||||||
|
}
|
||||||
|
if session.mc == 1 {
|
||||||
|
session.h, plaintext, valid, session.cs1, session.cs2 = readMessageB(&session.hs, message)
|
||||||
|
session.hs = handshakestate{}
|
||||||
|
}
|
||||||
|
if session.mc > 1 {
|
||||||
|
if session.i {
|
||||||
|
_, plaintext, valid = readMessageRegular(&session.cs2, message)
|
||||||
|
} else {
|
||||||
|
_, plaintext, valid = readMessageRegular(&session.cs1, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
session.mc = session.mc + 1
|
||||||
|
return session, plaintext, valid
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {}
|
Loading…
Reference in New Issue