diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index d637ba1c4..6ae5ed530 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -649,34 +649,46 @@ func (h *peerAPIHandler) handlePeerPut(w http.ResponseWriter, r *http.Request) { http.Error(w, taildrop.ErrInvalidFileName.Error(), http.StatusBadRequest) return } + enc := json.NewEncoder(w) switch r.Method { case "GET": - var resp any - var err error id := taildrop.ClientID(h.peerNode.StableID()) - if prefix == "" { - resp, err = h.ps.taildrop.PartialFiles(id) + // List all the partial files. + files, err := h.ps.taildrop.PartialFiles(id) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := enc.Encode(files); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + h.logf("json.Encoder.Encode error: %v", err) + return + } } else { - ranges, ok := httphdr.ParseRange(r.Header.Get("Range")) - if !ok || len(ranges) != 1 || ranges[0].Length < 0 { - http.Error(w, "invalid Range header", http.StatusBadRequest) + // Stream all the block hashes for the specified file. + next, close, err := h.ps.taildrop.HashPartialFile(id, baseName) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - offset := ranges[0].Start - length := ranges[0].Length - if length == 0 { - length = -1 // httphdr.Range.Length == 0 implies reading the rest of file + defer close() + for { + switch cs, err := next(); { + case err == io.EOF: + return + case err != nil: + http.Error(w, err.Error(), http.StatusInternalServerError) + h.logf("HashPartialFile.next error: %v", err) + return + default: + if err := enc.Encode(cs); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + h.logf("json.Encoder.Encode error: %v", err) + return + } + } } - resp, err = h.ps.taildrop.HashPartialFile(id, baseName, offset, length) - } - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return } case "PUT": t0 := h.ps.b.clock.Now() diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 3e621862b..f9b040e8f 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -1320,38 +1320,36 @@ func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) { // Before we PUT a file we check to see if there are any existing partial file and if so, // we resume the upload from where we left off by sending the remaining file instead of // the full file. - offset, remainingBody, err := taildrop.ResumeReader(r.Body, func(offset, length int64) (taildrop.FileChecksums, error) { - client := &http.Client{ - Transport: h.b.Dialer().PeerAPITransport(), - Timeout: 10 * time.Second, - } - req, err := http.NewRequestWithContext(r.Context(), "GET", dstURL.String()+"/v0/put/"+filenameEscaped, nil) + var offset int64 + var resumeDuration time.Duration + remainingBody := io.Reader(r.Body) + client := &http.Client{ + Transport: h.b.Dialer().PeerAPITransport(), + Timeout: 10 * time.Second, + } + req, err := http.NewRequestWithContext(r.Context(), "GET", dstURL.String()+"/v0/put/"+filenameEscaped, nil) + if err != nil { + http.Error(w, "bogus peer URL", http.StatusInternalServerError) + return + } + switch resp, err := client.Do(req); { + case err != nil: + h.logf("could not fetch remote hashes: %v", err) + case resp.StatusCode == http.StatusMethodNotAllowed || resp.StatusCode == http.StatusNotFound: + // noop; implies older peerapi without resume support + case resp.StatusCode != http.StatusOK: + h.logf("fetch remote hashes status code: %d", resp.StatusCode) + default: + resumeStart := time.Now() + dec := json.NewDecoder(resp.Body) + offset, remainingBody, err = taildrop.ResumeReader(r.Body, func() (out taildrop.BlockChecksum, err error) { + err = dec.Decode(&out) + return out, err + }) if err != nil { - return taildrop.FileChecksums{}, err + h.logf("reader could not be fully resumed: %v", err) } - - rangeHdr, ok := httphdr.FormatRange([]httphdr.Range{{Start: offset, Length: length}}) - if !ok { - return taildrop.FileChecksums{}, fmt.Errorf("invalid offset and length") - } - req.Header.Set("Range", rangeHdr) - switch resp, err := client.Do(req); { - case err != nil: - return taildrop.FileChecksums{}, err - case resp.StatusCode == http.StatusMethodNotAllowed || resp.StatusCode == http.StatusNotFound: - return taildrop.FileChecksums{}, nil // implies remote peer on older version - case resp.StatusCode != http.StatusOK: - return taildrop.FileChecksums{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) - default: - var checksums taildrop.FileChecksums - err = json.NewDecoder(resp.Body).Decode(&checksums) - return checksums, err - } - }) - if err != nil { - // ResumeReader ensures that the returned offset and reader are consistent - // even if an error is encountered. Thus, we can still proceed. - h.logf("reader could not be fully resumed: %v", err) + resumeDuration = time.Since(resumeStart).Round(time.Millisecond) } outReq, err := http.NewRequestWithContext(r.Context(), "PUT", "http://peer/v0/put/"+filenameEscaped, remainingBody) @@ -1361,6 +1359,7 @@ func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) { } outReq.ContentLength = r.ContentLength if offset > 0 { + h.logf("resuming put at offset %d after %v", offset, resumeDuration) rangeHdr, _ := httphdr.FormatRange([]httphdr.Range{{offset, 0}}) outReq.Header.Set("Range", rangeHdr) if outReq.ContentLength >= 0 { diff --git a/taildrop/resume.go b/taildrop/resume.go index 1388ac793..5a6ba4e93 100644 --- a/taildrop/resume.go +++ b/taildrop/resume.go @@ -20,23 +20,11 @@ var ( hashAlgorithm = "sha256" ) -// FileChecksums represents checksums into partially received file. -type FileChecksums struct { - // Offset is the offset into the file. - Offset int64 `json:"offset"` - // Length is the length of content being hashed in the file. - Length int64 `json:"length"` - // Checksums is a list of checksums of BlockSize-sized blocks - // starting from Offset. The number of checksums is the Length - // divided by BlockSize rounded up to the nearest integer. - // All blocks except for the last one are guaranteed to be checksums - // over BlockSize-sized blocks. - Checksums []Checksum `json:"checksums"` - // Algorithm is the hashing algorithm used to compute checksums. - Algorithm string `json:"algorithm"` // always "sha256" for now - // BlockSize is the size of each block. - // The last block may be smaller than this, but never zero. - BlockSize int64 `json:"blockSize"` // always (64<<10) for now +// BlockChecksum represents the checksum for a single block. +type BlockChecksum struct { + Checksum Checksum `json:"checksum"` + Algorithm string `json:"algo"` // always "sha256" for now + Size int64 `json:"size"` // always (64<<10) for now } // Checksum is an opaque checksum that is comparable. @@ -92,113 +80,89 @@ func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) { return ret, nil } -// HashPartialFile hashes the contents of a partial file sent by id, -// starting at the specified offset and for the specified length. -// If length is negative, it hashes the entire file. -// If the length exceeds the remaining file length, then it hashes until EOF. -// If [FileHashes.Length] is less than length and no error occurred, -// then it implies that all remaining content in the file has been hashed. -func (m *Manager) HashPartialFile(id ClientID, baseName string, offset, length int64) (FileChecksums, error) { +// HashPartialFile returns a function that hashes the next block in the file, +// starting from the beginning of the file. +// It returns (BlockChecksum{}, io.EOF) when the stream is complete. +// It is the caller's responsibility to call close. +func (m *Manager) HashPartialFile(id ClientID, baseName string) (next func() (BlockChecksum, error), close func() error, err error) { if m == nil || m.opts.Dir == "" { - return FileChecksums{}, ErrNoTaildrop + return nil, nil, ErrNoTaildrop } + noopNext := func() (BlockChecksum, error) { return BlockChecksum{}, io.EOF } + noopClose := func() error { return nil } if m.opts.DirectFileMode && m.opts.AvoidFinalRename { - return FileChecksums{}, nil // resuming is not supported for users that peek at our file structure + return noopNext, noopClose, nil // resuming is not supported for users that peek at our file structure } dstFile, err := joinDir(m.opts.Dir, baseName) if err != nil { - return FileChecksums{}, err + return nil, nil, err } f, err := os.Open(dstFile + id.partialSuffix()) if err != nil { if os.IsNotExist(err) { - return FileChecksums{}, nil + return noopNext, noopClose, nil } - return FileChecksums{}, redactError(err) + return nil, nil, redactError(err) } - defer f.Close() - if _, err := f.Seek(offset, io.SeekStart); err != nil { - return FileChecksums{}, redactError(err) - } - checksums := FileChecksums{ - Offset: offset, - Algorithm: hashAlgorithm, - BlockSize: blockSize, - } b := make([]byte, blockSize) // TODO: Pool this? - r := io.Reader(f) - if length >= 0 { - r = io.LimitReader(f, length) - } - for { - switch n, err := io.ReadFull(r, b); { + next = func() (BlockChecksum, error) { + switch n, err := io.ReadFull(f, b); { case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF: - return checksums, redactError(err) + return BlockChecksum{}, redactError(err) case n == 0: - return checksums, nil + return BlockChecksum{}, io.EOF default: - checksums.Checksums = append(checksums.Checksums, hash(b[:n])) - checksums.Length += int64(n) + return BlockChecksum{hash(b[:n]), hashAlgorithm, int64(n)}, nil } } + close = f.Close + return next, close, nil } // ResumeReader reads and discards the leading content of r // that matches the content based on the checksums that exist. // It returns the number of bytes consumed, // and returns an [io.Reader] representing the remaining content. -func ResumeReader(r io.Reader, hashFile func(offset, length int64) (FileChecksums, error)) (int64, io.Reader, error) { - if hashFile == nil { +func ResumeReader(r io.Reader, hashNext func() (BlockChecksum, error)) (int64, io.Reader, error) { + if hashNext == nil { return 0, r, nil } - // Ask for checksums of a particular content length, - // where the amount of memory needed to represent the checksums themselves - // is exactly equal to the blockSize. - numBlocks := blockSize / sha256.Size - hashLength := blockSize * numBlocks - var offset int64 b := make([]byte, 0, blockSize) for { - // Request a list of checksums for the partial file starting at offset. - checksums, err := hashFile(offset, hashLength) - if len(checksums.Checksums) == 0 || err != nil { + // Obtain the next block checksum from the remote peer. + cs, err := hashNext() + switch { + case err == io.EOF: + return offset, io.MultiReader(bytes.NewReader(b), r), nil + case err != nil: return offset, io.MultiReader(bytes.NewReader(b), r), err - } else if checksums.BlockSize != blockSize || checksums.Algorithm != hashAlgorithm { + case cs.Algorithm != hashAlgorithm || cs.Size < 0 || cs.Size > blockSize: return offset, io.MultiReader(bytes.NewReader(b), r), fmt.Errorf("invalid block size or hashing algorithm") } - // Read from r, comparing each block with the provided checksums. - for _, want := range checksums.Checksums { - // Read a block from r. - n, err := io.ReadFull(r, b[:blockSize]) - b = b[:n] - if err == io.EOF || err == io.ErrUnexpectedEOF { - err = nil - } - if len(b) == 0 || err != nil { - // This should not occur in practice. - // It implies that an error occurred reading r, - // or that the partial file on the remote side is fully complete. - return offset, io.MultiReader(bytes.NewReader(b), r), err - } - - // Compare the local and remote block checksums. - // If it mismatches, then resume from this point. - got := hash(b) - if got != want { - return offset, io.MultiReader(bytes.NewReader(b), r), nil - } - offset += int64(len(b)) - b = b[:0] + // Read the contents of the next block. + n, err := io.ReadFull(r, b[:blockSize]) + b = b[:n] + if err == io.EOF || err == io.ErrUnexpectedEOF { + err = nil + } + if len(b) == 0 || err != nil { + // This should not occur in practice. + // It implies that an error occurred reading r, + // or that the partial file on the remote side is fully complete. + return offset, io.MultiReader(bytes.NewReader(b), r), err } - // We hashed the remainder of the partial file, so stop. - if checksums.Length < hashLength { + // Compare the local and remote block checksums. + // If it mismatches, then resume from this point. + if cs.Checksum != hash(b) { return offset, io.MultiReader(bytes.NewReader(b), r), nil } + offset += int64(len(b)) + b = b[:0] } } diff --git a/taildrop/resume_test.go b/taildrop/resume_test.go index 0deaf6869..55502289f 100644 --- a/taildrop/resume_test.go +++ b/taildrop/resume_test.go @@ -26,11 +26,12 @@ func TestResume(t *testing.T) { want := make([]byte, 12345) must.Get(io.ReadFull(rn, want)) - t.Run("resume-noop", func(t *testing.T) { + t.Run("resume-noexist", func(t *testing.T) { r := io.Reader(bytes.NewReader(want)) - offset, r, err := ResumeReader(r, func(offset, length int64) (FileChecksums, error) { - return m.HashPartialFile("", "foo", offset, length) - }) + next, close, err := m.HashPartialFile("", "foo") + must.Do(err) + defer close() + offset, r, err := ResumeReader(r, next) must.Do(err) must.Get(m.PutFile("", "foo", r, offset, -1)) got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "foo")))) @@ -43,9 +44,10 @@ func TestResume(t *testing.T) { rn := rand.New(rand.NewSource(0)) for { r := io.Reader(bytes.NewReader(want)) - offset, r, err := ResumeReader(r, func(offset, length int64) (FileChecksums, error) { - return m.HashPartialFile("", "foo", offset, length) - }) + next, close, err := m.HashPartialFile("", "foo") + must.Do(err) + defer close() + offset, r, err := ResumeReader(r, next) must.Do(err) numWant := rn.Int63n(min(int64(len(want))-offset, 1000) + 1) if offset < int64(len(want)) {