diff --git a/drive/driveimpl/compositedav/compositedav.go b/drive/driveimpl/compositedav/compositedav.go
index 9e5a293ac..8b41871ad 100644
--- a/drive/driveimpl/compositedav/compositedav.go
+++ b/drive/driveimpl/compositedav/compositedav.go
@@ -81,6 +81,16 @@ type Handler struct {
staticRoot string
}
+var cacheInvalidatingMethods = map[string]bool{
+ "PUT": true,
+ "POST": true,
+ "COPY": true,
+ "MKCOL": true,
+ "MOVE": true,
+ "PROPPATCH": true,
+ "DELETE": true,
+}
+
// ServeHTTP implements http.Handler.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "PROPFIND" {
@@ -88,11 +98,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
- if r.Method != "GET" {
- // If the user is performing a modification (e.g. PUT, MKDIR, etc),
+ _, shouldInvalidate := cacheInvalidatingMethods[r.Method]
+ if shouldInvalidate {
+ // If the user is performing a modification (e.g. PUT, MKDIR, etc.),
// we need to invalidate the StatCache to make sure we're not knowingly
// showing stale stats.
- // TODO(oxtoacart): maybe be more selective about invalidating cache
+ // TODO(oxtoacart): maybe only invalidate specific paths
h.StatCache.invalidate()
}
diff --git a/drive/driveimpl/compositedav/stat_cache.go b/drive/driveimpl/compositedav/stat_cache.go
index 4d5971bd0..fc57ff064 100644
--- a/drive/driveimpl/compositedav/stat_cache.go
+++ b/drive/driveimpl/compositedav/stat_cache.go
@@ -4,11 +4,19 @@
package compositedav
import (
+ "bytes"
+ "encoding/xml"
+ "log"
"net/http"
"sync"
"time"
"github.com/jellydator/ttlcache/v3"
+ "tailscale.com/drive/driveimpl/shared"
+)
+
+var (
+ notFound = newCacheEntry(http.StatusNotFound, nil)
)
// StatCache provides a cache for directory listings and file metadata.
@@ -18,12 +26,38 @@ import (
// This is similar to the DirectoryCacheLifetime setting of Windows' built-in
// SMB client, see
// https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-7/ff686200(v=ws.10)
+//
+// StatCache is built specifically to cache the results of PROPFIND requests,
+// which come back as MultiStatus XML responses. Typical clients will issue two
+// kinds of PROPFIND:
+//
+// The first kind of PROPFIND is a directory listing performed to depth 1. At
+// this depth, the resulting XML will contain stats for the requested folder as
+// well as for all children of that folder.
+//
+// The second kind of PROPFIND is a file listing performed to depth 0. At this
+// depth, the resulting XML will contain stats only for the requested file.
+//
+// In order to avoid round-trips, when a PROPFIND at depth 0 is attempted, and
+// the requested file is not in the cache, StatCache will check to see if the
+// parent folder of that file is cached. If so, StatCache infers the correct
+// MultiStatus for the file according to the following logic:
+//
+// 1. If the parent folder is NotFound (404), treat the file itself as NotFound
+// 2. If the parent folder's XML doesn't contain the file, treat it as
+// NotFound.
+// 3. If the parent folder's XML contains the file, build a MultiStatus for the
+// file based on the parent's XML.
+//
+// To avoid inconsistencies from the perspective of the client, any operations
+// that modify the filesystem (e.g. PUT, MKDIR, etc.) should call invalidate()
+// to invalidate the cache.
type StatCache struct {
TTL time.Duration
// mu guards the below values.
mu sync.Mutex
- cachesByDepthAndPath map[int]*ttlcache.Cache[string, []byte]
+ cachesByDepthAndPath map[int]*ttlcache.Cache[string, *cacheEntry]
}
// getOr checks the cache for the named value at the given depth. If a cached
@@ -32,25 +66,57 @@ type StatCache struct {
// status and value. If the function returned http.StatusMultiStatus, getOr
// caches the resulting value at the given name and depth before returning.
func (c *StatCache) getOr(name string, depth int, or func() (int, []byte)) (int, []byte) {
- cached := c.get(name, depth)
- if cached != nil {
- return http.StatusMultiStatus, cached
- }
- status, next := or()
- if c != nil && status == http.StatusMultiStatus && next != nil {
- c.set(name, depth, next)
+ ce := c.get(name, depth)
+ if ce == nil {
+ // Not cached, fetch value.
+ status, raw := or()
+ ce = newCacheEntry(status, raw)
+ if status == http.StatusMultiStatus || status == http.StatusNotFound {
+ // Got a legit status, cache value
+ c.set(name, depth, ce)
+ }
}
- return status, next
+ return ce.Status, ce.Raw
}
-func (c *StatCache) get(name string, depth int) []byte {
+// get retrieves the entry for the named file at the given depth. If no entry
+// is found, and depth == 0, get will check to see if the parent path of name
+// is present in the cache at depth 1. If so, it will infer that the child does
+// not exist and return notFound (404).
+func (c *StatCache) get(name string, depth int) *cacheEntry {
if c == nil {
return nil
}
+ name = shared.Normalize(name)
+
c.mu.Lock()
defer c.mu.Unlock()
+ ce := c.tryGetLocked(name, depth)
+ if ce != nil {
+ // Cache hit.
+ return ce
+ }
+
+ if depth > 0 {
+ // Cache miss.
+ return nil
+ }
+
+ // At depth 0, if child's parent is in the cache, and the child isn't
+ // cached, we can infer that the child is notFound.
+ p := c.tryGetLocked(shared.Parent(name), 1)
+ if p != nil {
+ return notFound
+ }
+
+ // No parent in cache, cache miss.
+ return nil
+}
+
+// tryGetLocked requires that c.mu be held.
+func (c *StatCache) tryGetLocked(name string, depth int) *cacheEntry {
if c.cachesByDepthAndPath == nil {
return nil
}
@@ -65,28 +131,80 @@ func (c *StatCache) get(name string, depth int) []byte {
return item.Value()
}
-func (c *StatCache) set(name string, depth int, value []byte) {
+// set stores the given cacheEntry in the cache at the given name and depth. If
+// the depth is 1, set also populates depth 0 entries in the cache for the bare
+// name. If status is StatusMultiStatus, set will parse the PROPFIND result and
+// store depth 0 entries for all children. If parsing the result fails, nothing
+// is cached.
+func (c *StatCache) set(name string, depth int, ce *cacheEntry) {
if c == nil {
return
}
+ name = shared.Normalize(name)
+
+ var self *cacheEntry
+ var children map[string]*cacheEntry
+ if depth == 1 {
+ switch ce.Status {
+ case http.StatusNotFound:
+ // Record notFound as the self entry.
+ self = ce
+ case http.StatusMultiStatus:
+ // Parse the raw MultiStatus and extract specific responses
+ // corresponding to the self entry (e.g. the directory, but at depth 0)
+ // and children (e.g. files within the directory) so that subsequent
+ // requests for these can be satisfied from the cache.
+ var ms multiStatus
+ err := xml.Unmarshal(ce.Raw, &ms)
+ if err != nil {
+ // unparseable MultiStatus response, don't cache
+ log.Printf("statcache.set error: %s", err)
+ return
+ }
+ children = make(map[string]*cacheEntry, len(ms.Responses)-1)
+ for i := 0; i < len(ms.Responses); i++ {
+ response := ms.Responses[i]
+ name := shared.Normalize(response.Href)
+ raw := marshalMultiStatus(response)
+ entry := newCacheEntry(ce.Status, raw)
+ if i == 0 {
+ self = entry
+ } else {
+ children[name] = entry
+ }
+ }
+ }
+ }
+
c.mu.Lock()
defer c.mu.Unlock()
+ c.setLocked(name, depth, ce)
+ if self != nil {
+ c.setLocked(name, 0, self)
+ }
+ for childName, child := range children {
+ c.setLocked(childName, 0, child)
+ }
+}
+// setLocked requires that c.mu be held.
+func (c *StatCache) setLocked(name string, depth int, ce *cacheEntry) {
if c.cachesByDepthAndPath == nil {
- c.cachesByDepthAndPath = make(map[int]*ttlcache.Cache[string, []byte])
+ c.cachesByDepthAndPath = make(map[int]*ttlcache.Cache[string, *cacheEntry])
}
cache := c.cachesByDepthAndPath[depth]
if cache == nil {
cache = ttlcache.New(
- ttlcache.WithTTL[string, []byte](c.TTL),
+ ttlcache.WithTTL[string, *cacheEntry](c.TTL),
)
go cache.Start()
c.cachesByDepthAndPath[depth] = cache
}
- cache.Set(name, value, ttlcache.DefaultTTL)
+ cache.Set(name, ce, ttlcache.DefaultTTL)
}
+// invalidate invalidates the entire cache.
func (c *StatCache) invalidate() {
if c == nil {
return
@@ -108,3 +226,54 @@ func (c *StatCache) stop() {
cache.Stop()
}
}
+
+type cacheEntry struct {
+ Status int
+ Raw []byte
+}
+
+func newCacheEntry(status int, raw []byte) *cacheEntry {
+ return &cacheEntry{Status: status, Raw: raw}
+}
+
+type propStat struct {
+ InnerXML []byte `xml:",innerxml"`
+}
+
+type response struct {
+ XMLName xml.Name `xml:"response"`
+ Href string `xml:"href"`
+ PropStats []*propStat `xml:"propstat"`
+}
+
+type multiStatus struct {
+ XMLName xml.Name `xml:"multistatus"`
+ Responses []*response `xml:"response"`
+}
+
+// marshalMultiStatus performs custom marshalling of a MultiStatus to preserve
+// the original formatting, namespacing, etc. Doing this with Go's XML encoder
+// is somewhere between difficult and impossible, which is why we use this more
+// manual approach.
+func marshalMultiStatus(response *response) []byte {
+ // TODO(percy): maybe pool these buffers
+ var buf bytes.Buffer
+ buf.WriteString(multistatusTemplateStart)
+ buf.WriteString(response.Href)
+ buf.WriteString(hrefEnd)
+ for _, propStat := range response.PropStats {
+ buf.WriteString(propstatStart)
+ buf.Write(propStat.InnerXML)
+ buf.WriteString(propstatEnd)
+ }
+ buf.WriteString(multistatusTemplateEnd)
+ return buf.Bytes()
+}
+
+const (
+ multistatusTemplateStart = ``
+ hrefEnd = ``
+ propstatStart = ``
+ propstatEnd = ``
+ multistatusTemplateEnd = ``
+)
diff --git a/drive/driveimpl/compositedav/stat_cache_test.go b/drive/driveimpl/compositedav/stat_cache_test.go
index c69832f26..fa63457a2 100644
--- a/drive/driveimpl/compositedav/stat_cache_test.go
+++ b/drive/driveimpl/compositedav/stat_cache_test.go
@@ -4,17 +4,65 @@
package compositedav
import (
- "bytes"
+ "fmt"
+ "log"
+ "net/http"
+ "path"
+ "strings"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"tailscale.com/tstest"
)
-var (
- val = []byte("1")
- file = "file"
-)
+var parentPath = "/parent"
+
+var childPath = "/parent/child.txt"
+
+var parentResponse = `
+/parent/
+
+
+Mon, 29 Apr 2024 19:52:23 GMT
+Fri, 19 Apr 2024 04:13:34 GMT
+
+
+
+
+HTTP/1.1 200 OK
+
+`
+
+var childResponse = `
+
+/parent/child.txt
+
+
+Mon, 29 Apr 2024 19:52:23 GMT
+Fri, 19 Apr 2024 04:13:34 GMT
+
+
+
+
+HTTP/1.1 200 OK
+
+`
+
+var fullParent = []byte(
+ strings.ReplaceAll(
+ fmt.Sprintf(`%s%s`, parentResponse, childResponse),
+ "\n", ""))
+
+var partialParent = []byte(
+ strings.ReplaceAll(
+ fmt.Sprintf(`%s`, parentResponse),
+ "\n", ""))
+
+var fullChild = []byte(
+ strings.ReplaceAll(
+ fmt.Sprintf(`%s`, childResponse),
+ "\n", ""))
func TestStatCacheNoTimeout(t *testing.T) {
// Make sure we don't leak goroutines
@@ -24,22 +72,23 @@ func TestStatCacheNoTimeout(t *testing.T) {
defer c.stop()
// check get before set
- fetched := c.get(file, 1)
+ fetched := c.get(childPath, 0)
if fetched != nil {
- t.Errorf("got %q, want nil", fetched)
+ t.Errorf("got %v, want nil", fetched)
}
// set new stat
- c.set(file, 1, val)
- fetched = c.get(file, 1)
- if !bytes.Equal(fetched, val) {
- t.Errorf("got %q, want %q", fetched, val)
+ ce := newCacheEntry(http.StatusMultiStatus, fullChild)
+ c.set(childPath, 0, ce)
+ fetched = c.get(childPath, 0)
+ if diff := cmp.Diff(fetched, ce); diff != "" {
+ t.Errorf("should have gotten cached value; (-got+want):%v", diff)
}
// fetch stat again, should still be cached
- fetched = c.get(file, 1)
- if !bytes.Equal(fetched, val) {
- t.Errorf("got %q, want %q", fetched, val)
+ fetched = c.get(childPath, 0)
+ if diff := cmp.Diff(fetched, ce); diff != "" {
+ t.Errorf("should still have gotten cached value; (-got+want):%v", diff)
}
}
@@ -51,25 +100,114 @@ func TestStatCacheTimeout(t *testing.T) {
defer c.stop()
// set new stat
- c.set(file, 1, val)
- fetched := c.get(file, 1)
- if !bytes.Equal(fetched, val) {
- t.Errorf("got %q, want %q", fetched, val)
+ ce := newCacheEntry(http.StatusMultiStatus, fullChild)
+ c.set(childPath, 0, ce)
+ fetched := c.get(childPath, 0)
+ if diff := cmp.Diff(fetched, ce); diff != "" {
+ t.Errorf("should have gotten cached value; (-got+want):%v", diff)
}
// wait for cache to expire and refetch stat, should be empty now
time.Sleep(c.TTL * 2)
- fetched = c.get(file, 1)
+ fetched = c.get(childPath, 0)
if fetched != nil {
- t.Errorf("invalidate should have cleared cached value")
+ t.Errorf("cached value should have expired")
}
- c.set(file, 1, val)
+ c.set(childPath, 0, ce)
// invalidate the cache and make sure nothing is returned
c.invalidate()
- fetched = c.get(file, 1)
+ fetched = c.get(childPath, 0)
if fetched != nil {
t.Errorf("invalidate should have cleared cached value")
}
}
+
+func TestParentChildRelationship(t *testing.T) {
+ // Make sure we don't leak goroutines
+ tstest.ResourceCheck(t)
+
+ c := &StatCache{TTL: 24 * time.Hour} // don't expire
+ defer c.stop()
+
+ missingParentPath := "/missingparent"
+ unparseableParentPath := "/unparseable"
+
+ c.set(parentPath, 1, newCacheEntry(http.StatusMultiStatus, fullParent))
+ c.set(missingParentPath, 1, newCacheEntry(http.StatusNotFound, nil))
+ c.set(unparseableParentPath, 1, newCacheEntry(http.StatusMultiStatus, []byte("