// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package derp import ( "bufio" "bytes" "context" "crypto/x509" "encoding/asn1" "encoding/json" "errors" "expvar" "fmt" "io" "log" "net" "os" "reflect" "strconv" "sync" "testing" "time" "go4.org/mem" "golang.org/x/time/rate" "tailscale.com/disco" "tailscale.com/net/memnet" "tailscale.com/tstest" "tailscale.com/types/key" "tailscale.com/types/logger" ) func TestClientInfoUnmarshal(t *testing.T) { for i, in := range []string{ `{"Version":5,"MeshKey":"abc"}`, `{"version":5,"meshKey":"abc"}`, } { var got clientInfo if err := json.Unmarshal([]byte(in), &got); err != nil { t.Fatalf("[%d]: %v", i, err) } want := clientInfo{Version: 5, MeshKey: "abc"} if got != want { t.Errorf("[%d]: got %+v; want %+v", i, got, want) } } } func TestSendRecv(t *testing.T) { serverPrivateKey := key.NewNode() s := NewServer(serverPrivateKey, t.Logf) defer s.Close() const numClients = 3 var clientPrivateKeys []key.NodePrivate var clientKeys []key.NodePublic for i := 0; i < numClients; i++ { priv := key.NewNode() clientPrivateKeys = append(clientPrivateKeys, priv) clientKeys = append(clientKeys, priv.Public()) } ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } defer ln.Close() var clients []*Client var connsOut []Conn var recvChs []chan []byte errCh := make(chan error, 3) for i := 0; i < numClients; i++ { t.Logf("Connecting client %d ...", i) cout, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } defer cout.Close() connsOut = append(connsOut, cout) cin, err := ln.Accept() if err != nil { t.Fatal(err) } defer cin.Close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() brwServer := bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin)) go s.Accept(ctx, cin, brwServer, fmt.Sprintf("test-client-%d", i)) key := clientPrivateKeys[i] brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout)) c, err := NewClient(key, cout, brw, t.Logf) if err != nil { t.Fatalf("client %d: %v", i, err) } waitConnect(t, c) clients = append(clients, c) recvChs = append(recvChs, make(chan []byte)) t.Logf("Connected client %d.", i) } var peerGoneCountDisconnected expvar.Int var peerGoneCountNotHere expvar.Int t.Logf("Starting read loops") for i := 0; i < numClients; i++ { go func(i int) { for { m, err := clients[i].Recv() if err != nil { errCh <- err return } switch m := m.(type) { default: t.Errorf("unexpected message type %T", m) continue case PeerGoneMessage: switch m.Reason { case PeerGoneReasonDisconnected: peerGoneCountDisconnected.Add(1) case PeerGoneReasonNotHere: peerGoneCountNotHere.Add(1) default: t.Errorf("unexpected PeerGone reason %v", m.Reason) } case ReceivedPacket: if m.Source.IsZero() { t.Errorf("zero Source address in ReceivedPacket") } recvChs[i] <- bytes.Clone(m.Data) } } }(i) } recv := func(i int, want string) { t.Helper() select { case b := <-recvChs[i]: if got := string(b); got != want { t.Errorf("client1.Recv=%q, want %q", got, want) } case <-time.After(5 * time.Second): t.Errorf("client%d.Recv, got nothing, want %q", i, want) } } recvNothing := func(i int) { t.Helper() select { case b := <-recvChs[0]: t.Errorf("client%d.Recv=%q, want nothing", i, string(b)) default: } } wantActive := func(total, home int64) { t.Helper() dl := time.Now().Add(5 * time.Second) var gotTotal, gotHome int64 for time.Now().Before(dl) { gotTotal, gotHome = s.curClients.Value(), s.curHomeClients.Value() if gotTotal == total && gotHome == home { return } time.Sleep(10 * time.Millisecond) } t.Errorf("total/home=%v/%v; want %v/%v", gotTotal, gotHome, total, home) } wantClosedPeers := func(want int64) { t.Helper() var got int64 dl := time.Now().Add(5 * time.Second) for time.Now().Before(dl) { if got = peerGoneCountDisconnected.Value(); got == want { return } } t.Errorf("peer gone count = %v; want %v", got, want) } wantUnknownPeers := func(want int64) { t.Helper() var got int64 dl := time.Now().Add(5 * time.Second) for time.Now().Before(dl) { if got = peerGoneCountNotHere.Value(); got == want { return } } t.Errorf("peer gone count = %v; want %v", got, want) } msg1 := []byte("hello 0->1\n") if err := clients[0].Send(clientKeys[1], msg1); err != nil { t.Fatal(err) } recv(1, string(msg1)) recvNothing(0) recvNothing(2) msg2 := []byte("hello 1->2\n") if err := clients[1].Send(clientKeys[2], msg2); err != nil { t.Fatal(err) } recv(2, string(msg2)) recvNothing(0) recvNothing(1) // Send messages to a non-existent node neKey := key.NewNode().Public() msg4 := []byte("not a CallMeMaybe->unknown destination\n") if err := clients[1].Send(neKey, msg4); err != nil { t.Fatal(err) } wantUnknownPeers(0) callMe := neKey.AppendTo([]byte(disco.Magic)) callMeHeader := make([]byte, disco.NonceLen) callMe = append(callMe, callMeHeader...) if err := clients[1].Send(neKey, callMe); err != nil { t.Fatal(err) } wantUnknownPeers(1) // PeerGoneNotHere is rate-limited to 3 times a second for i := 0; i < 5; i++ { if err := clients[1].Send(neKey, callMe); err != nil { t.Fatal(err) } } wantUnknownPeers(3) wantActive(3, 0) clients[0].NotePreferred(true) wantActive(3, 1) clients[0].NotePreferred(true) wantActive(3, 1) clients[0].NotePreferred(false) wantActive(3, 0) clients[0].NotePreferred(false) wantActive(3, 0) clients[1].NotePreferred(true) wantActive(3, 1) connsOut[1].Close() wantActive(2, 0) wantClosedPeers(1) clients[2].NotePreferred(true) wantActive(2, 1) clients[2].NotePreferred(false) wantActive(2, 0) connsOut[2].Close() wantActive(1, 0) wantClosedPeers(1) t.Logf("passed") s.Close() } func TestSendFreeze(t *testing.T) { serverPrivateKey := key.NewNode() s := NewServer(serverPrivateKey, t.Logf) defer s.Close() s.WriteTimeout = 100 * time.Millisecond // We send two streams of messages: // // alice --> bob // alice --> cathy // // Then cathy stops processing messages. // That should not interfere with alice talking to bob. newClient := func(ctx context.Context, name string, k key.NodePrivate) (c *Client, clientConn memnet.Conn) { t.Helper() c1, c2 := memnet.NewConn(name, 1024) go s.Accept(ctx, c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name) brw := bufio.NewReadWriter(bufio.NewReader(c2), bufio.NewWriter(c2)) c, err := NewClient(k, c2, brw, t.Logf) if err != nil { t.Fatal(err) } waitConnect(t, c) return c, c2 } ctx, clientCtxCancel := context.WithCancel(context.Background()) defer clientCtxCancel() aliceKey := key.NewNode() aliceClient, aliceConn := newClient(ctx, "alice", aliceKey) bobKey := key.NewNode() bobClient, bobConn := newClient(ctx, "bob", bobKey) cathyKey := key.NewNode() cathyClient, cathyConn := newClient(ctx, "cathy", cathyKey) var ( aliceCh = make(chan struct{}, 32) bobCh = make(chan struct{}, 32) cathyCh = make(chan struct{}, 32) ) chs := func(name string) chan struct{} { switch name { case "alice": return aliceCh case "bob": return bobCh case "cathy": return cathyCh default: panic("unknown ch: " + name) } } errCh := make(chan error, 4) recv := func(name string, client *Client) { ch := chs(name) for { m, err := client.Recv() if err != nil { errCh <- fmt.Errorf("%s: %w", name, err) return } switch m := m.(type) { default: errCh <- fmt.Errorf("%s: unexpected message type %T", name, m) return case ReceivedPacket: if m.Source.IsZero() { errCh <- fmt.Errorf("%s: zero Source address in ReceivedPacket", name) return } select { case ch <- struct{}{}: default: } } } } go recv("alice", aliceClient) go recv("bob", bobClient) go recv("cathy", cathyClient) var cancel func() go func() { t := time.NewTicker(2 * time.Millisecond) defer t.Stop() var ctx context.Context ctx, cancel = context.WithCancel(context.Background()) for { select { case <-t.C: case <-ctx.Done(): errCh <- nil return } msg1 := []byte("hello alice->bob\n") if err := aliceClient.Send(bobKey.Public(), msg1); err != nil { errCh <- fmt.Errorf("alice send to bob: %w", err) return } msg2 := []byte("hello alice->cathy\n") // TODO: an error is expected here. // We ignore it, maybe we should log it somehow? aliceClient.Send(cathyKey.Public(), msg2) } }() drainAny := func(ch chan struct{}) { // We are draining potentially infinite sources, // so place some reasonable upper limit. // // The important thing here is to make sure that // if any tokens remain in the channel, they // must have been generated after drainAny was // called. for i := 0; i < cap(ch); i++ { select { case <-ch: default: return } } } drain := func(t *testing.T, name string) bool { t.Helper() timer := time.NewTimer(1 * time.Second) defer timer.Stop() // Ensure ch has at least one element. ch := chs(name) select { case <-ch: case <-timer.C: t.Errorf("no packet received by %s", name) return false } // Drain remaining. drainAny(ch) return true } isEmpty := func(t *testing.T, name string) { t.Helper() select { case <-chs(name): t.Errorf("packet received by %s, want none", name) default: } } t.Run("initial send", func(t *testing.T) { drain(t, "bob") drain(t, "cathy") isEmpty(t, "alice") }) t.Run("block cathy", func(t *testing.T) { // Block cathy. Now the cathyConn buffer will fill up quickly, // and the derp server will back up. cathyConn.SetReadBlock(true) time.Sleep(2 * s.WriteTimeout) drain(t, "bob") drainAny(chs("cathy")) isEmpty(t, "alice") // Now wait a little longer, and ensure packets still flow to bob if !drain(t, "bob") { t.Errorf("connection alice->bob frozen by alice->cathy") } }) // Cleanup, make sure we process all errors. t.Logf("TEST COMPLETE, cancelling sender") cancel() t.Logf("closing connections") // Close bob before alice. // Starting with alice can cause a PeerGoneMessage to reach // bob before bob is closed, causing a test flake (issue 2668). bobConn.Close() aliceConn.Close() cathyConn.Close() for i := 0; i < cap(errCh); i++ { err := <-errCh if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { continue } t.Error(err) } } } type testServer struct { s *Server ln net.Listener logf logger.Logf mu sync.Mutex pubName map[key.NodePublic]string clients map[*testClient]bool } func (ts *testServer) addTestClient(c *testClient) { ts.mu.Lock() defer ts.mu.Unlock() ts.clients[c] = true } func (ts *testServer) addKeyName(k key.NodePublic, name string) { ts.mu.Lock() defer ts.mu.Unlock() ts.pubName[k] = name ts.logf("test adding named key %q for %x", name, k) } func (ts *testServer) keyName(k key.NodePublic) string { ts.mu.Lock() defer ts.mu.Unlock() if name, ok := ts.pubName[k]; ok { return name } return k.ShortString() } func (ts *testServer) close(t *testing.T) error { ts.ln.Close() ts.s.Close() for c := range ts.clients { c.close(t) } return nil } func newTestServer(t *testing.T, ctx context.Context) *testServer { t.Helper() logf := logger.WithPrefix(t.Logf, "derp-server: ") s := NewServer(key.NewNode(), logf) s.SetMeshKey("mesh-key") ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } go func() { i := 0 for { i++ c, err := ln.Accept() if err != nil { return } // TODO: register c in ts so Close also closes it? go func(i int) { brwServer := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) go s.Accept(ctx, c, brwServer, fmt.Sprintf("test-client-%d", i)) }(i) } }() return &testServer{ s: s, ln: ln, logf: logf, clients: map[*testClient]bool{}, pubName: map[key.NodePublic]string{}, } } type testClient struct { name string c *Client nc net.Conn pub key.NodePublic ts *testServer closed bool } func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net.Conn, key.NodePrivate, logger.Logf) (*Client, error)) *testClient { t.Helper() nc, err := net.Dial("tcp", ts.ln.Addr().String()) if err != nil { t.Fatal(err) } k := key.NewNode() ts.addKeyName(k.Public(), name) c, err := newClient(nc, k, logger.WithPrefix(t.Logf, "client-"+name+": ")) if err != nil { t.Fatal(err) } tc := &testClient{ name: name, nc: nc, c: c, ts: ts, pub: k.Public(), } ts.addTestClient(tc) return tc } func newRegularClient(t *testing.T, ts *testServer, name string) *testClient { return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) c, err := NewClient(priv, nc, brw, logf) if err != nil { return nil, err } waitConnect(t, c) return c, nil }) } func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient { return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) c, err := NewClient(priv, nc, brw, logf, MeshKey("mesh-key")) if err != nil { return nil, err } waitConnect(t, c) if err := c.WatchConnectionChanges(); err != nil { return nil, err } return c, nil }) } func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) { t.Helper() want := map[key.NodePublic]bool{} for _, k := range peers { want[k] = true } for { m, err := tc.c.recvTimeout(time.Second) if err != nil { t.Fatal(err) } switch m := m.(type) { case PeerPresentMessage: got := key.NodePublic(m) if !want[got] { t.Fatalf("got peer present for %v; want present for %v", tc.ts.keyName(got), logger.ArgWriter(func(bw *bufio.Writer) { for _, pub := range peers { fmt.Fprintf(bw, "%s ", tc.ts.keyName(pub)) } })) } delete(want, got) if len(want) == 0 { return } default: t.Fatalf("unexpected message type %T", m) } } } func (tc *testClient) wantGone(t *testing.T, peer key.NodePublic) { t.Helper() m, err := tc.c.recvTimeout(time.Second) if err != nil { t.Fatal(err) } switch m := m.(type) { case PeerGoneMessage: got := key.NodePublic(m.Peer) if peer != got { t.Errorf("got gone message for %v; want gone for %v", tc.ts.keyName(got), tc.ts.keyName(peer)) } reason := m.Reason if reason != PeerGoneReasonDisconnected { t.Errorf("got gone message for reason %v; wanted %v", reason, PeerGoneReasonDisconnected) } default: t.Fatalf("unexpected message type %T", m) } } func (c *testClient) close(t *testing.T) { t.Helper() if c.closed { return } c.closed = true t.Logf("closing client %q (%x)", c.name, c.pub) c.nc.Close() } // TestWatch tests the connection watcher mechanism used by regional // DERP nodes to mesh up with each other. func TestWatch(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() ts := newTestServer(t, ctx) defer ts.close(t) w1 := newTestWatcher(t, ts, "w1") w1.wantPresent(t, w1.pub) c1 := newRegularClient(t, ts, "c1") w1.wantPresent(t, c1.pub) c2 := newRegularClient(t, ts, "c2") w1.wantPresent(t, c2.pub) w2 := newTestWatcher(t, ts, "w2") w1.wantPresent(t, w2.pub) w2.wantPresent(t, w1.pub, w2.pub, c1.pub, c2.pub) c3 := newRegularClient(t, ts, "c3") w1.wantPresent(t, c3.pub) w2.wantPresent(t, c3.pub) c2.close(t) w1.wantGone(t, c2.pub) w2.wantGone(t, c2.pub) w3 := newTestWatcher(t, ts, "w3") w1.wantPresent(t, w3.pub) w2.wantPresent(t, w3.pub) w3.wantPresent(t, c1.pub, c3.pub, w1.pub, w2.pub, w3.pub) c1.close(t) w1.wantGone(t, c1.pub) w2.wantGone(t, c1.pub) w3.wantGone(t, c1.pub) } type testFwd int func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error { panic("not called in tests") } func (testFwd) String() string { panic("not called in tests") } func pubAll(b byte) (ret key.NodePublic) { var bs [32]byte for i := range bs { bs[i] = b } return key.NodePublicFromRaw32(mem.B(bs[:])) } func TestForwarderRegistration(t *testing.T) { s := &Server{ clients: make(map[key.NodePublic]clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } want := func(want map[key.NodePublic]PacketForwarder) { t.Helper() if got := s.clientsMesh; !reflect.DeepEqual(got, want) { t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want) } } wantCounter := func(c *expvar.Int, want int) { t.Helper() if got := c.Value(); got != int64(want) { t.Errorf("counter = %v; want %v", got, want) } } u1 := pubAll(1) u2 := pubAll(2) u3 := pubAll(3) s.AddPacketForwarder(u1, testFwd(1)) s.AddPacketForwarder(u2, testFwd(2)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Verify a remove of non-registered forwarder is no-op. s.RemovePacketForwarder(u2, testFwd(999)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Verify a remove of non-registered user is no-op. s.RemovePacketForwarder(u3, testFwd(1)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Actual removal. s.RemovePacketForwarder(u2, testFwd(2)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), }) // Adding a dup for a user. wantCounter(&s.multiForwarderCreated, 0) s.AddPacketForwarder(u1, testFwd(100)) s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path want(map[key.NodePublic]PacketForwarder{ u1: newMultiForwarder(testFwd(1), testFwd(100)), }) wantCounter(&s.multiForwarderCreated, 1) // Removing a forwarder in a multi set that doesn't exist; does nothing. s.RemovePacketForwarder(u1, testFwd(55)) want(map[key.NodePublic]PacketForwarder{ u1: newMultiForwarder(testFwd(1), testFwd(100)), }) // Removing a forwarder in a multi set that does exist should collapse it away // from being a multiForwarder. wantCounter(&s.multiForwarderDeleted, 0) s.RemovePacketForwarder(u1, testFwd(1)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(100), }) wantCounter(&s.multiForwarderDeleted, 1) // Removing an entry for a client that's still connected locally should result // in a nil forwarder. u1c := &sclient{ key: u1, logf: logger.Discard, } s.clients[u1] = singleClient{u1c} s.RemovePacketForwarder(u1, testFwd(100)) want(map[key.NodePublic]PacketForwarder{ u1: nil, }) // But once that client disconnects, it should go away. s.unregisterClient(u1c) want(map[key.NodePublic]PacketForwarder{}) // But if it already has a forwarder, it's not removed. s.AddPacketForwarder(u1, testFwd(2)) s.unregisterClient(u1c) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(2), }) // Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard // that they're also connected to a peer of ours. That shouldn't transition the forwarder // from nil to the new one, not a multiForwarder. s.clients[u1] = singleClient{u1c} s.clientsMesh[u1] = nil want(map[key.NodePublic]PacketForwarder{ u1: nil, }) s.AddPacketForwarder(u1, testFwd(3)) want(map[key.NodePublic]PacketForwarder{ u1: testFwd(3), }) } type channelFwd struct { // id is to ensure that different instances that reference the // same channel are not equal, as they are used as keys in the // multiForwarder map. id int c chan []byte } func (f channelFwd) String() string { return "" } func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error { f.c <- packet return nil } func TestMultiForwarder(t *testing.T) { received := 0 var wg sync.WaitGroup ch := make(chan []byte) ctx, cancel := context.WithCancel(context.Background()) s := &Server{ clients: make(map[key.NodePublic]clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } u := pubAll(1) s.AddPacketForwarder(u, channelFwd{1, ch}) wg.Add(2) go func() { defer wg.Done() for { select { case <-ch: received += 1 case <-ctx.Done(): return } } }() go func() { defer wg.Done() for { s.AddPacketForwarder(u, channelFwd{2, ch}) s.AddPacketForwarder(u, channelFwd{3, ch}) s.RemovePacketForwarder(u, channelFwd{2, ch}) s.RemovePacketForwarder(u, channelFwd{1, ch}) s.AddPacketForwarder(u, channelFwd{1, ch}) s.RemovePacketForwarder(u, channelFwd{3, ch}) if ctx.Err() != nil { return } } }() // Number of messages is chosen arbitrarily, just for this loop to // run long enough concurrently with {Add,Remove}PacketForwarder loop above. numMsgs := 5000 var fwd PacketForwarder for i := 0; i < numMsgs; i++ { s.mu.Lock() fwd = s.clientsMesh[u] s.mu.Unlock() fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i))) } cancel() wg.Wait() if received != numMsgs { t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received) } } func TestMetaCert(t *testing.T) { priv := key.NewNode() pub := priv.Public() s := NewServer(priv, t.Logf) certBytes := s.MetaCert() cert, err := x509.ParseCertificate(certBytes) if err != nil { log.Fatal(err) } if fmt.Sprint(cert.SerialNumber) != fmt.Sprint(ProtocolVersion) { t.Errorf("serial = %v; want %v", cert.SerialNumber, ProtocolVersion) } if g, w := cert.Subject.CommonName, fmt.Sprintf("derpkey%s", pub.UntypedHexString()); g != w { t.Errorf("CommonName = %q; want %q", g, w) } if n := len(cert.Extensions); n != 1 { t.Fatalf("got %d extensions; want 1", n) } // oidExtensionBasicConstraints is the Basic Constraints ID copied // from the x509 package. oidExtensionBasicConstraints := asn1.ObjectIdentifier{2, 5, 29, 19} if id := cert.Extensions[0].Id; !id.Equal(oidExtensionBasicConstraints) { t.Errorf("extension ID = %v; want %v", id, oidExtensionBasicConstraints) } } type dummyNetConn struct { net.Conn } func (dummyNetConn) SetReadDeadline(time.Time) error { return nil } func TestClientRecv(t *testing.T) { tests := []struct { name string input []byte want any }{ { name: "ping", input: []byte{ byte(framePing), 0, 0, 0, 8, 1, 2, 3, 4, 5, 6, 7, 8, }, want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8}, }, { name: "pong", input: []byte{ byte(framePong), 0, 0, 0, 8, 1, 2, 3, 4, 5, 6, 7, 8, }, want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8}, }, { name: "health_bad", input: []byte{ byte(frameHealth), 0, 0, 0, 3, byte('B'), byte('A'), byte('D'), }, want: HealthMessage{Problem: "BAD"}, }, { name: "health_ok", input: []byte{ byte(frameHealth), 0, 0, 0, 0, }, want: HealthMessage{}, }, { name: "server_restarting", input: []byte{ byte(frameRestarting), 0, 0, 0, 8, 0, 0, 0, 1, 0, 0, 0, 2, }, want: ServerRestartingMessage{ ReconnectIn: 1 * time.Millisecond, TryFor: 2 * time.Millisecond, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Client{ nc: dummyNetConn{}, br: bufio.NewReader(bytes.NewReader(tt.input)), logf: t.Logf, clock: &tstest.Clock{}, } got, err := c.Recv() if err != nil { t.Fatal(err) } if !reflect.DeepEqual(got, tt.want) { t.Errorf("got %#v; want %#v", got, tt.want) } }) } } func TestClientSendPing(t *testing.T) { var buf bytes.Buffer c := &Client{ bw: bufio.NewWriter(&buf), } if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { t.Fatal(err) } want := []byte{ byte(framePing), 0, 0, 0, 8, 1, 2, 3, 4, 5, 6, 7, 8, } if !bytes.Equal(buf.Bytes(), want) { t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) } } func TestClientSendPong(t *testing.T) { var buf bytes.Buffer c := &Client{ bw: bufio.NewWriter(&buf), } if err := c.SendPong([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { t.Fatal(err) } want := []byte{ byte(framePong), 0, 0, 0, 8, 1, 2, 3, 4, 5, 6, 7, 8, } if !bytes.Equal(buf.Bytes(), want) { t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) } } func TestServerDupClients(t *testing.T) { serverPriv := key.NewNode() var s *Server clientPriv := key.NewNode() clientPub := clientPriv.Public() var c1, c2, c3 *sclient var clientName map[*sclient]string // run starts a new test case and resets clients back to their zero values. run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) { s = NewServer(serverPriv, t.Logf) s.dupPolicy = dupPolicy c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")} c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")} c3 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c3: ")} clientName = map[*sclient]string{ c1: "c1", c2: "c2", c3: "c3", } t.Run(name, f) } runBothWays := func(name string, f func(t *testing.T)) { run(name+"_disablefighters", disableFighters, f) run(name+"_lastwriteractive", lastWriterIsActive, f) } wantSingleClient := func(t *testing.T, want *sclient) { t.Helper() switch s := s.clients[want.key].(type) { case singleClient: if s.c != want { t.Error("wrong single client") return } if want.isDup.Load() { t.Errorf("unexpected isDup on singleClient") } if want.isDisabled.Load() { t.Errorf("unexpected isDisabled on singleClient") } case nil: t.Error("no clients for key") case *dupClientSet: t.Error("unexpected multiple clients for key") } } wantNoClient := func(t *testing.T) { t.Helper() switch s := s.clients[clientPub].(type) { case nil: // Good. return default: t.Errorf("got %T; want empty", s) } } wantDupSet := func(t *testing.T) *dupClientSet { t.Helper() switch s := s.clients[clientPub].(type) { case *dupClientSet: return s default: t.Fatalf("wanted dup set; got %T", s) return nil } } wantActive := func(t *testing.T, want *sclient) { t.Helper() set, ok := s.clients[clientPub] if !ok { t.Error("no set for key") return } got := set.ActiveClient() if got != want { t.Errorf("active client = %q; want %q", clientName[got], clientName[want]) } } checkDup := func(t *testing.T, c *sclient, want bool) { t.Helper() if got := c.isDup.Load(); got != want { t.Errorf("client %q isDup = %v; want %v", clientName[c], got, want) } } checkDisabled := func(t *testing.T, c *sclient, want bool) { t.Helper() if got := c.isDisabled.Load(); got != want { t.Errorf("client %q isDisabled = %v; want %v", clientName[c], got, want) } } wantDupConns := func(t *testing.T, want int) { t.Helper() if got := s.dupClientConns.Value(); got != int64(want) { t.Errorf("dupClientConns = %v; want %v", got, want) } } wantDupKeys := func(t *testing.T, want int) { t.Helper() if got := s.dupClientKeys.Value(); got != int64(want) { t.Errorf("dupClientKeys = %v; want %v", got, want) } } // Common case: a single client comes and goes, with no dups. runBothWays("one_comes_and_goes", func(t *testing.T) { wantNoClient(t) s.registerClient(c1) wantSingleClient(t, c1) s.unregisterClient(c1) wantNoClient(t) }) // A still somewhat common case: a single client was // connected and then their wifi dies or laptop closes // or they switch networks and connect from a // different network. They have two connections but // it's not very bad. Only their new one is // active. The last one, being dead, doesn't send and // thus the new one doesn't get disabled. runBothWays("small_overlap_replacement", func(t *testing.T) { wantNoClient(t) s.registerClient(c1) wantSingleClient(t, c1) wantActive(t, c1) wantDupKeys(t, 0) wantDupKeys(t, 0) s.registerClient(c2) // wifi dies; c2 replacement connects wantDupSet(t) wantDupConns(t, 2) wantDupKeys(t, 1) checkDup(t, c1, true) checkDup(t, c2, true) checkDisabled(t, c1, false) checkDisabled(t, c2, false) wantActive(t, c2) // sends go to the replacement s.unregisterClient(c1) // c1 finally times out wantSingleClient(t, c2) checkDup(t, c2, false) // c2 is longer a dup wantActive(t, c2) wantDupConns(t, 0) wantDupKeys(t, 0) }) // Key cloning situation with concurrent clients, both trying // to write. run("concurrent_dups_get_disabled", disableFighters, func(t *testing.T) { wantNoClient(t) s.registerClient(c1) wantSingleClient(t, c1) wantActive(t, c1) s.registerClient(c2) wantDupSet(t) wantDupKeys(t, 1) wantDupConns(t, 2) wantActive(t, c2) checkDup(t, c1, true) checkDup(t, c2, true) checkDisabled(t, c1, false) checkDisabled(t, c2, false) s.noteClientActivity(c2) checkDisabled(t, c1, false) checkDisabled(t, c2, false) s.noteClientActivity(c1) checkDisabled(t, c1, true) checkDisabled(t, c2, true) wantActive(t, nil) s.registerClient(c3) wantActive(t, c3) checkDisabled(t, c3, false) wantDupKeys(t, 1) wantDupConns(t, 3) s.unregisterClient(c3) wantActive(t, nil) wantDupKeys(t, 1) wantDupConns(t, 2) s.unregisterClient(c2) wantSingleClient(t, c1) wantDupKeys(t, 0) wantDupConns(t, 0) }) // Key cloning with an A->B->C->A series instead. run("concurrent_dups_three_parties", disableFighters, func(t *testing.T) { wantNoClient(t) s.registerClient(c1) s.registerClient(c2) s.registerClient(c3) s.noteClientActivity(c1) checkDisabled(t, c1, true) checkDisabled(t, c2, true) checkDisabled(t, c3, true) wantActive(t, nil) }) run("activity_promotes_primary_when_nil", disableFighters, func(t *testing.T) { wantNoClient(t) // Last registered client is the active one... s.registerClient(c1) wantActive(t, c1) s.registerClient(c2) wantActive(t, c2) s.registerClient(c3) s.noteClientActivity(c2) wantActive(t, c3) // But if the last one goes away, the one with the // most recent activity wins. s.unregisterClient(c3) wantActive(t, c2) }) run("concurrent_dups_three_parties_last_writer", lastWriterIsActive, func(t *testing.T) { wantNoClient(t) s.registerClient(c1) wantActive(t, c1) s.registerClient(c2) wantActive(t, c2) s.noteClientActivity(c1) checkDisabled(t, c1, false) checkDisabled(t, c2, false) wantActive(t, c1) s.noteClientActivity(c2) checkDisabled(t, c1, false) checkDisabled(t, c2, false) wantActive(t, c2) s.unregisterClient(c2) checkDisabled(t, c1, false) wantActive(t, c1) }) } func TestLimiter(t *testing.T) { rl := rate.NewLimiter(rate.Every(time.Minute), 100) for i := 0; i < 200; i++ { r := rl.Reserve() d := r.Delay() t.Logf("i=%d, allow=%v, d=%v", i, r.OK(), d) } } func BenchmarkSendRecv(b *testing.B) { for _, size := range []int{10, 100, 1000, 10000} { b.Run(fmt.Sprintf("msgsize=%d", size), func(b *testing.B) { benchmarkSendRecvSize(b, size) }) } } func benchmarkSendRecvSize(b *testing.B, packetSize int) { serverPrivateKey := key.NewNode() s := NewServer(serverPrivateKey, logger.Discard) defer s.Close() k := key.NewNode() clientKey := k.Public() ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { b.Fatal(err) } defer ln.Close() connOut, err := net.Dial("tcp", ln.Addr().String()) if err != nil { b.Fatal(err) } defer connOut.Close() connIn, err := ln.Accept() if err != nil { b.Fatal(err) } defer connIn.Close() brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go s.Accept(ctx, connIn, brwServer, "test-client") brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) client, err := NewClient(k, connOut, brw, logger.Discard) if err != nil { b.Fatalf("client: %v", err) } go func() { for { _, err := client.Recv() if err != nil { return } } }() msg := make([]byte, packetSize) b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { if err := client.Send(clientKey, msg); err != nil { b.Fatal(err) } } } func BenchmarkWriteUint32(b *testing.B) { w := bufio.NewWriter(io.Discard) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { writeUint32(w, 0x0ba3a) } } type nopRead struct{} func (r nopRead) Read(p []byte) (int, error) { return len(p), nil } var sinkU32 uint32 func BenchmarkReadUint32(b *testing.B) { r := bufio.NewReader(nopRead{}) var err error b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { sinkU32, err = readUint32(r) if err != nil { b.Fatal(err) } } } func waitConnect(t testing.TB, c *Client) { t.Helper() if m, err := c.Recv(); err != nil { t.Fatalf("client first Recv: %v", err) } else if v, ok := m.(ServerInfoMessage); !ok { t.Fatalf("client first Recv was unexpected type %T", v) } } func TestParseSSOutput(t *testing.T) { contents, err := os.ReadFile("testdata/example_ss.txt") if err != nil { t.Errorf("os.ReadFile(example_ss.txt) failed: %v", err) } seen := parseSSOutput(string(contents)) if len(seen) == 0 { t.Errorf("parseSSOutput expected non-empty map") } } type countWriter struct { mu sync.Mutex writes int bytes int64 } func (w *countWriter) Write(p []byte) (n int, err error) { w.mu.Lock() defer w.mu.Unlock() w.writes++ w.bytes += int64(len(p)) return len(p), nil } func (w *countWriter) Stats() (writes int, bytes int64) { w.mu.Lock() defer w.mu.Unlock() return w.writes, w.bytes } func (w *countWriter) ResetStats() { w.mu.Lock() defer w.mu.Unlock() w.writes, w.bytes = 0, 0 } func TestClientSendRateLimiting(t *testing.T) { cw := new(countWriter) c := &Client{ bw: bufio.NewWriter(cw), clock: &tstest.Clock{}, } c.setSendRateLimiter(ServerInfoMessage{}) pkt := make([]byte, 1000) if err := c.send(key.NodePublic{}, pkt); err != nil { t.Fatal(err) } writes1, bytes1 := cw.Stats() if writes1 != 1 { t.Errorf("writes = %v, want 1", writes1) } // Flood should all succeed. cw.ResetStats() for i := 0; i < 1000; i++ { if err := c.send(key.NodePublic{}, pkt); err != nil { t.Fatal(err) } } writes1K, bytes1K := cw.Stats() if writes1K != 1000 { t.Logf("writes = %v; want 1000", writes1K) } if got, want := bytes1K, bytes1*1000; got != want { t.Logf("bytes = %v; want %v", got, want) } // Set a rate limiter cw.ResetStats() c.setSendRateLimiter(ServerInfoMessage{ TokenBucketBytesPerSecond: 1, TokenBucketBytesBurst: int(bytes1 * 2), }) for i := 0; i < 1000; i++ { if err := c.send(key.NodePublic{}, pkt); err != nil { t.Fatal(err) } } writesLimited, bytesLimited := cw.Stats() if writesLimited == 0 || writesLimited == writes1K { t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited) } if bytesLimited < bytes1*2 || bytesLimited >= bytes1K { t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K) } } func TestServerRepliesToPing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() ts := newTestServer(t, ctx) defer ts.close(t) tc := newRegularClient(t, ts, "alice") data := [8]byte{1, 2, 3, 4, 5, 6, 7, 42} if err := tc.c.SendPing(data); err != nil { t.Fatal(err) } for { m, err := tc.c.recvTimeout(time.Second) if err != nil { t.Fatal(err) } switch m := m.(type) { case PongMessage: if ([8]byte(m)) != data { t.Fatalf("got pong %2x; want %2x", [8]byte(m), data) } return } } }