diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index c5fcdf5c3..00433d1be 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -26,6 +26,7 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tstest" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -172,8 +173,12 @@ func testControlHTTP(t *testing.T, param httpTestParam) { } var httpHandler http.Handler = handler + const fallbackDelay = 50 * time.Millisecond + clock := tstest.NewClock(tstest.ClockOpts{Step: 2 * fallbackDelay}) + // Advance once to init the clock. + clock.Now() if param.makeHTTPHangAfterUpgrade { - httpHandler = http.HandlerFunc(brokenMITMHandler) + httpHandler = brokenMITMHandler(clock) } httpServer := &http.Server{Handler: httpHandler} go httpServer.Serve(httpLn) @@ -204,8 +209,8 @@ func testControlHTTP(t *testing.T, param httpTestParam) { Dialer: new(tsdial.Dialer).SystemDial, Logf: t.Logf, omitCertErrorLogging: true, - testFallbackDelay: 50 * time.Millisecond, - Clock: &tstest.Clock{}, + testFallbackDelay: fallbackDelay, + Clock: clock, } if proxy != nil { @@ -471,12 +476,16 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== } } -func brokenMITMHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Upgrade", upgradeHeaderValue) - w.Header().Set("Connection", "upgrade") - w.WriteHeader(http.StatusSwitchingProtocols) - w.(http.Flusher).Flush() - <-r.Context().Done() +func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Upgrade", upgradeHeaderValue) + w.Header().Set("Connection", "upgrade") + w.WriteHeader(http.StatusSwitchingProtocols) + w.(http.Flusher).Flush() + // Advance the clock to trigger HTTPs fallback. + clock.Now() + <-r.Context().Done() + } } func TestDialPlan(t *testing.T) { @@ -621,12 +630,15 @@ func TestDialPlan(t *testing.T) { } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { + // TODO(awly): replace this with tstest.NewClock and update the + // test to advance the clock correctly. + clock := tstime.StdClock{} makeHandler(t, "fallback", fallbackAddr, nil) makeHandler(t, "good", goodAddr, nil) makeHandler(t, "other", otherAddr, nil) makeHandler(t, "other2", other2Addr, nil) makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler { - return http.HandlerFunc(brokenMITMHandler) + return brokenMITMHandler(clock) }) dialer := closeTrackDialer{ @@ -662,7 +674,7 @@ func TestDialPlan(t *testing.T) { drainFinished: drained, omitCertErrorLogging: true, testFallbackDelay: 50 * time.Millisecond, - Clock: &tstest.Clock{}, + Clock: clock, } conn, err := a.dial(ctx)