diff --git a/safeweb/http.go b/safeweb/http.go index 8abd169d6..b41a1855d 100644 --- a/safeweb/http.go +++ b/safeweb/http.go @@ -72,6 +72,7 @@ package safeweb import ( crand "crypto/rand" "fmt" + "log" "net" "net/http" "net/url" @@ -148,48 +149,12 @@ func (c *Config) setDefaults() error { return nil } -func (c Config) newHandler() http.Handler { - // only set Secure flag on CSRF cookies if we are in a secure context - // as otherwise the browser will reject the cookie - csrfProtect := csrf.Protect(c.CSRFSecret, csrf.Secure(c.SecureContext)) - - var csp string - if c.CSPAllowInlineStyles { - csp = defaultCSP + `; style-src 'self' 'unsafe-inline'` - } else { - // if no style-src is provided the browser will fallback to the - // default-src directive which disallows inline styles. - csp = defaultCSP - } - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if _, p := c.BrowserMux.Handler(r); p == "" { - // disallow x-www-form-urlencoded requests to the API - if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" { - http.Error(w, "invalid content type", http.StatusBadRequest) - return - } - - // set CORS headers for pre-flight OPTIONS requests if any were configured - if r.Method == "OPTIONS" && len(c.AccessControlAllowOrigin) > 0 { - w.Header().Set("Access-Control-Allow-Origin", strings.Join(c.AccessControlAllowOrigin, ", ")) - w.Header().Set("Access-Control-Allow-Methods", strings.Join(c.AccessControlAllowMethods, ", ")) - } - c.APIMux.ServeHTTP(w, r) - return - } - - w.Header().Set("Content-Security-Policy", csp) - w.Header().Set("X-Content-Type-Options", "nosniff") - w.Header().Set("Referer-Policy", "same-origin") - csrfProtect(c.BrowserMux).ServeHTTP(w, r) - }) -} - // Server is a safeweb server. type Server struct { Config - h *http.Server + h *http.Server + csp string + csrfProtect func(http.Handler) http.Handler } // NewServer creates a safeweb server with the provided configuration. It will @@ -208,10 +173,73 @@ func NewServer(config Config) (*Server, error) { return nil, fmt.Errorf("failed to set defaults: %w", err) } - return &Server{ - config, - &http.Server{Handler: config.newHandler()}, - }, nil + s := &Server{ + Config: config, + csp: defaultCSP, + // only set Secure flag on CSRF cookies if we are in a secure context + // as otherwise the browser will reject the cookie + csrfProtect: csrf.Protect(config.CSRFSecret, csrf.Secure(config.SecureContext)), + } + if config.CSPAllowInlineStyles { + s.csp = defaultCSP + `; style-src 'self' 'unsafe-inline'` + } + s.h = &http.Server{Handler: s} + return s, nil +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + _, bp := s.BrowserMux.Handler(r) + _, ap := s.APIMux.Handler(r) + switch { + case bp == "" && ap != "": // APIMux match + s.serveAPI(w, r) + case bp != "" && ap == "": // BrowserMux match + s.serveBrowser(w, r) + case bp == "" && ap == "": // neither match + http.NotFound(w, r) + case bp != "" && ap != "": + // Both muxes match the path. This can be because: + // * one of them registers a wildcard "/" handler + // * there are overlapping specific handlers + // + // If it's the former, route to the more-specific handler. If it's the + // latter - that's a bug so return an error to avoid mis-routing the + // request. + // + // TODO(awly): match the longest path instead of only special-casing + // "/". + switch { + case bp == "/": + s.serveAPI(w, r) + case ap == "/": + s.serveBrowser(w, r) + default: + log.Printf("conflicting mux paths in safeweb: request %q matches browser mux pattern %q and API mux patter %q; returning 500", r.URL.Path, bp, ap) + http.Error(w, "multiple handlers match this request", http.StatusInternalServerError) + } + } +} + +func (s *Server) serveAPI(w http.ResponseWriter, r *http.Request) { + // disallow x-www-form-urlencoded requests to the API + if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" { + http.Error(w, "invalid content type", http.StatusBadRequest) + return + } + + // set CORS headers for pre-flight OPTIONS requests if any were configured + if r.Method == "OPTIONS" && len(s.AccessControlAllowOrigin) > 0 { + w.Header().Set("Access-Control-Allow-Origin", strings.Join(s.AccessControlAllowOrigin, ", ")) + w.Header().Set("Access-Control-Allow-Methods", strings.Join(s.AccessControlAllowMethods, ", ")) + } + s.APIMux.ServeHTTP(w, r) +} + +func (s *Server) serveBrowser(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Security-Policy", s.csp) + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Referer-Policy", "same-origin") + s.csrfProtect(s.BrowserMux).ServeHTTP(w, r) } // RedirectHTTP returns a handler that redirects all incoming HTTP requests to diff --git a/safeweb/http_test.go b/safeweb/http_test.go index 8131c2a97..e179cc4e9 100644 --- a/safeweb/http_test.go +++ b/safeweb/http_test.go @@ -4,6 +4,7 @@ package safeweb import ( + "io" "net/http" "net/http/httptest" "strconv" @@ -392,3 +393,98 @@ func TestCSPAllowInlineStyles(t *testing.T) { }) } } + +func TestRouting(t *testing.T) { + for _, tt := range []struct { + desc string + browserPatterns []string + apiPatterns []string + requestPath string + want string + }{ + { + desc: "only browser mux", + browserPatterns: []string{"/"}, + requestPath: "/index.html", + want: "browser", + }, + { + desc: "only API mux", + apiPatterns: []string{"/api/"}, + requestPath: "/api/foo", + want: "api", + }, + { + desc: "browser mux match", + browserPatterns: []string{"/content/"}, + apiPatterns: []string{"/api/"}, + requestPath: "/content/index.html", + want: "browser", + }, + { + desc: "API mux match", + browserPatterns: []string{"/content/"}, + apiPatterns: []string{"/api/"}, + requestPath: "/api/foo", + want: "api", + }, + { + desc: "browser wildcard match", + browserPatterns: []string{"/"}, + apiPatterns: []string{"/api/"}, + requestPath: "/index.html", + want: "browser", + }, + { + desc: "API wildcard match", + browserPatterns: []string{"/content/"}, + apiPatterns: []string{"/"}, + requestPath: "/api/foo", + want: "api", + }, + { + desc: "path conflict", + browserPatterns: []string{"/foo/"}, + apiPatterns: []string{"/foo/bar/"}, + requestPath: "/foo/bar/baz", + want: "multiple handlers match this request", + }, + { + desc: "no match", + browserPatterns: []string{"/foo/"}, + apiPatterns: []string{"/bar/"}, + requestPath: "/baz", + want: "404 page not found", + }, + } { + t.Run(tt.desc, func(t *testing.T) { + bm := &http.ServeMux{} + for _, p := range tt.browserPatterns { + bm.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("browser")) + }) + } + am := &http.ServeMux{} + for _, p := range tt.apiPatterns { + am.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("api")) + }) + } + s, err := NewServer(Config{BrowserMux: bm, APIMux: am}) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("GET", tt.requestPath, nil) + w := httptest.NewRecorder() + s.h.Handler.ServeHTTP(w, req) + resp, err := io.ReadAll(w.Result().Body) + if err != nil { + t.Fatal(err) + } + if got := strings.TrimSpace(string(resp)); got != tt.want { + t.Errorf("got response %q, want %q", got, tt.want) + } + }) + } +}