// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package taildrop import ( "bytes" "crypto/sha256" "encoding/hex" "fmt" "io" "io/fs" "os" "strings" ) var ( blockSize = int64(64 << 10) hashAlgorithm = "sha256" ) // BlockChecksum represents the checksum for a single block. type BlockChecksum struct { Checksum Checksum `json:"checksum"` Algorithm string `json:"algo"` // always "sha256" for now Size int64 `json:"size"` // always (64<<10) for now } // Checksum is an opaque checksum that is comparable. type Checksum struct{ cs [sha256.Size]byte } func hash(b []byte) Checksum { return Checksum{sha256.Sum256(b)} } func (cs Checksum) String() string { return hex.EncodeToString(cs.cs[:]) } func (cs Checksum) AppendText(b []byte) ([]byte, error) { return hex.AppendEncode(b, cs.cs[:]), nil } func (cs Checksum) MarshalText() ([]byte, error) { return hex.AppendEncode(nil, cs.cs[:]), nil } func (cs *Checksum) UnmarshalText(b []byte) error { if len(b) != 2*len(cs.cs) { return fmt.Errorf("invalid hex length: %d", len(b)) } _, err := hex.Decode(cs.cs[:], b) return err } // PartialFiles returns a list of partial files in [Handler.Dir] // that were sent (or is actively being sent) by the provided id. func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) { if m == nil || m.opts.Dir == "" { return nil, ErrNoTaildrop } suffix := id.partialSuffix() if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { if name := de.Name(); strings.HasSuffix(name, suffix) { ret = append(ret, name) } return true }); err != nil { return ret, redactError(err) } return ret, nil } // HashPartialFile returns a function that hashes the next block in the file, // starting from the beginning of the file. // It returns (BlockChecksum{}, io.EOF) when the stream is complete. // It is the caller's responsibility to call close. func (m *Manager) HashPartialFile(id ClientID, baseName string) (next func() (BlockChecksum, error), close func() error, err error) { if m == nil || m.opts.Dir == "" { return nil, nil, ErrNoTaildrop } noopNext := func() (BlockChecksum, error) { return BlockChecksum{}, io.EOF } noopClose := func() error { return nil } dstFile, err := joinDir(m.opts.Dir, baseName) if err != nil { return nil, nil, err } f, err := os.Open(dstFile + id.partialSuffix()) if err != nil { if os.IsNotExist(err) { return noopNext, noopClose, nil } return nil, nil, redactError(err) } b := make([]byte, blockSize) // TODO: Pool this? next = func() (BlockChecksum, error) { switch n, err := io.ReadFull(f, b); { case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF: return BlockChecksum{}, redactError(err) case n == 0: return BlockChecksum{}, io.EOF default: return BlockChecksum{hash(b[:n]), hashAlgorithm, int64(n)}, nil } } close = f.Close return next, close, nil } // ResumeReader reads and discards the leading content of r // that matches the content based on the checksums that exist. // It returns the number of bytes consumed, // and returns an [io.Reader] representing the remaining content. func ResumeReader(r io.Reader, hashNext func() (BlockChecksum, error)) (int64, io.Reader, error) { if hashNext == nil { return 0, r, nil } var offset int64 b := make([]byte, 0, blockSize) for { // Obtain the next block checksum from the remote peer. cs, err := hashNext() switch { case err == io.EOF: return offset, io.MultiReader(bytes.NewReader(b), r), nil case err != nil: return offset, io.MultiReader(bytes.NewReader(b), r), err case cs.Algorithm != hashAlgorithm || cs.Size < 0 || cs.Size > blockSize: return offset, io.MultiReader(bytes.NewReader(b), r), fmt.Errorf("invalid block size or hashing algorithm") } // Read the contents of the next block. n, err := io.ReadFull(r, b[:cs.Size]) b = b[:n] if err == io.EOF || err == io.ErrUnexpectedEOF { err = nil } if len(b) == 0 || err != nil { // This should not occur in practice. // It implies that an error occurred reading r, // or that the partial file on the remote side is fully complete. return offset, io.MultiReader(bytes.NewReader(b), r), err } // Compare the local and remote block checksums. // If it mismatches, then resume from this point. if cs.Checksum != hash(b) { return offset, io.MultiReader(bytes.NewReader(b), r), nil } offset += int64(len(b)) b = b[:0] } }