diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 34addb451..10db37e58 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -3552,7 +3552,7 @@ func (b *LocalBackend) initPeerAPIListener() { b: b, taildrop: &taildrop.Manager{ Logf: b.logf, - Clock: b.clock, + Clock: tstime.DefaultClock{b.clock}, Dir: fileRoot, DirectFileMode: b.directFileRoot != "", AvoidFinalRename: !b.directFileDoFinalRename, diff --git a/ipn/ipnlocal/peerapi_test.go b/ipn/ipnlocal/peerapi_test.go index 48a6a1d3b..7a76c8713 100644 --- a/ipn/ipnlocal/peerapi_test.go +++ b/ipn/ipnlocal/peerapi_test.go @@ -541,8 +541,7 @@ func TestHandlePeerAPI(t *testing.T) { rootDir = t.TempDir() if e.ph.ps.taildrop == nil { e.ph.ps.taildrop = &taildrop.Manager{ - Logf: e.logBuf.Logf, - Clock: &tstest.Clock{}, + Logf: e.logBuf.Logf, } } e.ph.ps.taildrop.Dir = rootDir @@ -585,9 +584,8 @@ func TestFileDeleteRace(t *testing.T) { clock: &tstest.Clock{}, }, taildrop: &taildrop.Manager{ - Logf: t.Logf, - Clock: &tstest.Clock{}, - Dir: dir, + Logf: t.Logf, + Dir: dir, }, } ph := &peerAPIHandler{ diff --git a/taildrop/resume.go b/taildrop/resume.go new file mode 100644 index 000000000..ed797abd1 --- /dev/null +++ b/taildrop/resume.go @@ -0,0 +1,211 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "os" + "slices" + "strings" +) + +var ( + blockSize = int64(64 << 10) + hashAlgorithm = "sha256" +) + +// FileChecksums represents checksums into partially received file. +type FileChecksums struct { + // Offset is the offset into the file. + Offset int64 `json:"offset"` + // Length is the length of content being hashed in the file. + Length int64 `json:"length"` + // Checksums is a list of checksums of BlockSize-sized blocks + // starting from Offset. The number of checksums is the Length + // divided by BlockSize rounded up to the nearest integer. + // All blocks except for the last one are guaranteed to be checksums + // over BlockSize-sized blocks. + Checksums []Checksum `json:"checksums"` + // Algorithm is the hashing algorithm used to compute checksums. + Algorithm string `json:"algorithm"` // always "sha256" for now + // BlockSize is the size of each block. + // The last block may be smaller than this, but never zero. + BlockSize int64 `json:"blockSize"` // always (64<<10) for now +} + +// Checksum is an opaque checksum that is comparable. +type Checksum struct{ cs [sha256.Size]byte } + +func hash(b []byte) Checksum { + return Checksum{sha256.Sum256(b)} +} +func (cs Checksum) String() string { + return hex.EncodeToString(cs.cs[:]) +} +func (cs Checksum) AppendText(b []byte) ([]byte, error) { + return hexAppendEncode(b, cs.cs[:]), nil +} +func (cs Checksum) MarshalText() ([]byte, error) { + return hexAppendEncode(nil, cs.cs[:]), nil +} +func (cs *Checksum) UnmarshalText(b []byte) error { + if len(b) != 2*len(cs.cs) { + return fmt.Errorf("invalid hex length: %d", len(b)) + } + _, err := hex.Decode(cs.cs[:], b) + return err +} + +// TODO(https://go.dev/issue/53693): Use hex.AppendEncode instead. +func hexAppendEncode(dst, src []byte) []byte { + n := hex.EncodedLen(len(src)) + dst = slices.Grow(dst, n) + hex.Encode(dst[len(dst):][:n], src) + return dst[:len(dst)+n] +} + +// PartialFiles returns a list of partial files in [Handler.Dir] +// that were sent (or is actively being sent) by the provided id. +func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) { + if m.Dir == "" { + return ret, ErrNoTaildrop + } + if m.DirectFileMode && m.AvoidFinalRename { + return nil, nil // resuming is not supported for users that peek at our file structure + } + + f, err := os.Open(m.Dir) + if err != nil { + return ret, err + } + defer f.Close() + + suffix := id.partialSuffix() + for { + des, err := f.ReadDir(10) + if err != nil { + return ret, err + } + for _, de := range des { + if name := de.Name(); strings.HasSuffix(name, suffix) { + ret = append(ret, name) + } + } + if err == io.EOF { + return ret, nil + } + } +} + +// HashPartialFile hashes the contents of a partial file sent by id, +// starting at the specified offset and for the specified length. +// If length is negative, it hashes the entire file. +// If the length exceeds the remaining file length, then it hashes until EOF. +// If [FileHashes.Length] is less than length and no error occurred, +// then it implies that all remaining content in the file has been hashed. +func (m *Manager) HashPartialFile(id ClientID, baseName string, offset, length int64) (FileChecksums, error) { + if m.Dir == "" { + return FileChecksums{}, ErrNoTaildrop + } + if m.DirectFileMode && m.AvoidFinalRename { + return FileChecksums{}, nil // resuming is not supported for users that peek at our file structure + } + + dstFile, err := m.joinDir(baseName) + if err != nil { + return FileChecksums{}, err + } + f, err := os.Open(dstFile + id.partialSuffix()) + if err != nil { + if os.IsNotExist(err) { + return FileChecksums{}, nil + } + return FileChecksums{}, err + } + defer f.Close() + + if _, err := f.Seek(offset, io.SeekStart); err != nil { + return FileChecksums{}, err + } + checksums := FileChecksums{ + Offset: offset, + Algorithm: hashAlgorithm, + BlockSize: blockSize, + } + b := make([]byte, blockSize) // TODO: Pool this? + r := io.LimitReader(f, length) + for { + switch n, err := io.ReadFull(r, b); { + case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF: + return checksums, err + case n == 0: + return checksums, nil + default: + checksums.Checksums = append(checksums.Checksums, hash(b[:n])) + checksums.Length += int64(n) + } + } +} + +// ResumeReader reads and discards the leading content of r +// that matches the content based on the checksums that exist. +// It returns the number of bytes consumed, +// and returns an [io.Reader] representing the remaining content. +func ResumeReader(r io.Reader, hashFile func(offset, length int64) (FileChecksums, error)) (int64, io.Reader, error) { + if hashFile == nil { + return 0, r, nil + } + + // Ask for checksums of a particular content length, + // where the amount of memory needed to represent the checksums themselves + // is exactly equal to the blockSize. + numBlocks := blockSize / sha256.Size + hashLength := blockSize * numBlocks + + var offset int64 + b := make([]byte, 0, blockSize) + for { + // Request a list of checksums for the partial file starting at offset. + checksums, err := hashFile(offset, hashLength) + if len(checksums.Checksums) == 0 || err != nil { + return offset, io.MultiReader(bytes.NewReader(b), r), err + } else if checksums.BlockSize != blockSize || checksums.Algorithm != hashAlgorithm { + return offset, io.MultiReader(bytes.NewReader(b), r), fmt.Errorf("invalid block size or hashing algorithm") + } + + // Read from r, comparing each block with the provided checksums. + for _, want := range checksums.Checksums { + // Read a block from r. + n, err := io.ReadFull(r, b[:blockSize]) + b = b[:n] + if err == io.EOF || err == io.ErrUnexpectedEOF { + err = nil + } + if len(b) == 0 || err != nil { + // This should not occur in practice. + // It implies that an error occurred reading r, + // or that the partial file on the remote side is fully complete. + return offset, io.MultiReader(bytes.NewReader(b), r), err + } + + // Compare the local and remote block checksums. + // If it mismatches, then resume from this point. + got := hash(b) + if got != want { + return offset, io.MultiReader(bytes.NewReader(b), r), nil + } + offset += int64(len(b)) + b = b[:0] + } + + // We hashed the remainder of the partial file, so stop. + if checksums.Length < hashLength { + return offset, io.MultiReader(bytes.NewReader(b), r), nil + } + } +} diff --git a/taildrop/resume_test.go b/taildrop/resume_test.go new file mode 100644 index 000000000..d79fb80dd --- /dev/null +++ b/taildrop/resume_test.go @@ -0,0 +1,63 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "io" + "math/rand" + "os" + "testing" + "testing/iotest" + + "tailscale.com/util/must" +) + +func TestResume(t *testing.T) { + oldBlockSize := blockSize + defer func() { blockSize = oldBlockSize }() + blockSize = 256 + + m := Manager{Logf: t.Logf, Dir: t.TempDir()} + + rn := rand.New(rand.NewSource(0)) + want := make([]byte, 12345) + must.Get(io.ReadFull(rn, want)) + + t.Run("resume-noop", func(t *testing.T) { + r := io.Reader(bytes.NewReader(want)) + offset, r, err := ResumeReader(r, func(offset, length int64) (FileChecksums, error) { + return m.HashPartialFile("", "foo", offset, length) + }) + must.Do(err) + must.Get(m.PutFile("", "foo", r, offset, -1)) + got := must.Get(os.ReadFile(must.Get(m.joinDir("foo")))) + if !bytes.Equal(got, want) { + t.Errorf("content mismatches") + } + }) + + t.Run("resume-retry", func(t *testing.T) { + rn := rand.New(rand.NewSource(0)) + for { + r := io.Reader(bytes.NewReader(want)) + offset, r, err := ResumeReader(r, func(offset, length int64) (FileChecksums, error) { + return m.HashPartialFile("", "foo", offset, length) + }) + must.Do(err) + numWant := rn.Int63n(min(int64(len(want))-offset, 1000) + 1) + if offset < int64(len(want)) { + r = io.MultiReader(io.LimitReader(r, numWant), iotest.ErrReader(io.ErrClosedPipe)) + } + if _, err := m.PutFile("", "foo", r, offset, -1); err == nil { + break + } + } + got := must.Get(os.ReadFile(must.Get(m.joinDir("foo")))) + if !bytes.Equal(got, want) { + t.Errorf("content mismatches") + } + }) + +} diff --git a/taildrop/retrieve.go b/taildrop/retrieve.go index 01ab59468..76dd918ed 100644 --- a/taildrop/retrieve.go +++ b/taildrop/retrieve.go @@ -167,9 +167,9 @@ func (m *Manager) DeleteFile(baseName string) error { if m.DirectFileMode { return errors.New("deletes not allowed in direct mode") } - path, ok := m.joinDir(baseName) - if !ok { - return errors.New("bad filename") + path, err := m.joinDir(baseName) + if err != nil { + return err } var bo *backoff.Backoff logf := m.Logf @@ -224,9 +224,9 @@ func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err e if m.DirectFileMode { return nil, 0, errors.New("opens not allowed in direct mode") } - path, ok := m.joinDir(baseName) - if !ok { - return nil, 0, errors.New("bad filename") + path, err := m.joinDir(baseName) + if err != nil { + return nil, 0, err } if fi, err := os.Stat(path + deletedSuffix); err == nil && fi.Mode().IsRegular() { tryDeleteAgain(path) diff --git a/taildrop/send.go b/taildrop/send.go index 8bb2e8715..f97e2bfbe 100644 --- a/taildrop/send.go +++ b/taildrop/send.go @@ -22,7 +22,7 @@ type incomingFileKey struct { } type incomingFile struct { - clock tstime.Clock + clock tstime.DefaultClock started time.Time size int64 // or -1 if unknown; never 0 @@ -62,6 +62,7 @@ func (f *incomingFile) Write(p []byte) (n int, err error) { // The baseName must be a base filename without any slashes. // The length is the expected length of content to read from r, // it may be negative to indicate that it is unknown. +// It returns the length of the entire file. // // If there is a failure reading from r, then the partial file is not deleted // for some period of time. The [Manager.PartialFiles] and [Manager.HashPartialFile] @@ -78,9 +79,9 @@ func (m *Manager) PutFile(id ClientID, baseName string, r io.Reader, offset, len case distro.Get() == distro.Unraid && !m.DirectFileMode: return 0, ErrNotAccessible } - dstPath, ok := m.joinDir(baseName) - if !ok { - return 0, ErrInvalidFileName + dstPath, err := m.joinDir(baseName) + if err != nil { + return 0, err } redactAndLogError := func(action string, err error) error { diff --git a/taildrop/taildrop.go b/taildrop/taildrop.go index bc2b3f6ff..5b0bda6c2 100644 --- a/taildrop/taildrop.go +++ b/taildrop/taildrop.go @@ -45,7 +45,7 @@ func (id ClientID) partialSuffix() string { // Manager manages the state for receiving and managing taildropped files. type Manager struct { Logf logger.Logf - Clock tstime.Clock + Clock tstime.DefaultClock // Dir is the directory to store received files. // This main either be the final location for the files @@ -131,15 +131,15 @@ func validFilenameRune(r rune) bool { return unicode.IsPrint(r) } -func (m *Manager) joinDir(baseName string) (fullPath string, ok bool) { +func (m *Manager) joinDir(baseName string) (fullPath string, err error) { if !utf8.ValidString(baseName) { - return "", false + return "", ErrInvalidFileName } if strings.TrimSpace(baseName) != baseName { - return "", false + return "", ErrInvalidFileName } if len(baseName) > 255 { - return "", false + return "", ErrInvalidFileName } // TODO: validate unicode normalization form too? Varies by platform. clean := path.Clean(baseName) @@ -147,17 +147,17 @@ func (m *Manager) joinDir(baseName string) (fullPath string, ok bool) { clean == "." || clean == ".." || strings.HasSuffix(clean, deletedSuffix) || strings.HasSuffix(clean, partialSuffix) { - return "", false + return "", ErrInvalidFileName } for _, r := range baseName { if !validFilenameRune(r) { - return "", false + return "", ErrInvalidFileName } } if !filepath.IsLocal(baseName) { - return "", false + return "", ErrInvalidFileName } - return filepath.Join(m.Dir, baseName), true + return filepath.Join(m.Dir, baseName), nil } // IncomingFiles returns a list of active incoming files.