diff --git a/scripts/speculator/main.go b/scripts/speculator/main.go index 1c2780e38..2e6b48b3f 100644 --- a/scripts/speculator/main.go +++ b/scripts/speculator/main.go @@ -21,6 +21,7 @@ import ( "os" "os/exec" "path" + "regexp" "strconv" "strings" "sync" @@ -71,6 +72,8 @@ const ( permissionsOwnerFull = 0700 ) +var numericRegex = regexp.MustCompile(`^\d+$`) + func gitClone(url string, shared bool) (string, error) { directory := path.Join("/tmp/matrix-doc", strconv.FormatInt(rand.Int63(), 10)) if err := os.MkdirAll(directory, permissionsOwnerFull); err != nil { @@ -101,17 +104,16 @@ func runGitCommand(path string, args []string) error { return nil } -func lookupPullRequest(url url.URL, pathPrefix string) (*PullRequest, error) { - if !strings.HasPrefix(url.Path, pathPrefix+"/") { - return nil, fmt.Errorf("invalid path passed: %s expect %s/123", url.Path, pathPrefix) - } - prNumber := strings.Split(url.Path[len(pathPrefix)+1:], "/")[0] - +func lookupPullRequest(prNumber string) (*PullRequest, error) { resp, err := http.Get(fmt.Sprintf("%s/%s", pullsPrefix, prNumber)) defer resp.Body.Close() if err != nil { return nil, fmt.Errorf("error getting pulls: %v", err) } + if resp.StatusCode != 200 { + body, _ := ioutil.ReadAll(resp.Body) + return nil, fmt.Errorf("error getting pull request %s: %v", prNumber, string(body)) + } dec := json.NewDecoder(resp.Body) var pr PullRequest if err := dec.Decode(&pr); err != nil { @@ -120,6 +122,26 @@ func lookupPullRequest(url url.URL, pathPrefix string) (*PullRequest, error) { return &pr, nil } +func (s *server) lookupBranch(branch string) (string, error) { + err := s.updateBase() + if err != nil { + log.Printf("Error fetching: %v, will use cached branches") + } + + if strings.ToLower(branch) == "head" { + branch = "master" + } + branch = "origin/" + branch + sha, err := s.getSHAOf(branch) + if err != nil { + return "", fmt.Errorf("error getting branch %s: %v", branch, err) + } + if sha == "" { + return "", fmt.Errorf("Unable to get sha for %s", branch) + } + return sha, nil +} + func generate(dir string) error { cmd := exec.Command("python", "gendoc.py", "--nodelete") cmd.Dir = path.Join(dir, "scripts") @@ -193,6 +215,15 @@ func (s *server) getSHAOf(ref string) (string, error) { return strings.TrimSpace(b.String()), nil } +// extractPRNumber checks that the path begins with the given base, and returns +// the following component. +func extractPRNumber(path, base string) (string, error) { + if !strings.HasPrefix(path, base+"/") { + return "", fmt.Errorf("invalid path passed: %q expect %s/123", path, base) + } + return strings.Split(path[len(base)+1:], "/")[0], nil +} + // extractPath extracts the file path within the gen directory which should be served for the request. // Returns one of (file to serve, path to redirect to). // path is the actual path being requested, e.g. "/spec/head/client_server.html". @@ -230,17 +261,45 @@ func (s *server) serveSpec(w http.ResponseWriter, req *http.Request) { return } - if strings.HasPrefix(strings.ToLower(req.URL.Path), "/spec/head") { - // err may be non-nil here but if headSha is non-empty we will serve a possibly-stale result in favour of erroring. - // This is to deal with cases like where github is down but we still want to serve the spec. - if headSha, err := s.lookupHeadSHA(); headSha == "" { - writeError(w, 500, err) - return + // we use URL.EscapedPath() to get hold of the %-encoded version of the + // path, so that we can handle branch names with slashes in. + urlPath := req.URL.EscapedPath() + + if urlPath == "/spec" { + // special treatment for /spec - redirect to /spec/HEAD/ + s.redirectTo(w, req, "/spec/HEAD/") + return + } + + if !strings.HasPrefix(urlPath, "/spec/") { + writeError(w, 500, fmt.Errorf("invalid path passed: %q expect /spec/...", urlPath)) + } + + splits := strings.SplitN(urlPath[6:], "/", 2) + + if len(splits) == 1 { + // "/spec/foo" - redirect to "/spec/foo/" (so that relative links from the index work) + if splits[0] == "" { + s.redirectTo(w, req, "/spec/HEAD/") } else { - sha = headSha + s.redirectTo(w, req, urlPath+"/") } - } else { - pr, err := lookupPullRequest(*req.URL, "/spec") + return + } + + // now we have: + // splits[0] is a PR#, or a branch name + // splits[1] is the file to serve + + branchName, _ := url.QueryUnescape(splits[0]) + requestedPath, _ := url.QueryUnescape(splits[1]) + if requestedPath == "" { + requestedPath = "index.html" + } + + if numericRegex.MatchString(branchName) { + // PR number + pr, err := lookupPullRequest(branchName) if err != nil { writeError(w, 400, err) return @@ -253,6 +312,20 @@ func (s *server) serveSpec(w http.ResponseWriter, req *http.Request) { return } sha = pr.Head.SHA + log.Printf("Serving pr %s (%s)\n", branchName, sha) + } else if strings.ToLower(branchName) == "head" || + branchName == "master" || + strings.HasPrefix(branchName, "drafts/") { + branchSHA, err := s.lookupBranch(branchName) + if err != nil { + writeError(w, 400, err) + return + } + sha = branchSHA + log.Printf("Serving branch %s (%s)\n", branchName, sha) + } else { + writeError(w, 404, fmt.Errorf("invalid branch name")) + return } var cache = specCache @@ -299,11 +372,6 @@ func (s *server) serveSpec(w http.ResponseWriter, req *http.Request) { cache.Add(sha, pathToContent) } - requestedPath, redirect := extractPath(req.URL.Path, "/spec/") - if redirect != "" { - s.redirectTo(w, req, redirect) - return - } if b, ok := pathToContent[requestedPath]; ok { w.Write(b) return @@ -319,31 +387,11 @@ func (s *server) serveSpec(w http.ResponseWriter, req *http.Request) { w.Write([]byte("Not found")) } -func (s *server) redirectTo(w http.ResponseWriter, req *http.Request, path string) { - req.URL.Path = path - w.Header().Set("Location", req.URL.String()) +func (s *server) redirectTo(w http.ResponseWriter, _ *http.Request, path string) { + w.Header().Set("Location", path) w.WriteHeader(302) } -// lookupHeadSHA looks up what origin/master's HEAD SHA is. -// It attempts to `git fetch` before doing so. -// If this fails, it may still return a stale sha, but will also return an error. -func (s *server) lookupHeadSHA() (sha string, retErr error) { - retErr = s.updateBase() - if retErr != nil { - log.Printf("Error fetching: %v, attempting to fall back to current known value", retErr) - } - originHead, err := s.getSHAOf("origin/master") - if err != nil { - retErr = err - } - sha = originHead - if retErr != nil && originHead != "" { - log.Printf("Successfully fell back to possibly stale sha: %s", sha) - } - return -} - func checkAuth(pr *PullRequest) error { if !pr.User.IsTrusted() { return fmt.Errorf("%q is not a trusted pull requester", pr.User.Login) @@ -352,7 +400,12 @@ func checkAuth(pr *PullRequest) error { } func (s *server) serveRSTDiff(w http.ResponseWriter, req *http.Request) { - pr, err := lookupPullRequest(*req.URL, "/diff/rst") + prNumber, err := extractPRNumber(req.URL.Path, "/diff/rst") + if err != nil { + writeError(w, 400, err) + return + } + pr, err := lookupPullRequest(prNumber) if err != nil { writeError(w, 400, err) return @@ -390,7 +443,12 @@ func (s *server) serveRSTDiff(w http.ResponseWriter, req *http.Request) { } func (s *server) serveHTMLDiff(w http.ResponseWriter, req *http.Request) { - pr, err := lookupPullRequest(*req.URL, "/diff/html") + prNumber, err := extractPRNumber(req.URL.Path, "/diff/html") + if err != nil { + writeError(w, 400, err) + return + } + pr, err := lookupPullRequest(prNumber) if err != nil { writeError(w, 400, err) return @@ -450,21 +508,54 @@ func findHTMLDiffer() (string, error) { return "", fmt.Errorf("unable to find htmldiff.pl") } -func listPulls(w http.ResponseWriter, req *http.Request) { +func getPulls() ([]PullRequest, error) { resp, err := http.Get(pullsPrefix) if err != nil { - writeError(w, 500, err) - return + return nil, err } defer resp.Body.Close() + if resp.StatusCode != 200 { + body, _ := ioutil.ReadAll(resp.Body) + return nil, fmt.Errorf("error getting pull requests: %v", string(body)) + } dec := json.NewDecoder(resp.Body) var pulls []PullRequest - if err := dec.Decode(&pulls); err != nil { - writeError(w, 500, err) - return + err = dec.Decode(&pulls) + return pulls, err +} + +// getBranches returns a list of the upstream branch names. +// It attempts to `git fetch` before doing so. +func (s *server) getBranches() ([]string, error) { + err := s.updateBase() + if err != nil { + log.Printf("Error fetching: %v, will use cached branches") } - if len(pulls) == 0 { - io.WriteString(w, "No pull requests found") + + cmd := exec.Command("git", "branch", "-r") + cmd.Dir = path.Join(s.matrixDocCloneURL) + var b bytes.Buffer + cmd.Stdout = &b + s.mu.Lock() + err = cmd.Run() + s.mu.Unlock() + if err != nil { + return nil, fmt.Errorf("Error reading branch names: %v. Output from git:\n%v", err, b.String()) + } + branches := []string{} + for _, b := range strings.Split(b.String(), "\n") { + b = strings.TrimSpace(b) + if strings.HasPrefix(b, "origin/") { + branches = append(branches, b[7:]) + } + } + return branches, nil +} + +func (srv *server) makeIndex(w http.ResponseWriter, req *http.Request) { + pulls, err := getPulls() + if err != nil { + writeError(w, 500, err) return } s := "
` - if *includesDir != "" { - s += `