diff --git a/derp/derp_test.go b/derp/derp_test.go index 5585c61c1..f981dc797 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -6,15 +6,22 @@ package derp import ( "bufio" + "context" crand "crypto/rand" + "errors" + "expvar" + "fmt" + "io" "net" "testing" "time" + "tailscale.com/net/nettest" "tailscale.com/types/key" ) func newPrivateKey(t *testing.T) (k key.Private) { + t.Helper() if _, err := crand.Read(k[:]); err != nil { t.Fatal(err) } @@ -42,7 +49,7 @@ func TestSendRecv(t *testing.T) { defer ln.Close() var clients []*Client - var connsOut []net.Conn + var connsOut []Conn var recvChs []chan []byte errCh := make(chan error, 3) @@ -171,3 +178,168 @@ func TestSendRecv(t *testing.T) { t.Logf("passed") s.Close() } + +func TestSendFreeze(t *testing.T) { + serverPrivateKey := newPrivateKey(t) + s := NewServer(serverPrivateKey, t.Logf) + defer s.Close() + s.WriteTimeout = 100 * time.Millisecond + + // We send two streams of messages: + // + // alice --> bob + // alice --> cathy + // + // Then cathy stops processing messsages. + // That should not interfere with alice talking to bob. + + newClient := func(name string, k key.Private) (c *Client, clientConn nettest.Conn) { + t.Helper() + c1, c2 := nettest.NewConn(name, 1024) + go s.Accept(c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name) + + brw := bufio.NewReadWriter(bufio.NewReader(c2), bufio.NewWriter(c2)) + c, err := NewClient(k, c2, brw, t.Logf) + if err != nil { + t.Fatal(err) + } + return c, c2 + } + + aliceKey := newPrivateKey(t) + aliceClient, aliceConn := newClient("alice", aliceKey) + + bobKey := newPrivateKey(t) + bobClient, bobConn := newClient("bob", bobKey) + + cathyKey := newPrivateKey(t) + cathyClient, cathyConn := newClient("cathy", cathyKey) + + var aliceCount, bobCount, cathyCount expvar.Int + + errCh := make(chan error, 4) + recvAndCount := func(count *expvar.Int, name string, client *Client) { + for { + b := make([]byte, 1<<9) + m, err := client.Recv(b) + if err != nil { + errCh <- fmt.Errorf("%s: %w", name, err) + return + } + switch m := m.(type) { + default: + errCh <- fmt.Errorf("%s: unexpected message type %T", name, m) + return + case ReceivedPacket: + if m.Source.IsZero() { + errCh <- fmt.Errorf("%s: zero Source address in ReceivedPacket", name) + return + } + count.Add(1) + } + } + } + go recvAndCount(&aliceCount, "alice", aliceClient) + go recvAndCount(&bobCount, "bob", bobClient) + go recvAndCount(&cathyCount, "cathy", cathyClient) + + var cancel func() + go func() { + t := time.NewTicker(2 * time.Millisecond) + defer t.Stop() + var ctx context.Context + ctx, cancel = context.WithCancel(context.Background()) + for { + select { + case <-t.C: + case <-ctx.Done(): + errCh <- nil + return + } + + msg1 := []byte("hello alice->bob\n") + if err := aliceClient.Send(bobKey.Public(), msg1); err != nil { + errCh <- fmt.Errorf("alice send to bob: %w", err) + return + } + msg2 := []byte("hello alice->cathy\n") + + // TODO: an error is expected here. + // We ignore it, maybe we should log it somehow? + aliceClient.Send(cathyKey.Public(), msg2) + } + }() + + var countSnapshot [3]int64 + loadCounts := func() (adiff, bdiff, cdiff int64) { + t.Helper() + + atotal := aliceCount.Value() + btotal := bobCount.Value() + ctotal := cathyCount.Value() + + adiff = atotal - countSnapshot[0] + bdiff = btotal - countSnapshot[1] + cdiff = ctotal - countSnapshot[2] + + countSnapshot[0] = atotal + countSnapshot[1] = btotal + countSnapshot[2] = ctotal + + t.Logf("count diffs: alice=%d, bob=%d, cathy=%d", adiff, bdiff, cdiff) + return adiff, bdiff, cdiff + } + + t.Run("initial send", func(t *testing.T) { + time.Sleep(10 * time.Millisecond) + a, b, c := loadCounts() + if a != 0 { + t.Errorf("alice diff=%d, want 0", a) + } + if b == 0 { + t.Errorf("no bob diff, want positive value") + } + if c == 0 { + t.Errorf("no cathy diff, want positive value") + } + }) + + t.Run("block cathy", func(t *testing.T) { + // Block cathy. Now the cathyConn buffer will fill up quickly, + // and the derp server will back up. + cathyConn.SetReadBlock(true) + time.Sleep(2 * s.WriteTimeout) + + a, b, _ := loadCounts() + if a != 0 { + t.Errorf("alice diff=%d, want 0", a) + } + if b == 0 { + t.Errorf("no bob diff, want positive value") + } + + // Now wait a little longer, and ensure packets still flow to bob + time.Sleep(10 * time.Millisecond) + if _, b, _ := loadCounts(); b == 0 { + t.Errorf("connection alice->bob frozen by alice->cathy") + } + }) + + // Cleanup, make sure we process all errors. + t.Logf("TEST COMPLETE, cancelling sender") + cancel() + t.Logf("closing connections") + aliceConn.Close() + bobConn.Close() + cathyConn.Close() + + for i := 0; i < cap(errCh); i++ { + err := <-errCh + if err != nil { + if errors.Is(err, io.EOF) { + continue + } + t.Error(err) + } + } +}