From 4cb1bfee44e019dbdd196e0d95f0fff3bdc59097 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Fri, 24 Mar 2023 11:25:32 -0400 Subject: [PATCH] net/netcheck: improve determinism in hairpinning test If multiple Go channels have a value (or are closed), receiving from them all in a select will nondeterministically return one of the two arms. In this case, it's possible that the hairpin check timer will have expired between when we start checking and before we check at all, but the hairpin packet has already been received. In such cases, we'd nondeterministically set report.HairPinning. Instead, check if we have a value in our results channel first, then select on the value and timeout channel after. Also, add a test that catches this particular failure. Fixes #1795 Change-Id: I842ab0bd38d66fabc6cabf2c2c1bb9bd32febf35 Signed-off-by: Andrew Dunham --- net/netcheck/netcheck.go | 19 +++--- net/netcheck/netcheck_test.go | 107 ++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index 1897fa372..d53eb7ee2 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -638,20 +638,23 @@ func (rs *reportState) waitHairCheck(ctx context.Context) { return } + // First, check whether we have a value before we check for timeouts. + select { + case <-rs.gotHairSTUN: + ret.HairPinning.Set(true) + return + default: + } + + // Now, wait for a response or a timeout. select { case <-rs.gotHairSTUN: ret.HairPinning.Set(true) case <-rs.hairTimeout: rs.c.vlogf("hairCheck timeout") ret.HairPinning.Set(false) - default: - select { - case <-rs.gotHairSTUN: - ret.HairPinning.Set(true) - case <-rs.hairTimeout: - ret.HairPinning.Set(false) - case <-ctx.Done(): - } + case <-ctx.Done(): + rs.c.vlogf("hairCheck context timeout") } } diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 797889926..4d4bc4a2f 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -47,6 +47,113 @@ func TestHairpinSTUN(t *testing.T) { } } +func TestHairpinWait(t *testing.T) { + makeClient := func(t *testing.T) (*Client, *reportState) { + tx := stun.NewTxID() + c := &Client{} + req := stun.Request(tx) + if !stun.Is(req) { + t.Fatal("expected STUN message") + } + + var err error + rs := &reportState{ + c: c, + hairTX: tx, + gotHairSTUN: make(chan netip.AddrPort, 1), + hairTimeout: make(chan struct{}), + report: newReport(), + } + rs.pc4Hair, err = net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + }) + if err != nil { + t.Fatal(err) + } + + c.curState = rs + return c, rs + } + + ll, err := net.ListenPacket("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ll.Close() + dstAddr := netip.MustParseAddrPort(ll.LocalAddr().String()) + + t.Run("Success", func(t *testing.T) { + c, rs := makeClient(t) + req := stun.Request(rs.hairTX) + + // Start a hairpin check to ourselves. + rs.startHairCheckLocked(dstAddr) + + // Fake receiving the stun check from ourselves after some period of time. + src := netip.MustParseAddrPort(rs.pc4Hair.LocalAddr().String()) + c.handleHairSTUNLocked(req, src) + + rs.waitHairCheck(context.Background()) + + // Verify that we set HairPinning + if got := rs.report.HairPinning; !got.EqualBool(true) { + t.Errorf("wanted HairPinning=true, got %v", got) + } + }) + + t.Run("LateReply", func(t *testing.T) { + c, rs := makeClient(t) + req := stun.Request(rs.hairTX) + + // Start a hairpin check to ourselves. + rs.startHairCheckLocked(dstAddr) + + // Wait until we've timed out, to mimic the race in #1795. + <-rs.hairTimeout + + // Fake receiving the stun check from ourselves after some period of time. + src := netip.MustParseAddrPort(rs.pc4Hair.LocalAddr().String()) + c.handleHairSTUNLocked(req, src) + + // Wait for a hairpin response + rs.waitHairCheck(context.Background()) + + // Verify that we set HairPinning + if got := rs.report.HairPinning; !got.EqualBool(true) { + t.Errorf("wanted HairPinning=true, got %v", got) + } + }) + + t.Run("Timeout", func(t *testing.T) { + _, rs := makeClient(t) + + // Start a hairpin check to ourselves. + rs.startHairCheckLocked(dstAddr) + + ctx, cancel := context.WithTimeout(context.Background(), hairpinCheckTimeout*50) + defer cancel() + + // Wait in the background + waitDone := make(chan struct{}) + go func() { + rs.waitHairCheck(ctx) + close(waitDone) + }() + + // If we do nothing, then we time out; confirm that we set + // HairPinning to false in this case. + select { + case <-waitDone: + if got := rs.report.HairPinning; !got.EqualBool(false) { + t.Errorf("wanted HairPinning=false, got %v", got) + } + case <-ctx.Done(): + t.Fatalf("timed out waiting for hairpin channel") + } + }) +} + func TestBasic(t *testing.T) { stunAddr, cleanup := stuntest.Serve(t) defer cleanup()