diff --git a/safeweb/http.go b/safeweb/http.go index 4181f9d0c..c2787611e 100644 --- a/safeweb/http.go +++ b/safeweb/http.go @@ -300,3 +300,8 @@ func (s *Server) ServeRedirectHTTP(ln net.Listener, fqdn string) error { func (s *Server) Serve(ln net.Listener) error { return s.h.Serve(ln) } + +// Close closes all client connections and stops accepting new ones. +func (s *Server) Close() error { + return s.h.Close() +} diff --git a/safeweb/http_test.go b/safeweb/http_test.go index c5e2f9cbd..f48aa64a7 100644 --- a/safeweb/http_test.go +++ b/safeweb/http_test.go @@ -80,6 +80,7 @@ func TestPostRequestContentTypeValidation(t *testing.T) { if err != nil { t.Fatal(err) } + defer s.Close() req := httptest.NewRequest("POST", "/", nil) req.Header.Set("Content-Type", tt.contentType) @@ -137,6 +138,7 @@ func TestAPIMuxCrossOriginResourceSharingHeaders(t *testing.T) { if err != nil { t.Fatal(err) } + defer s.Close() req := httptest.NewRequest(tt.httpMethod, "/", nil) w := httptest.NewRecorder() @@ -192,6 +194,7 @@ func TestCSRFProtection(t *testing.T) { if err != nil { t.Fatal(err) } + defer s.Close() // construct the test request req := httptest.NewRequest("POST", "/", nil) @@ -267,6 +270,7 @@ func TestContentSecurityPolicyHeader(t *testing.T) { if err != nil { t.Fatal(err) } + defer s.Close() req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -307,6 +311,7 @@ func TestCSRFCookieSecureMode(t *testing.T) { if err != nil { t.Fatal(err) } + defer s.Close() req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -355,6 +360,7 @@ func TestRefererPolicy(t *testing.T) { if err != nil { t.Fatal(err) } + defer s.Close() req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -379,6 +385,7 @@ func TestCSPAllowInlineStyles(t *testing.T) { if err != nil { t.Fatal(err) } + defer s.Close() req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -474,6 +481,7 @@ func TestRouting(t *testing.T) { if err != nil { t.Fatal(err) } + defer s.Close() req := httptest.NewRequest("GET", tt.requestPath, nil) w := httptest.NewRecorder()