diff --git a/drive/driveimpl/compositedav/compositedav.go b/drive/driveimpl/compositedav/compositedav.go index 9e5a293ac..8b41871ad 100644 --- a/drive/driveimpl/compositedav/compositedav.go +++ b/drive/driveimpl/compositedav/compositedav.go @@ -81,6 +81,16 @@ type Handler struct { staticRoot string } +var cacheInvalidatingMethods = map[string]bool{ + "PUT": true, + "POST": true, + "COPY": true, + "MKCOL": true, + "MOVE": true, + "PROPPATCH": true, + "DELETE": true, +} + // ServeHTTP implements http.Handler. func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == "PROPFIND" { @@ -88,11 +98,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if r.Method != "GET" { - // If the user is performing a modification (e.g. PUT, MKDIR, etc), + _, shouldInvalidate := cacheInvalidatingMethods[r.Method] + if shouldInvalidate { + // If the user is performing a modification (e.g. PUT, MKDIR, etc.), // we need to invalidate the StatCache to make sure we're not knowingly // showing stale stats. - // TODO(oxtoacart): maybe be more selective about invalidating cache + // TODO(oxtoacart): maybe only invalidate specific paths h.StatCache.invalidate() } diff --git a/drive/driveimpl/compositedav/stat_cache.go b/drive/driveimpl/compositedav/stat_cache.go index 4d5971bd0..fc57ff064 100644 --- a/drive/driveimpl/compositedav/stat_cache.go +++ b/drive/driveimpl/compositedav/stat_cache.go @@ -4,11 +4,19 @@ package compositedav import ( + "bytes" + "encoding/xml" + "log" "net/http" "sync" "time" "github.com/jellydator/ttlcache/v3" + "tailscale.com/drive/driveimpl/shared" +) + +var ( + notFound = newCacheEntry(http.StatusNotFound, nil) ) // StatCache provides a cache for directory listings and file metadata. @@ -18,12 +26,38 @@ import ( // This is similar to the DirectoryCacheLifetime setting of Windows' built-in // SMB client, see // https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-7/ff686200(v=ws.10) +// +// StatCache is built specifically to cache the results of PROPFIND requests, +// which come back as MultiStatus XML responses. Typical clients will issue two +// kinds of PROPFIND: +// +// The first kind of PROPFIND is a directory listing performed to depth 1. At +// this depth, the resulting XML will contain stats for the requested folder as +// well as for all children of that folder. +// +// The second kind of PROPFIND is a file listing performed to depth 0. At this +// depth, the resulting XML will contain stats only for the requested file. +// +// In order to avoid round-trips, when a PROPFIND at depth 0 is attempted, and +// the requested file is not in the cache, StatCache will check to see if the +// parent folder of that file is cached. If so, StatCache infers the correct +// MultiStatus for the file according to the following logic: +// +// 1. If the parent folder is NotFound (404), treat the file itself as NotFound +// 2. If the parent folder's XML doesn't contain the file, treat it as +// NotFound. +// 3. If the parent folder's XML contains the file, build a MultiStatus for the +// file based on the parent's XML. +// +// To avoid inconsistencies from the perspective of the client, any operations +// that modify the filesystem (e.g. PUT, MKDIR, etc.) should call invalidate() +// to invalidate the cache. type StatCache struct { TTL time.Duration // mu guards the below values. mu sync.Mutex - cachesByDepthAndPath map[int]*ttlcache.Cache[string, []byte] + cachesByDepthAndPath map[int]*ttlcache.Cache[string, *cacheEntry] } // getOr checks the cache for the named value at the given depth. If a cached @@ -32,25 +66,57 @@ type StatCache struct { // status and value. If the function returned http.StatusMultiStatus, getOr // caches the resulting value at the given name and depth before returning. func (c *StatCache) getOr(name string, depth int, or func() (int, []byte)) (int, []byte) { - cached := c.get(name, depth) - if cached != nil { - return http.StatusMultiStatus, cached - } - status, next := or() - if c != nil && status == http.StatusMultiStatus && next != nil { - c.set(name, depth, next) + ce := c.get(name, depth) + if ce == nil { + // Not cached, fetch value. + status, raw := or() + ce = newCacheEntry(status, raw) + if status == http.StatusMultiStatus || status == http.StatusNotFound { + // Got a legit status, cache value + c.set(name, depth, ce) + } } - return status, next + return ce.Status, ce.Raw } -func (c *StatCache) get(name string, depth int) []byte { +// get retrieves the entry for the named file at the given depth. If no entry +// is found, and depth == 0, get will check to see if the parent path of name +// is present in the cache at depth 1. If so, it will infer that the child does +// not exist and return notFound (404). +func (c *StatCache) get(name string, depth int) *cacheEntry { if c == nil { return nil } + name = shared.Normalize(name) + c.mu.Lock() defer c.mu.Unlock() + ce := c.tryGetLocked(name, depth) + if ce != nil { + // Cache hit. + return ce + } + + if depth > 0 { + // Cache miss. + return nil + } + + // At depth 0, if child's parent is in the cache, and the child isn't + // cached, we can infer that the child is notFound. + p := c.tryGetLocked(shared.Parent(name), 1) + if p != nil { + return notFound + } + + // No parent in cache, cache miss. + return nil +} + +// tryGetLocked requires that c.mu be held. +func (c *StatCache) tryGetLocked(name string, depth int) *cacheEntry { if c.cachesByDepthAndPath == nil { return nil } @@ -65,28 +131,80 @@ func (c *StatCache) get(name string, depth int) []byte { return item.Value() } -func (c *StatCache) set(name string, depth int, value []byte) { +// set stores the given cacheEntry in the cache at the given name and depth. If +// the depth is 1, set also populates depth 0 entries in the cache for the bare +// name. If status is StatusMultiStatus, set will parse the PROPFIND result and +// store depth 0 entries for all children. If parsing the result fails, nothing +// is cached. +func (c *StatCache) set(name string, depth int, ce *cacheEntry) { if c == nil { return } + name = shared.Normalize(name) + + var self *cacheEntry + var children map[string]*cacheEntry + if depth == 1 { + switch ce.Status { + case http.StatusNotFound: + // Record notFound as the self entry. + self = ce + case http.StatusMultiStatus: + // Parse the raw MultiStatus and extract specific responses + // corresponding to the self entry (e.g. the directory, but at depth 0) + // and children (e.g. files within the directory) so that subsequent + // requests for these can be satisfied from the cache. + var ms multiStatus + err := xml.Unmarshal(ce.Raw, &ms) + if err != nil { + // unparseable MultiStatus response, don't cache + log.Printf("statcache.set error: %s", err) + return + } + children = make(map[string]*cacheEntry, len(ms.Responses)-1) + for i := 0; i < len(ms.Responses); i++ { + response := ms.Responses[i] + name := shared.Normalize(response.Href) + raw := marshalMultiStatus(response) + entry := newCacheEntry(ce.Status, raw) + if i == 0 { + self = entry + } else { + children[name] = entry + } + } + } + } + c.mu.Lock() defer c.mu.Unlock() + c.setLocked(name, depth, ce) + if self != nil { + c.setLocked(name, 0, self) + } + for childName, child := range children { + c.setLocked(childName, 0, child) + } +} +// setLocked requires that c.mu be held. +func (c *StatCache) setLocked(name string, depth int, ce *cacheEntry) { if c.cachesByDepthAndPath == nil { - c.cachesByDepthAndPath = make(map[int]*ttlcache.Cache[string, []byte]) + c.cachesByDepthAndPath = make(map[int]*ttlcache.Cache[string, *cacheEntry]) } cache := c.cachesByDepthAndPath[depth] if cache == nil { cache = ttlcache.New( - ttlcache.WithTTL[string, []byte](c.TTL), + ttlcache.WithTTL[string, *cacheEntry](c.TTL), ) go cache.Start() c.cachesByDepthAndPath[depth] = cache } - cache.Set(name, value, ttlcache.DefaultTTL) + cache.Set(name, ce, ttlcache.DefaultTTL) } +// invalidate invalidates the entire cache. func (c *StatCache) invalidate() { if c == nil { return @@ -108,3 +226,54 @@ func (c *StatCache) stop() { cache.Stop() } } + +type cacheEntry struct { + Status int + Raw []byte +} + +func newCacheEntry(status int, raw []byte) *cacheEntry { + return &cacheEntry{Status: status, Raw: raw} +} + +type propStat struct { + InnerXML []byte `xml:",innerxml"` +} + +type response struct { + XMLName xml.Name `xml:"response"` + Href string `xml:"href"` + PropStats []*propStat `xml:"propstat"` +} + +type multiStatus struct { + XMLName xml.Name `xml:"multistatus"` + Responses []*response `xml:"response"` +} + +// marshalMultiStatus performs custom marshalling of a MultiStatus to preserve +// the original formatting, namespacing, etc. Doing this with Go's XML encoder +// is somewhere between difficult and impossible, which is why we use this more +// manual approach. +func marshalMultiStatus(response *response) []byte { + // TODO(percy): maybe pool these buffers + var buf bytes.Buffer + buf.WriteString(multistatusTemplateStart) + buf.WriteString(response.Href) + buf.WriteString(hrefEnd) + for _, propStat := range response.PropStats { + buf.WriteString(propstatStart) + buf.Write(propStat.InnerXML) + buf.WriteString(propstatEnd) + } + buf.WriteString(multistatusTemplateEnd) + return buf.Bytes() +} + +const ( + multistatusTemplateStart = `` + hrefEnd = `` + propstatStart = `` + propstatEnd = `` + multistatusTemplateEnd = `` +) diff --git a/drive/driveimpl/compositedav/stat_cache_test.go b/drive/driveimpl/compositedav/stat_cache_test.go index c69832f26..fa63457a2 100644 --- a/drive/driveimpl/compositedav/stat_cache_test.go +++ b/drive/driveimpl/compositedav/stat_cache_test.go @@ -4,17 +4,65 @@ package compositedav import ( - "bytes" + "fmt" + "log" + "net/http" + "path" + "strings" "testing" "time" + "github.com/google/go-cmp/cmp" "tailscale.com/tstest" ) -var ( - val = []byte("1") - file = "file" -) +var parentPath = "/parent" + +var childPath = "/parent/child.txt" + +var parentResponse = ` +/parent/ + + +Mon, 29 Apr 2024 19:52:23 GMT +Fri, 19 Apr 2024 04:13:34 GMT + + + + +HTTP/1.1 200 OK + +` + +var childResponse = ` + +/parent/child.txt + + +Mon, 29 Apr 2024 19:52:23 GMT +Fri, 19 Apr 2024 04:13:34 GMT + + + + +HTTP/1.1 200 OK + +` + +var fullParent = []byte( + strings.ReplaceAll( + fmt.Sprintf(`%s%s`, parentResponse, childResponse), + "\n", "")) + +var partialParent = []byte( + strings.ReplaceAll( + fmt.Sprintf(`%s`, parentResponse), + "\n", "")) + +var fullChild = []byte( + strings.ReplaceAll( + fmt.Sprintf(`%s`, childResponse), + "\n", "")) func TestStatCacheNoTimeout(t *testing.T) { // Make sure we don't leak goroutines @@ -24,22 +72,23 @@ func TestStatCacheNoTimeout(t *testing.T) { defer c.stop() // check get before set - fetched := c.get(file, 1) + fetched := c.get(childPath, 0) if fetched != nil { - t.Errorf("got %q, want nil", fetched) + t.Errorf("got %v, want nil", fetched) } // set new stat - c.set(file, 1, val) - fetched = c.get(file, 1) - if !bytes.Equal(fetched, val) { - t.Errorf("got %q, want %q", fetched, val) + ce := newCacheEntry(http.StatusMultiStatus, fullChild) + c.set(childPath, 0, ce) + fetched = c.get(childPath, 0) + if diff := cmp.Diff(fetched, ce); diff != "" { + t.Errorf("should have gotten cached value; (-got+want):%v", diff) } // fetch stat again, should still be cached - fetched = c.get(file, 1) - if !bytes.Equal(fetched, val) { - t.Errorf("got %q, want %q", fetched, val) + fetched = c.get(childPath, 0) + if diff := cmp.Diff(fetched, ce); diff != "" { + t.Errorf("should still have gotten cached value; (-got+want):%v", diff) } } @@ -51,25 +100,114 @@ func TestStatCacheTimeout(t *testing.T) { defer c.stop() // set new stat - c.set(file, 1, val) - fetched := c.get(file, 1) - if !bytes.Equal(fetched, val) { - t.Errorf("got %q, want %q", fetched, val) + ce := newCacheEntry(http.StatusMultiStatus, fullChild) + c.set(childPath, 0, ce) + fetched := c.get(childPath, 0) + if diff := cmp.Diff(fetched, ce); diff != "" { + t.Errorf("should have gotten cached value; (-got+want):%v", diff) } // wait for cache to expire and refetch stat, should be empty now time.Sleep(c.TTL * 2) - fetched = c.get(file, 1) + fetched = c.get(childPath, 0) if fetched != nil { - t.Errorf("invalidate should have cleared cached value") + t.Errorf("cached value should have expired") } - c.set(file, 1, val) + c.set(childPath, 0, ce) // invalidate the cache and make sure nothing is returned c.invalidate() - fetched = c.get(file, 1) + fetched = c.get(childPath, 0) if fetched != nil { t.Errorf("invalidate should have cleared cached value") } } + +func TestParentChildRelationship(t *testing.T) { + // Make sure we don't leak goroutines + tstest.ResourceCheck(t) + + c := &StatCache{TTL: 24 * time.Hour} // don't expire + defer c.stop() + + missingParentPath := "/missingparent" + unparseableParentPath := "/unparseable" + + c.set(parentPath, 1, newCacheEntry(http.StatusMultiStatus, fullParent)) + c.set(missingParentPath, 1, newCacheEntry(http.StatusNotFound, nil)) + c.set(unparseableParentPath, 1, newCacheEntry(http.StatusMultiStatus, []byte("