diff --git a/net/captivedetection/captivedetection.go b/net/captivedetection/captivedetection.go index c6e8bca3a..7d598d853 100644 --- a/net/captivedetection/captivedetection.go +++ b/net/captivedetection/captivedetection.go @@ -136,26 +136,31 @@ func interfaceNameDoesNotNeedCaptiveDetection(ifName string, goos string) bool { func (d *Detector) detectOnInterface(ctx context.Context, ifIndex int, endpoints []Endpoint) bool { defer d.httpClient.CloseIdleConnections() - d.logf("[v2] %d available captive portal detection endpoints: %v", len(endpoints), endpoints) + use := min(len(endpoints), 5) + endpoints = endpoints[:use] + d.logf("[v2] %d available captive portal detection endpoints; trying %v", len(endpoints), use) // We try to detect the captive portal more quickly by making requests to multiple endpoints concurrently. var wg sync.WaitGroup resultCh := make(chan bool, len(endpoints)) - for i, e := range endpoints { - if i >= 5 { - // Try a maximum of 5 endpoints, break out (returning false) if we run of attempts. - break - } + // Once any goroutine detects a captive portal, we shut down the others. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + for _, e := range endpoints { wg.Add(1) go func(endpoint Endpoint) { defer wg.Done() found, err := d.verifyCaptivePortalEndpoint(ctx, endpoint, ifIndex) if err != nil { - d.logf("[v1] checkCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + if ctx.Err() == nil { + d.logf("[v1] checkCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + } return } if found { + cancel() // one match is good enough resultCh <- true } }(e) diff --git a/net/captivedetection/captivedetection_test.go b/net/captivedetection/captivedetection_test.go index e74273afd..29a197d31 100644 --- a/net/captivedetection/captivedetection_test.go +++ b/net/captivedetection/captivedetection_test.go @@ -7,10 +7,12 @@ import ( "context" "runtime" "sync" + "sync/atomic" "testing" - "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/net/netmon" + "tailscale.com/syncs" + "tailscale.com/tstest/nettest" ) func TestAvailableEndpointsAlwaysAtLeastTwo(t *testing.T) { @@ -36,25 +38,46 @@ func TestDetectCaptivePortalReturnsFalse(t *testing.T) { } } -func TestAllEndpointsAreUpAndReturnExpectedResponse(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13019") +func TestEndpointsAreUpAndReturnExpectedResponse(t *testing.T) { + nettest.SkipIfNoNetwork(t) + d := NewDetector(t.Logf) endpoints := availableEndpoints(nil, 0, t.Logf, runtime.GOOS) + t.Logf("testing %d endpoints", len(endpoints)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var good atomic.Bool var wg sync.WaitGroup + sem := syncs.NewSemaphore(5) for _, e := range endpoints { wg.Add(1) go func(endpoint Endpoint) { defer wg.Done() - found, err := d.verifyCaptivePortalEndpoint(context.Background(), endpoint, 0) - if err != nil { - t.Errorf("verifyCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + + if !sem.AcquireContext(ctx) { + return + } + defer sem.Release() + + found, err := d.verifyCaptivePortalEndpoint(ctx, endpoint, 0) + if err != nil && ctx.Err() == nil { + t.Logf("verifyCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) } if found { - t.Errorf("verifyCaptivePortalEndpoint with endpoint %v says we're behind a captive portal, but we aren't", endpoint) + t.Logf("verifyCaptivePortalEndpoint with endpoint %v says we're behind a captive portal, but we aren't", endpoint) + return } + good.Store(true) + t.Logf("endpoint good: %v", endpoint) + cancel() }(e) } wg.Wait() + + if !good.Load() { + t.Errorf("no good endpoints found") + } }