// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package noise import ( "bufio" "bytes" "context" "crypto/rand" "encoding/binary" "fmt" "io" "net" "strings" "sync" "testing" "testing/iotest" chp "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/nettest" tsnettest "tailscale.com/net/nettest" "tailscale.com/types/key" ) func TestMessageSize(t *testing.T) { // This test is a regression guard against someone looking at // maxCiphertextSize, going "huh, we could be more efficient if it // were larger, and accidentally violating the Noise spec. Do not // change this max value, it's a deliberate limitation of the // cryptographic protocol we use (see Section 3 "Message Format" // of the Noise spec). const max = 65535 if maxCiphertextSize > max { t.Fatalf("max ciphertext size is %d, which is larger than the maximum noise message size %d", maxCiphertextSize, max) } } func TestConnBasic(t *testing.T) { client, server := pair(t) sb := sinkReads(server) want := "test" if _, err := io.WriteString(client, want); err != nil { t.Fatalf("client write failed: %v", err) } client.Close() if got := sb.String(4); got != want { t.Fatalf("wrong content received: got %q, want %q", got, want) } if err := sb.Error(); err != io.EOF { t.Fatal("client close wasn't seen by server") } if sb.Total() != 4 { t.Fatalf("wrong amount of bytes received: got %d, want 4", sb.Total()) } } // bufferedWriteConn wraps a net.Conn and gives control over how // Writes get batched out. type bufferedWriteConn struct { net.Conn w *bufio.Writer manualFlush bool } func (c *bufferedWriteConn) Write(bs []byte) (int, error) { n, err := c.w.Write(bs) if err == nil && !c.manualFlush { err = c.w.Flush() } return n, err } // TestFastPath exercises the Read codepath that can receive multiple // Noise frames at once and decode each in turn without making another // syscall. func TestFastPath(t *testing.T) { s1, s2 := tsnettest.NewConn("noise", 128000) b := &bufferedWriteConn{s1, bufio.NewWriterSize(s1, 10000), false} client, server := pairWithConns(t, b, s2) b.manualFlush = true sb := sinkReads(server) const packets = 10 s := "test" for i := 0; i < packets; i++ { // Many separate writes, to force separate Noise frames that // all get buffered up and then all sent as a single slice to // the server. if _, err := io.WriteString(client, s); err != nil { t.Fatalf("client write1 failed: %v", err) } } if err := b.w.Flush(); err != nil { t.Fatalf("client flush failed: %v", err) } client.Close() want := strings.Repeat(s, packets) if got := sb.String(len(want)); got != want { t.Fatalf("wrong content received: got %q, want %q", got, want) } if err := sb.Error(); err != io.EOF { t.Fatalf("client close wasn't seen by server") } } // Writes things larger than a single Noise frame, to check the // chunking on the encoder and decoder. func TestBigData(t *testing.T) { client, server := pair(t) serverReads := sinkReads(server) clientReads := sinkReads(client) const sz = 15 * 1024 // 15KiB clientStr := strings.Repeat("abcde", sz/5) serverStr := strings.Repeat("fghij", sz/5*2) if _, err := io.WriteString(client, clientStr); err != nil { t.Fatalf("writing client>server: %v", err) } if _, err := io.WriteString(server, serverStr); err != nil { t.Fatalf("writing server>client: %v", err) } if serverGot := serverReads.String(sz); serverGot != clientStr { t.Error("server didn't receive what client sent") } if clientGot := clientReads.String(2 * sz); clientGot != serverStr { t.Error("client didn't receive what server sent") } getNonce := func(n [chp.NonceSize]byte) uint64 { if binary.BigEndian.Uint32(n[:4]) != 0 { panic("unexpected nonce") } return binary.BigEndian.Uint64(n[4:]) } // Reach into the Conns and verify the cipher nonces advanced as // expected. if getNonce(client.tx.nonce) != getNonce(server.rx.nonce) { t.Error("desynchronized client tx nonce") } if getNonce(server.tx.nonce) != getNonce(client.rx.nonce) { t.Error("desynchronized server tx nonce") } if n := getNonce(client.tx.nonce); n != 4 { t.Errorf("wrong client tx nonce, got %d want 4", n) } if n := getNonce(server.tx.nonce); n != 8 { t.Errorf("wrong client tx nonce, got %d want 8", n) } } // readerConn wraps a net.Conn and routes its Reads through a separate // io.Reader. type readerConn struct { net.Conn r io.Reader } func (c readerConn) Read(bs []byte) (int, error) { return c.r.Read(bs) } // Check that the receiver can handle not being able to read an entire // frame in a single syscall. func TestDataTrickle(t *testing.T) { s1, s2 := tsnettest.NewConn("noise", 128000) client, server := pairWithConns(t, s1, readerConn{s2, iotest.OneByteReader(s2)}) serverReads := sinkReads(server) const sz = 10000 clientStr := strings.Repeat("abcde", sz/5) if _, err := io.WriteString(client, clientStr); err != nil { t.Fatalf("writing client>server: %v", err) } serverGot := serverReads.String(sz) if serverGot != clientStr { t.Error("server didn't receive what client sent") } } func TestConnStd(t *testing.T) { // You can run this test manually, and noise.Conn should pass all // of them except for TestConn/PastTimeout, // TestConn/FutureTimeout, TestConn/ConcurrentMethods, because // those tests assume that write errors are recoverable, and // they're not on our Conn due to cipher security. t.Skip("not all tests can pass on this Conn, see https://github.com/golang/go/issues/46977") nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { s1, s2 := tsnettest.NewConn("noise", 4096) controlKey := key.NewPrivate() machineKey := key.NewPrivate() serverErr := make(chan error, 1) go func() { var err error c2, err = Server(context.Background(), s2, controlKey) serverErr <- err }() c1, err = Client(context.Background(), s1, machineKey, controlKey.Public()) if err != nil { s1.Close() s2.Close() return nil, nil, nil, fmt.Errorf("connecting client: %w", err) } if err := <-serverErr; err != nil { c1.Close() s1.Close() s2.Close() return nil, nil, nil, fmt.Errorf("connecting server: %w", err) } return c1, c2, func() { c1.Close() c2.Close() }, nil }) } // mkConns creates synthetic Noise Conns wrapping the given net.Conns. // This function is for testing just the Conn transport logic without // having to muck about with Noise handshakes. func mkConns(s1, s2 net.Conn) (*Conn, *Conn) { var k1, k2 [chp.KeySize]byte if _, err := rand.Read(k1[:]); err != nil { panic(err) } if _, err := rand.Read(k2[:]); err != nil { panic(err) } ret1 := &Conn{ conn: s1, tx: txState{cipher: newCHP(k1)}, rx: rxState{cipher: newCHP(k2)}, } ret2 := &Conn{ conn: s2, tx: txState{cipher: newCHP(k2)}, rx: rxState{cipher: newCHP(k1)}, } return ret1, ret2 } type readSink struct { r io.Reader cond *sync.Cond sync.Mutex bs bytes.Buffer err error } func sinkReads(r io.Reader) *readSink { ret := &readSink{ r: r, } ret.cond = sync.NewCond(&ret.Mutex) go func() { var buf [4096]byte for { n, err := r.Read(buf[:]) ret.Lock() ret.bs.Write(buf[:n]) if err != nil { ret.err = err } ret.cond.Broadcast() ret.Unlock() if err != nil { return } } }() return ret } func (s *readSink) String(total int) string { s.Lock() defer s.Unlock() for s.bs.Len() < total && s.err == nil { s.cond.Wait() } if s.err != nil { total = s.bs.Len() } return string(s.bs.Bytes()[:total]) } func (s *readSink) Error() error { s.Lock() defer s.Unlock() for s.err == nil { s.cond.Wait() } return s.err } func (s *readSink) Total() int { s.Lock() defer s.Unlock() return s.bs.Len() } func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) { var ( controlKey = key.NewPrivate() machineKey = key.NewPrivate() server *Conn serverErr = make(chan error, 1) ) go func() { var err error server, err = Server(context.Background(), serverConn, controlKey) serverErr <- err }() client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public()) if err != nil { t.Fatalf("client connection failed: %v", err) } if err := <-serverErr; err != nil { t.Fatalf("server connection failed: %v", err) } return client, server } func pair(t *testing.T) (*Conn, *Conn) { s1, s2 := tsnettest.NewConn("noise", 128000) return pairWithConns(t, s1, s2) }