diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index 7c117ff90..e6fbffbf5 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -162,7 +162,7 @@ func (h Port80Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } host := h.FQDN if host == "" { - host = r.URL.Hostname() + host = r.Host } target := "https://" + host + path http.Redirect(w, r, target, http.StatusFound) diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 24e75f8fe..8df574254 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -580,3 +580,45 @@ func TestAcceptsEncoding(t *testing.T) { } } } + +func TestPort80Handler(t *testing.T) { + tests := []struct { + name string + h *Port80Handler + req string + wantLoc string + }{ + { + name: "no_fqdn", + h: &Port80Handler{}, + req: "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n", + wantLoc: "https://foo.com/", + }, + { + name: "fqdn_and_path", + h: &Port80Handler{FQDN: "bar.com"}, + req: "GET /path HTTP/1.1\r\nHost: foo.com\r\n\r\n", + wantLoc: "https://bar.com/path", + }, + { + name: "path_and_query_string", + h: &Port80Handler{FQDN: "baz.com"}, + req: "GET /path?a=b HTTP/1.1\r\nHost: foo.com\r\n\r\n", + wantLoc: "https://baz.com/path?a=b", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.ReadRequest(bufio.NewReader(strings.NewReader(tt.req))) + rec := httptest.NewRecorder() + tt.h.ServeHTTP(rec, r) + got := rec.Result() + if got, want := got.StatusCode, 302; got != want { + t.Errorf("got status code %v; want %v", got, want) + } + if got, want := got.Header.Get("Location"), "https://foo.com/"; got != tt.wantLoc { + t.Errorf("Location = %q; want %q", got, want) + } + }) + } +}