control/noise: adjust implementation to match revised spec.

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/3293/head
David Anderson 3 years ago committed by Dave Anderson
parent 89a68a4c22
commit 0b392dbaf7

@ -24,9 +24,9 @@ import (
) )
const ( const (
maxPlaintextSize = 4096 maxMessageSize = 4096
maxCiphertextSize = maxPlaintextSize + poly1305.TagSize maxCiphertextSize = maxMessageSize - headerLen
maxPacketSize = maxCiphertextSize + 2 // ciphertext + length header maxPlaintextSize = maxCiphertextSize - poly1305.TagSize
) )
// A Conn is a secured Noise connection. It implements the net.Conn // A Conn is a secured Noise connection. It implements the net.Conn
@ -35,6 +35,7 @@ const (
// fail. // fail.
type Conn struct { type Conn struct {
conn net.Conn conn net.Conn
version int
peer key.Public peer key.Public
handshakeHash [blake2s.Size]byte handshakeHash [blake2s.Size]byte
rx rxState rx rxState
@ -46,7 +47,7 @@ type rxState struct {
sync.Mutex sync.Mutex
cipher cipher.AEAD cipher cipher.AEAD
nonce [chp.NonceSize]byte nonce [chp.NonceSize]byte
buf [maxPacketSize]byte buf [maxMessageSize]byte
n int // number of valid bytes in buf n int // number of valid bytes in buf
next int // offset of next undecrypted packet next int // offset of next undecrypted packet
plaintext []byte // slice into buf of decrypted bytes plaintext []byte // slice into buf of decrypted bytes
@ -57,10 +58,14 @@ type txState struct {
sync.Mutex sync.Mutex
cipher cipher.AEAD cipher cipher.AEAD
nonce [chp.NonceSize]byte nonce [chp.NonceSize]byte
buf [maxPacketSize]byte buf [maxMessageSize]byte
err error // records the first partial write error for all future calls err error // records the first partial write error for all future calls
} }
func (c *Conn) ProtocolVersion() int {
return c.version
}
// HandshakeHash returns the Noise handshake hash for the connection, // HandshakeHash returns the Noise handshake hash for the connection,
// which can be used to bind other messages to this connection // which can be used to bind other messages to this connection
// (i.e. to ensure that the message wasn't replayed from a different // (i.e. to ensure that the message wasn't replayed from a different
@ -84,7 +89,7 @@ func validNonce(nonce []byte) bool {
// bytes. Returns a slice of the available bytes in rxBuf, or an // bytes. Returns a slice of the available bytes in rxBuf, or an
// error if fewer than total bytes are available. // error if fewer than total bytes are available.
func (c *Conn) readNLocked(total int) ([]byte, error) { func (c *Conn) readNLocked(total int) ([]byte, error) {
if total > maxPacketSize { if total > maxMessageSize {
return nil, errReadTooBig{total} return nil, errReadTooBig{total}
} }
for { for {
@ -100,10 +105,20 @@ func (c *Conn) readNLocked(total int) ([]byte, error) {
} }
} }
// decryptLocked decrypts ciphertext in-place and sets c.rx.plaintext // decryptLocked decrypts message (which is header+ciphertext)
// to the decrypted bytes. Returns an error if the cipher is exhausted // in-place and sets c.rx.plaintext to the decrypted bytes. Returns an
// (i.e. can no longer be used safely) or decryption fails. // error if the cipher is exhausted (i.e. can no longer be used
func (c *Conn) decryptLocked(ciphertext []byte) (err error) { // safely) or decryption fails.
func (c *Conn) decryptLocked(msg []byte) (err error) {
if hdrVersion(msg) != c.version {
return fmt.Errorf("received message with unexpected protocol version %d, want %d", hdrVersion(msg), c.version)
}
if hdrType(msg) != msgTypeRecord {
return fmt.Errorf("received message with unexpected type %d, want %d", hdrType(msg), msgTypeRecord)
}
// length was already handled in caller to size msg.
ciphertext := msg[headerLen:]
if !validNonce(c.rx.nonce[:]) { if !validNonce(c.rx.nonce[:]) {
return errCipherExhausted{} return errCipherExhausted{}
} }
@ -124,8 +139,8 @@ func (c *Conn) decryptLocked(ciphertext []byte) (err error) {
} }
// encryptLocked encrypts plaintext into c.tx.buf (including the // encryptLocked encrypts plaintext into c.tx.buf (including the
// 2-byte length header) and returns a slice of the ciphertext, or an // packet header) and returns a slice of the ciphertext, or an error
// error if the cipher is exhausted (i.e. can no longer be used safely). // if the cipher is exhausted (i.e. can no longer be used safely).
func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) { func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
if !validNonce(c.tx.nonce[:]) { if !validNonce(c.tx.nonce[:]) {
// Received 2^64-1 messages on this cipher state. Connection // Received 2^64-1 messages on this cipher state. Connection
@ -133,8 +148,8 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
return nil, errCipherExhausted{} return nil, errCipherExhausted{}
} }
binary.BigEndian.PutUint16(c.tx.buf[:2], uint16(len(plaintext)+poly1305.TagSize)) setHeader(c.tx.buf[:5], protocolVersion, msgTypeRecord, len(plaintext)+poly1305.TagSize)
ret := c.tx.cipher.Seal(c.tx.buf[:2], c.tx.nonce[:], plaintext, nil) ret := c.tx.cipher.Seal(c.tx.buf[:5], c.tx.nonce[:], plaintext, nil)
// Safe to increment the nonce here, because we checked for nonce // Safe to increment the nonce here, because we checked for nonce
// wraparound above. // wraparound above.
@ -143,18 +158,18 @@ func (c *Conn) encryptLocked(plaintext []byte) ([]byte, error) {
return ret, nil return ret, nil
} }
// wholeCiphertextLocked returns a slice of one whole Noise frame from // wholeMessageLocked returns a slice of one whole Noise transport
// c.rx.buf, if one whole ciphertext is available, and advances the // message from c.rx.buf, if one whole message is available, and
// read state to the next Noise frame in the buffer. Returns nil // advances the read state to the next Noise message in the
// without advancing read state if there's not one whole ciphertext in // buffer. Returns nil without advancing read state if there isn't one
// c.rx.buf. // whole message in c.rx.buf.
func (c *Conn) wholeCiphertextLocked() []byte { func (c *Conn) wholeMessageLocked() []byte {
available := c.rx.n - c.rx.next available := c.rx.n - c.rx.next
if available < 2 { if available < headerLen {
return nil return nil
} }
bs := c.rx.buf[c.rx.next:c.rx.n] bs := c.rx.buf[c.rx.next:c.rx.n]
totalSize := int(binary.BigEndian.Uint16(bs[:2])) + 2 totalSize := hdrLen(bs) + headerLen
if len(bs) < totalSize { if len(bs) < totalSize {
return nil return nil
} }
@ -162,16 +177,16 @@ func (c *Conn) wholeCiphertextLocked() []byte {
return bs[:totalSize] return bs[:totalSize]
} }
// decryptOneLocked decrypts one Noise frame, reading from c.conn as needed, // decryptOneLocked decrypts one Noise transport message, reading from
// and sets c.rx.plaintext to point to the decrypted // c.conn as needed, and sets c.rx.plaintext to point to the decrypted
// bytes. c.rx.plaintext is only valid if err == nil. // bytes. c.rx.plaintext is only valid if err == nil.
func (c *Conn) decryptOneLocked() error { func (c *Conn) decryptOneLocked() error {
c.rx.plaintext = nil c.rx.plaintext = nil
// Fast path: do we have one whole ciphertext frame buffered // Fast path: do we have one whole ciphertext frame buffered
// already? // already?
if bs := c.wholeCiphertextLocked(); bs != nil { if bs := c.wholeMessageLocked(); bs != nil {
return c.decryptLocked(bs[2:]) return c.decryptLocked(bs)
} }
if c.rx.next != 0 { if c.rx.next != 0 {
@ -183,18 +198,20 @@ func (c *Conn) decryptOneLocked() error {
c.rx.next = 0 c.rx.next = 0
} }
bs, err := c.readNLocked(2) bs, err := c.readNLocked(headerLen)
if err != nil { if err != nil {
return err return err
} }
totalLen := int(binary.BigEndian.Uint16(bs[:2])) + 2 // The rest of the header (besides the length field) gets verified
bs, err = c.readNLocked(totalLen) // in decryptLocked, not here.
messageLen := headerLen + hdrLen(bs)
bs, err = c.readNLocked(messageLen)
if err != nil { if err != nil {
return err return err
} }
bs = bs[:messageLen]
c.rx.next = totalLen c.rx.next = len(bs)
bs = bs[2:totalLen]
return c.decryptLocked(bs) return c.decryptLocked(bs)
} }

@ -12,6 +12,7 @@ import (
"hash" "hash"
"io" "io"
"net" "net"
"strconv"
"time" "time"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
@ -23,15 +24,32 @@ import (
) )
const ( const (
// protocolName is the name of the specific instantiation of the
// Noise protocol we're using. Each field is defined in the Noise
// spec, and shouldn't be changed unless we're switching to a
// different Noise protocol instance.
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
// protocolVersion is the version string that gets included as the // protocolVersion is the version of the Tailscale base
// Noise "prologue" in the handshake. It exists so that we can // protocol that Client will use when initiating a handshake.
// ensure that peer have agreed on the protocol version they're protocolVersion = 1
// executing, to defeat some MITM protocol downgrade attacks. // protocolVersionPrefix is the name portion of the protocol
protocolVersion = "Tailscale Control Protocol v1" // name+version string that gets mixed into the Noise handshake as
// a prologue.
//
// This mixing verifies that both clients agree that
// they're executing the Tailscale control protocol at a specific
// version that matches the advertised version in the cleartext
// packet header.
protocolVersionPrefix = "Tailscale Control Protocol v"
invalidNonce = ^uint64(0) invalidNonce = ^uint64(0)
) )
func protocolVersionPrologue(version int) []byte {
ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
ret = append(ret, protocolVersionPrefix...)
return strconv.AppendUint(ret, uint64(version), 10)
}
// Client initiates a Noise client handshake, returning the resulting // Client initiates a Noise client handshake, returning the resulting
// Noise connection. // Noise connection.
// //
@ -50,15 +68,18 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK
var s symmetricState var s symmetricState
s.Initialize() s.Initialize()
// prologue
s.MixHash(protocolVersionPrologue(protocolVersion))
// <- s // <- s
// ... // ...
s.MixHash(controlKey[:]) s.MixHash(controlKey[:])
// -> e, es, s, ss // -> e, es, s, ss
var init initiationMessage init := mkInitiationMessage()
machineEphemeral := key.NewPrivate() machineEphemeral := key.NewPrivate()
machineEphemeralPub := machineEphemeral.Public() machineEphemeralPub := machineEphemeral.Public()
copy(init.MachineEphemeralPub(), machineEphemeralPub[:]) copy(init.EphemeralPub(), machineEphemeralPub[:])
s.MixHash(machineEphemeralPub[:]) s.MixHash(machineEphemeralPub[:])
if err := s.MixDH(machineEphemeral, controlKey); err != nil { if err := s.MixDH(machineEphemeral, controlKey); err != nil {
return nil, fmt.Errorf("computing es: %w", err) return nil, fmt.Errorf("computing es: %w", err)
@ -74,14 +95,34 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK
return nil, fmt.Errorf("writing initiation: %w", err) return nil, fmt.Errorf("writing initiation: %w", err)
} }
// <- e, ee, se // Read in the payload and look for errors/protocol violations from the server.
var resp responseMessage var resp responseMessage
if _, err := io.ReadFull(conn, resp[:]); err != nil { if _, err := io.ReadFull(conn, resp.Header()); err != nil {
return nil, fmt.Errorf("reading response: %w", err) return nil, fmt.Errorf("reading response header: %w", err)
}
if resp.Version() != protocolVersion {
return nil, fmt.Errorf("unexpected version %d from server, want %d", resp.Version(), protocolVersion)
}
if resp.Type() != msgTypeResponse {
if resp.Type() != msgTypeError {
return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
}
msg := make([]byte, resp.Length())
if _, err := io.ReadFull(conn, msg); err != nil {
return nil, err
}
return nil, fmt.Errorf("server error: %s", string(msg))
}
if resp.Length() != len(resp.Payload()) {
return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
}
if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
return nil, err
} }
// <- e, ee, se
var controlEphemeralPub key.Public var controlEphemeralPub key.Public
copy(controlEphemeralPub[:], resp.ControlEphemeralPub()) copy(controlEphemeralPub[:], resp.EphemeralPub())
s.MixHash(controlEphemeralPub[:]) s.MixHash(controlEphemeralPub[:])
if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { if err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
return nil, fmt.Errorf("computing ee: %w", err) return nil, fmt.Errorf("computing ee: %w", err)
@ -100,6 +141,7 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.Private, controlK
return &Conn{ return &Conn{
conn: conn, conn: conn,
version: protocolVersion,
peer: controlKey, peer: controlKey,
handshakeHash: s.h, handshakeHash: s.h,
tx: txState{ tx: txState{
@ -126,22 +168,55 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn,
}() }()
} }
// Deliberately does not support formatting, so that we don't echo
// attacker-controlled input back to them.
sendErr := func(msg string) error {
if len(msg) >= 1<<16 {
msg = msg[:1<<16]
}
var hdr [headerLen]byte
setHeader(hdr[:], protocolVersion, msgTypeError, len(msg))
if _, err := conn.Write(hdr[:]); err != nil {
return fmt.Errorf("sending %q error to client: %w", msg, err)
}
if _, err := conn.Write([]byte(msg)); err != nil {
return fmt.Errorf("sending %q error to client: %w", msg, err)
}
return fmt.Errorf("refused client handshake: %s", msg)
}
var s symmetricState var s symmetricState
s.Initialize() s.Initialize()
var init initiationMessage
if _, err := io.ReadFull(conn, init.Header()); err != nil {
return nil, err
}
if init.Version() != protocolVersion {
return nil, sendErr("unsupported protocol version")
}
if init.Type() != msgTypeInitiation {
return nil, sendErr("unexpected handshake message type")
}
if init.Length() != len(init.Payload()) {
return nil, sendErr("wrong handshake initiation length")
}
if _, err := io.ReadFull(conn, init.Payload()); err != nil {
return nil, err
}
// prologue. Can only do this once we at least think the client is
// handshaking using a supported version.
s.MixHash(protocolVersionPrologue(protocolVersion))
// <- s // <- s
// ... // ...
controlKeyPub := controlKey.Public() controlKeyPub := controlKey.Public()
s.MixHash(controlKeyPub[:]) s.MixHash(controlKeyPub[:])
// -> e, es, s, ss // -> e, es, s, ss
var init initiationMessage
if _, err := io.ReadFull(conn, init[:]); err != nil {
return nil, fmt.Errorf("reading initiation: %w", err)
}
var machineEphemeralPub key.Public var machineEphemeralPub key.Public
copy(machineEphemeralPub[:], init.MachineEphemeralPub()) copy(machineEphemeralPub[:], init.EphemeralPub())
s.MixHash(machineEphemeralPub[:]) s.MixHash(machineEphemeralPub[:])
if err := s.MixDH(controlKey, machineEphemeralPub); err != nil { if err := s.MixDH(controlKey, machineEphemeralPub); err != nil {
return nil, fmt.Errorf("computing es: %w", err) return nil, fmt.Errorf("computing es: %w", err)
@ -158,10 +233,10 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn,
} }
// <- e, ee, se // <- e, ee, se
var resp responseMessage resp := mkResponseMessage()
controlEphemeral := key.NewPrivate() controlEphemeral := key.NewPrivate()
controlEphemeralPub := controlEphemeral.Public() controlEphemeralPub := controlEphemeral.Public()
copy(resp.ControlEphemeralPub(), controlEphemeralPub[:]) copy(resp.EphemeralPub(), controlEphemeralPub[:])
s.MixHash(controlEphemeralPub[:]) s.MixHash(controlEphemeralPub[:])
if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { if err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
return nil, fmt.Errorf("computing ee: %w", err) return nil, fmt.Errorf("computing ee: %w", err)
@ -182,6 +257,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn,
return &Conn{ return &Conn{
conn: conn, conn: conn,
version: protocolVersion,
peer: machineKey, peer: machineKey,
handshakeHash: s.h, handshakeHash: s.h,
tx: txState{ tx: txState{
@ -193,21 +269,6 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.Private) (*Conn,
}, nil }, nil
} }
// initiationMessage is the Noise protocol message sent from a client
// machine to a control server.
type initiationMessage [96]byte
func (m *initiationMessage) MachineEphemeralPub() []byte { return m[:32] }
func (m *initiationMessage) MachinePub() []byte { return m[32:80] }
func (m *initiationMessage) Tag() []byte { return m[80:] }
// responseMessage is the Noise protocol message sent from a control
// server to a client machine.
type responseMessage [48]byte
func (m *responseMessage) ControlEphemeralPub() []byte { return m[:32] }
func (m *responseMessage) Tag() []byte { return m[32:] }
// symmetricState is the SymmetricState object from the Noise protocol // symmetricState is the SymmetricState object from the Noise protocol
// spec. It contains all the symmetric cipher state of an in-flight // spec. It contains all the symmetric cipher state of an in-flight
// handshake. Field names match the variable names in the spec. // handshake. Field names match the variable names in the spec.
@ -232,7 +293,6 @@ func (s *symmetricState) Initialize() {
s.k = [chp.KeySize]byte{} s.k = [chp.KeySize]byte{}
s.n = invalidNonce s.n = invalidNonce
s.mixer = newBLAKE2s() s.mixer = newBLAKE2s()
s.MixHash([]byte(protocolVersion))
} }
// MixHash updates s.h to be BLAKE2s(s.h || data), where || is // MixHash updates s.h to be BLAKE2s(s.h || data), where || is

@ -42,6 +42,12 @@ func TestHandshake(t *testing.T) {
t.Fatal("client and server disagree on handshake hash") t.Fatal("client and server disagree on handshake hash")
} }
if client.ProtocolVersion() != protocolVersion {
t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), protocolVersion)
}
if client.ProtocolVersion() != server.ProtocolVersion() {
t.Fatalf("peers disagree on protocol version, client=%d server=%d", client.ProtocolVersion(), server.ProtocolVersion())
}
if client.Peer() != serverKey.Public() { if client.Peer() != serverKey.Public() {
t.Fatal("client peer key isn't serverKey") t.Fatal("client peer key isn't serverKey")
} }
@ -154,7 +160,7 @@ func (r *tamperReader) Read(bs []byte) (int, error) {
func TestTampering(t *testing.T) { func TestTampering(t *testing.T) {
// Tamper with every byte of the client initiation message. // Tamper with every byte of the client initiation message.
for i := 0; i < 96; i++ { for i := 0; i < 101; i++ {
var ( var (
clientConn, serverRaw = tsnettest.NewConn("noise", 128000) clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}} serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, i, 0}}
@ -182,7 +188,7 @@ func TestTampering(t *testing.T) {
} }
// Tamper with every byte of the server response message. // Tamper with every byte of the server response message.
for i := 0; i < 48; i++ { for i := 0; i < 53; i++ {
var ( var (
clientRaw, serverConn = tsnettest.NewConn("noise", 128000) clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}} clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, i, 0}}
@ -210,7 +216,7 @@ func TestTampering(t *testing.T) {
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
var ( var (
clientRaw, serverConn = tsnettest.NewConn("noise", 128000) clientRaw, serverConn = tsnettest.NewConn("noise", 128000)
clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 48 + i, 0}} clientConn = &readerConn{clientRaw, &tamperReader{clientRaw, 53 + i, 0}}
serverKey = key.NewPrivate() serverKey = key.NewPrivate()
clientKey = key.NewPrivate() clientKey = key.NewPrivate()
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
@ -233,7 +239,7 @@ func TestTampering(t *testing.T) {
} }
// The client needs a timeout if the tampering is hitting the length header. // The client needs a timeout if the tampering is hitting the length header.
if i == 0 || i == 1 { if i == 3 || i == 4 {
client.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) client.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
} }
@ -251,7 +257,7 @@ func TestTampering(t *testing.T) {
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
var ( var (
clientConn, serverRaw = tsnettest.NewConn("noise", 128000) clientConn, serverRaw = tsnettest.NewConn("noise", 128000)
serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 96 + i, 0}} serverConn = &readerConn{serverRaw, &tamperReader{serverRaw, 101 + i, 0}}
serverKey = key.NewPrivate() serverKey = key.NewPrivate()
clientKey = key.NewPrivate() clientKey = key.NewPrivate()
serverErr = make(chan error, 1) serverErr = make(chan error, 1)
@ -261,7 +267,7 @@ func TestTampering(t *testing.T) {
serverErr <- err serverErr <- err
var bs [100]byte var bs [100]byte
// The server needs a timeout if the tampering is hitting the length header. // The server needs a timeout if the tampering is hitting the length header.
if i == 0 || i == 1 { if i == 3 || i == 4 {
server.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) server.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
} }
n, err := server.Read(bs[:]) n, err := server.Read(bs[:])

@ -120,9 +120,14 @@ func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Pr
private_key: machineKey, private_key: machineKey,
public_key: machineKey.Public(), public_key: machineKey.Public(),
} }
session := InitSession(true, []byte(protocolVersion), mk, controlKey) session := InitSession(true, protocolVersionPrologue(protocolVersion), mk, controlKey)
_, msg1 := SendMessage(&session, nil) _, msg1 := SendMessage(&session, nil)
var hdr [headerLen]byte
setHeader(hdr[:], protocolVersion, msgTypeInitiation, 96)
if _, err := conn.Write(hdr[:]); err != nil {
return nil, err
}
if _, err := conn.Write(msg1.ne[:]); err != nil { if _, err := conn.Write(msg1.ne[:]); err != nil {
return nil, err return nil, err
} }
@ -134,13 +139,15 @@ func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Pr
} }
var buf [1024]byte var buf [1024]byte
if _, err := io.ReadFull(conn, buf[:48]); err != nil { if _, err := io.ReadFull(conn, buf[:53]); err != nil {
return nil, err return nil, err
} }
// ignore the header for this test, we're only checking the noise
// implementation.
msg2 := messagebuffer{ msg2 := messagebuffer{
ciphertext: buf[32:48], ciphertext: buf[37:53],
} }
copy(msg2.ne[:], buf[:32]) copy(msg2.ne[:], buf[5:37])
_, p, valid := RecvMessage(&session, &msg2) _, p, valid := RecvMessage(&session, &msg2)
if !valid { if !valid {
return nil, errors.New("handshake failed") return nil, errors.New("handshake failed")
@ -150,18 +157,19 @@ func noiseExplorerClient(conn net.Conn, controlKey key.Public, machineKey key.Pr
} }
_, msg3 := SendMessage(&session, payload) _, msg3 := SendMessage(&session, payload)
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg3.ciphertext))) setHeader(hdr[:], protocolVersion, msgTypeRecord, len(msg3.ciphertext))
if _, err := conn.Write(buf[:2]); err != nil { if _, err := conn.Write(hdr[:]); err != nil {
return nil, err return nil, err
} }
if _, err := conn.Write(msg3.ciphertext); err != nil { if _, err := conn.Write(msg3.ciphertext); err != nil {
return nil, err return nil, err
} }
if _, err := io.ReadFull(conn, buf[:2]); err != nil { if _, err := io.ReadFull(conn, buf[:5]); err != nil {
return nil, err return nil, err
} }
plen := int(binary.BigEndian.Uint16(buf[:2])) // Ignore all of the header except the payload length
plen := int(binary.LittleEndian.Uint16(buf[3:5]))
if _, err := io.ReadFull(conn, buf[:plen]); err != nil { if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
return nil, err return nil, err
} }
@ -182,17 +190,18 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k
private_key: controlKey, private_key: controlKey,
public_key: controlKey.Public(), public_key: controlKey.Public(),
} }
session := InitSession(false, []byte(protocolVersion), mk, [32]byte{}) session := InitSession(false, protocolVersionPrologue(protocolVersion), mk, [32]byte{})
var buf [1024]byte var buf [1024]byte
if _, err := io.ReadFull(conn, buf[:96]); err != nil { if _, err := io.ReadFull(conn, buf[:101]); err != nil {
return nil, err return nil, err
} }
// Ignore the header, we're just checking the noise implementation.
msg1 := messagebuffer{ msg1 := messagebuffer{
ns: buf[32:80], ns: buf[37:85],
ciphertext: buf[80:96], ciphertext: buf[85:101],
} }
copy(msg1.ne[:], buf[:32]) copy(msg1.ne[:], buf[5:37])
_, p, valid := RecvMessage(&session, &msg1) _, p, valid := RecvMessage(&session, &msg1)
if !valid { if !valid {
return nil, errors.New("handshake failed") return nil, errors.New("handshake failed")
@ -202,6 +211,11 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k
} }
_, msg2 := SendMessage(&session, nil) _, msg2 := SendMessage(&session, nil)
var hdr [headerLen]byte
setHeader(hdr[:], protocolVersion, msgTypeResponse, 48)
if _, err := conn.Write(hdr[:]); err != nil {
return nil, err
}
if _, err := conn.Write(msg2.ne[:]); err != nil { if _, err := conn.Write(msg2.ne[:]); err != nil {
return nil, err return nil, err
} }
@ -209,10 +223,10 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k
return nil, err return nil, err
} }
if _, err := io.ReadFull(conn, buf[:2]); err != nil { if _, err := io.ReadFull(conn, buf[:5]); err != nil {
return nil, err return nil, err
} }
plen := int(binary.BigEndian.Uint16(buf[:2])) plen := int(binary.LittleEndian.Uint16(buf[3:5]))
if _, err := io.ReadFull(conn, buf[:plen]); err != nil { if _, err := io.ReadFull(conn, buf[:plen]); err != nil {
return nil, err return nil, err
} }
@ -226,8 +240,8 @@ func noiseExplorerServer(conn net.Conn, controlKey key.Private, wantMachineKey k
} }
_, msg4 := SendMessage(&session, payload) _, msg4 := SendMessage(&session, payload)
binary.BigEndian.PutUint16(buf[:2], uint16(len(msg4.ciphertext))) setHeader(hdr[:], protocolVersion, msgTypeRecord, len(msg4.ciphertext))
if _, err := conn.Write(buf[:2]); err != nil { if _, err := conn.Write(hdr[:]); err != nil {
return nil, err return nil, err
} }
if _, err := conn.Write(msg4.ciphertext); err != nil { if _, err := conn.Write(msg4.ciphertext); err != nil {

@ -0,0 +1,26 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package noise
// Note that these types are deliberately separate from the types/key
// package. That package defines generic curve25519 keys, without
// consideration for how those keys are used. We don't want to
// encourage mixing machine keys, node keys, and whatever else we
// might use curve25519 for.
//
// Furthermore, the implementation in types/key does some work that is
// unnecessary for machine keys, and results in a harder to follow
// implementation. In particular, machine keys do not need to be
// clamped per the curve25519 spec because they're only used with the
// X25519 operation, and the X25519 operation defines its own clamping
// and sanity checking logic. Thus, these keys must be used only with
// this Noise protocol implementation, and the easiest way to ensure
// that is a different type.
// PrivateKey is a Tailscale machine private key.
type PrivateKey [32]byte
// PublicKey is a Tailscale machine public key.
type PublicKey [32]byte

@ -0,0 +1,87 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package noise
import "encoding/binary"
const (
msgTypeInitiation = 1
msgTypeResponse = 2
msgTypeError = 3
msgTypeRecord = 4
)
// headerLen is the size of the cleartext message header that gets
// prepended to Noise messages.
//
// 2b: protocol version
// 1b: message type
// 2b: payload length (not including this header)
const headerLen = 5
func setHeader(bs []byte, version int, msgType byte, length int) {
binary.LittleEndian.PutUint16(bs[:2], uint16(version))
bs[2] = msgType
binary.LittleEndian.PutUint16(bs[3:5], uint16(length))
}
func hdrVersion(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[:2])) }
func hdrType(bs []byte) byte { return bs[2] }
func hdrLen(bs []byte) int { return int(binary.LittleEndian.Uint16(bs[3:5])) }
// initiationMessage is the Noise protocol message sent from a client
// machine to a control server.
//
// 5b: header (see headerLen for fields)
// 32b: client ephemeral public key (cleartext)
// 48b: client machine public key (encrypted)
// 16b: message tag (authenticates the whole message)
type initiationMessage [101]byte
func mkInitiationMessage() initiationMessage {
var ret initiationMessage
binary.LittleEndian.PutUint16(ret[:2], protocolVersion)
ret[2] = msgTypeInitiation
binary.LittleEndian.PutUint16(ret[3:5], 96)
return ret
}
func (m *initiationMessage) Header() []byte { return m[:5] }
func (m *initiationMessage) Payload() []byte { return m[5:] }
func (m *initiationMessage) Version() int { return hdrVersion(m.Header()) }
func (m *initiationMessage) Type() byte { return hdrType(m.Header()) }
func (m *initiationMessage) Length() int { return hdrLen(m.Header()) }
func (m *initiationMessage) EphemeralPub() []byte { return m[5:37] }
func (m *initiationMessage) MachinePub() []byte { return m[37:85] }
func (m *initiationMessage) Tag() []byte { return m[85:] }
// responseMessage is the Noise protocol message sent from a control
// server to a client machine.
//
// 2b: little-endian protocol version
// 1b: message type
// 2b: little-endian size of message (not including this header)
// 32b: control ephemeral public key (cleartext)
// 16b: message tag (authenticates the whole message)
type responseMessage [53]byte
func mkResponseMessage() responseMessage {
var ret responseMessage
binary.LittleEndian.PutUint16(ret[:2], protocolVersion)
ret[2] = msgTypeResponse
binary.LittleEndian.PutUint16(ret[3:5], 48)
return ret
}
func (m *responseMessage) Header() []byte { return m[:5] }
func (m *responseMessage) Payload() []byte { return m[5:] }
func (m *responseMessage) Version() int { return hdrVersion(m.Header()) }
func (m *responseMessage) Type() byte { return hdrType(m.Header()) }
func (m *responseMessage) Length() int { return hdrLen(m.Header()) }
func (m *responseMessage) EphemeralPub() []byte { return m[5:37] }
func (m *responseMessage) Tag() []byte { return m[37:] }
Loading…
Cancel
Save