diff --git a/derp/derp_test.go b/derp/derp_test.go index f981dc797..de67e58c6 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -9,7 +9,6 @@ import ( "context" crand "crypto/rand" "errors" - "expvar" "fmt" "io" "net" @@ -215,10 +214,27 @@ func TestSendFreeze(t *testing.T) { cathyKey := newPrivateKey(t) cathyClient, cathyConn := newClient("cathy", cathyKey) - var aliceCount, bobCount, cathyCount expvar.Int + var ( + aliceCh = make(chan struct{}, 32) + bobCh = make(chan struct{}, 32) + cathyCh = make(chan struct{}, 32) + ) + chs := func(name string) chan struct{} { + switch name { + case "alice": + return aliceCh + case "bob": + return bobCh + case "cathy": + return cathyCh + default: + panic("unknown ch: " + name) + } + } errCh := make(chan error, 4) - recvAndCount := func(count *expvar.Int, name string, client *Client) { + recv := func(name string, client *Client) { + ch := chs(name) for { b := make([]byte, 1<<9) m, err := client.Recv(b) @@ -235,13 +251,16 @@ func TestSendFreeze(t *testing.T) { errCh <- fmt.Errorf("%s: zero Source address in ReceivedPacket", name) return } - count.Add(1) + select { + case ch <- struct{}{}: + default: + } } } } - go recvAndCount(&aliceCount, "alice", aliceClient) - go recvAndCount(&bobCount, "bob", bobClient) - go recvAndCount(&cathyCount, "cathy", cathyClient) + go recv("alice", aliceClient) + go recv("bob", bobClient) + go recv("cathy", cathyClient) var cancel func() go func() { @@ -270,38 +289,52 @@ func TestSendFreeze(t *testing.T) { } }() - var countSnapshot [3]int64 - loadCounts := func() (adiff, bdiff, cdiff int64) { + drainAny := func(ch chan struct{}) { + // We are draining potentially infinite sources, + // so place some reasonable upper limit. + // + // The important thing here is to make sure that + // if any tokens remain in the channel, they + // must have been generated after drainAny was + // called. + for i := 0; i < cap(ch); i++ { + select { + case <-ch: + default: + return + } + } + } + drain := func(t *testing.T, name string) bool { t.Helper() + timer := time.NewTimer(1 * time.Second) + defer timer.Stop() - atotal := aliceCount.Value() - btotal := bobCount.Value() - ctotal := cathyCount.Value() - - adiff = atotal - countSnapshot[0] - bdiff = btotal - countSnapshot[1] - cdiff = ctotal - countSnapshot[2] - - countSnapshot[0] = atotal - countSnapshot[1] = btotal - countSnapshot[2] = ctotal - - t.Logf("count diffs: alice=%d, bob=%d, cathy=%d", adiff, bdiff, cdiff) - return adiff, bdiff, cdiff + // Ensure ch has at least one element. + ch := chs(name) + select { + case <-ch: + case <-timer.C: + t.Errorf("no packet received by %s", name) + return false + } + // Drain remaining. + drainAny(ch) + return true + } + isEmpty := func(t *testing.T, name string) { + t.Helper() + select { + case <-chs(name): + t.Errorf("packet received by %s, want none", name) + default: + } } t.Run("initial send", func(t *testing.T) { - time.Sleep(10 * time.Millisecond) - a, b, c := loadCounts() - if a != 0 { - t.Errorf("alice diff=%d, want 0", a) - } - if b == 0 { - t.Errorf("no bob diff, want positive value") - } - if c == 0 { - t.Errorf("no cathy diff, want positive value") - } + drain(t, "bob") + drain(t, "cathy") + isEmpty(t, "alice") }) t.Run("block cathy", func(t *testing.T) { @@ -310,17 +343,12 @@ func TestSendFreeze(t *testing.T) { cathyConn.SetReadBlock(true) time.Sleep(2 * s.WriteTimeout) - a, b, _ := loadCounts() - if a != 0 { - t.Errorf("alice diff=%d, want 0", a) - } - if b == 0 { - t.Errorf("no bob diff, want positive value") - } + drain(t, "bob") + drainAny(chs("cathy")) + isEmpty(t, "alice") // Now wait a little longer, and ensure packets still flow to bob - time.Sleep(10 * time.Millisecond) - if _, b, _ := loadCounts(); b == 0 { + if !drain(t, "bob") { t.Errorf("connection alice->bob frozen by alice->cathy") } })