diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index 1897fa372..d53eb7ee2 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -638,20 +638,23 @@ func (rs *reportState) waitHairCheck(ctx context.Context) { return } + // First, check whether we have a value before we check for timeouts. + select { + case <-rs.gotHairSTUN: + ret.HairPinning.Set(true) + return + default: + } + + // Now, wait for a response or a timeout. select { case <-rs.gotHairSTUN: ret.HairPinning.Set(true) case <-rs.hairTimeout: rs.c.vlogf("hairCheck timeout") ret.HairPinning.Set(false) - default: - select { - case <-rs.gotHairSTUN: - ret.HairPinning.Set(true) - case <-rs.hairTimeout: - ret.HairPinning.Set(false) - case <-ctx.Done(): - } + case <-ctx.Done(): + rs.c.vlogf("hairCheck context timeout") } } diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 797889926..4d4bc4a2f 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -47,6 +47,113 @@ func TestHairpinSTUN(t *testing.T) { } } +func TestHairpinWait(t *testing.T) { + makeClient := func(t *testing.T) (*Client, *reportState) { + tx := stun.NewTxID() + c := &Client{} + req := stun.Request(tx) + if !stun.Is(req) { + t.Fatal("expected STUN message") + } + + var err error + rs := &reportState{ + c: c, + hairTX: tx, + gotHairSTUN: make(chan netip.AddrPort, 1), + hairTimeout: make(chan struct{}), + report: newReport(), + } + rs.pc4Hair, err = net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + }) + if err != nil { + t.Fatal(err) + } + + c.curState = rs + return c, rs + } + + ll, err := net.ListenPacket("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ll.Close() + dstAddr := netip.MustParseAddrPort(ll.LocalAddr().String()) + + t.Run("Success", func(t *testing.T) { + c, rs := makeClient(t) + req := stun.Request(rs.hairTX) + + // Start a hairpin check to ourselves. + rs.startHairCheckLocked(dstAddr) + + // Fake receiving the stun check from ourselves after some period of time. + src := netip.MustParseAddrPort(rs.pc4Hair.LocalAddr().String()) + c.handleHairSTUNLocked(req, src) + + rs.waitHairCheck(context.Background()) + + // Verify that we set HairPinning + if got := rs.report.HairPinning; !got.EqualBool(true) { + t.Errorf("wanted HairPinning=true, got %v", got) + } + }) + + t.Run("LateReply", func(t *testing.T) { + c, rs := makeClient(t) + req := stun.Request(rs.hairTX) + + // Start a hairpin check to ourselves. + rs.startHairCheckLocked(dstAddr) + + // Wait until we've timed out, to mimic the race in #1795. + <-rs.hairTimeout + + // Fake receiving the stun check from ourselves after some period of time. + src := netip.MustParseAddrPort(rs.pc4Hair.LocalAddr().String()) + c.handleHairSTUNLocked(req, src) + + // Wait for a hairpin response + rs.waitHairCheck(context.Background()) + + // Verify that we set HairPinning + if got := rs.report.HairPinning; !got.EqualBool(true) { + t.Errorf("wanted HairPinning=true, got %v", got) + } + }) + + t.Run("Timeout", func(t *testing.T) { + _, rs := makeClient(t) + + // Start a hairpin check to ourselves. + rs.startHairCheckLocked(dstAddr) + + ctx, cancel := context.WithTimeout(context.Background(), hairpinCheckTimeout*50) + defer cancel() + + // Wait in the background + waitDone := make(chan struct{}) + go func() { + rs.waitHairCheck(ctx) + close(waitDone) + }() + + // If we do nothing, then we time out; confirm that we set + // HairPinning to false in this case. + select { + case <-waitDone: + if got := rs.report.HairPinning; !got.EqualBool(false) { + t.Errorf("wanted HairPinning=false, got %v", got) + } + case <-ctx.Done(): + t.Fatalf("timed out waiting for hairpin channel") + } + }) +} + func TestBasic(t *testing.T) { stunAddr, cleanup := stuntest.Serve(t) defer cleanup()