From 7209c4f91e97eda632f86c9f0def600cf0a03949 Mon Sep 17 00:00:00 2001 From: Percy Wegmann Date: Fri, 3 May 2024 17:20:13 -0500 Subject: [PATCH] drive: parse depth 1 PROPFIND results to include children in cache Clients often perform a PROPFIND for the parent directory before performing PROPFIND for specific children within that directory. The PROPFIND for the parent directory is usually done at depth 1, meaning that we already have information for all of the children. By immediately adding that to the cache, we save a roundtrip to the remote peer on the PROPFIND for the specific child. Updates tailscale/corp#19779 Signed-off-by: Percy Wegmann --- drive/driveimpl/compositedav/compositedav.go | 17 +- drive/driveimpl/compositedav/stat_cache.go | 197 ++++++++++++++++-- .../driveimpl/compositedav/stat_cache_test.go | 182 ++++++++++++++-- drive/driveimpl/shared/pathutil.go | 11 + 4 files changed, 368 insertions(+), 39 deletions(-) diff --git a/drive/driveimpl/compositedav/compositedav.go b/drive/driveimpl/compositedav/compositedav.go index 9e5a293ac..8b41871ad 100644 --- a/drive/driveimpl/compositedav/compositedav.go +++ b/drive/driveimpl/compositedav/compositedav.go @@ -81,6 +81,16 @@ type Handler struct { staticRoot string } +var cacheInvalidatingMethods = map[string]bool{ + "PUT": true, + "POST": true, + "COPY": true, + "MKCOL": true, + "MOVE": true, + "PROPPATCH": true, + "DELETE": true, +} + // ServeHTTP implements http.Handler. func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == "PROPFIND" { @@ -88,11 +98,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if r.Method != "GET" { - // If the user is performing a modification (e.g. PUT, MKDIR, etc), + _, shouldInvalidate := cacheInvalidatingMethods[r.Method] + if shouldInvalidate { + // If the user is performing a modification (e.g. PUT, MKDIR, etc.), // we need to invalidate the StatCache to make sure we're not knowingly // showing stale stats. - // TODO(oxtoacart): maybe be more selective about invalidating cache + // TODO(oxtoacart): maybe only invalidate specific paths h.StatCache.invalidate() } diff --git a/drive/driveimpl/compositedav/stat_cache.go b/drive/driveimpl/compositedav/stat_cache.go index 4d5971bd0..fc57ff064 100644 --- a/drive/driveimpl/compositedav/stat_cache.go +++ b/drive/driveimpl/compositedav/stat_cache.go @@ -4,11 +4,19 @@ package compositedav import ( + "bytes" + "encoding/xml" + "log" "net/http" "sync" "time" "github.com/jellydator/ttlcache/v3" + "tailscale.com/drive/driveimpl/shared" +) + +var ( + notFound = newCacheEntry(http.StatusNotFound, nil) ) // StatCache provides a cache for directory listings and file metadata. @@ -18,12 +26,38 @@ import ( // This is similar to the DirectoryCacheLifetime setting of Windows' built-in // SMB client, see // https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-7/ff686200(v=ws.10) +// +// StatCache is built specifically to cache the results of PROPFIND requests, +// which come back as MultiStatus XML responses. Typical clients will issue two +// kinds of PROPFIND: +// +// The first kind of PROPFIND is a directory listing performed to depth 1. At +// this depth, the resulting XML will contain stats for the requested folder as +// well as for all children of that folder. +// +// The second kind of PROPFIND is a file listing performed to depth 0. At this +// depth, the resulting XML will contain stats only for the requested file. +// +// In order to avoid round-trips, when a PROPFIND at depth 0 is attempted, and +// the requested file is not in the cache, StatCache will check to see if the +// parent folder of that file is cached. If so, StatCache infers the correct +// MultiStatus for the file according to the following logic: +// +// 1. If the parent folder is NotFound (404), treat the file itself as NotFound +// 2. If the parent folder's XML doesn't contain the file, treat it as +// NotFound. +// 3. If the parent folder's XML contains the file, build a MultiStatus for the +// file based on the parent's XML. +// +// To avoid inconsistencies from the perspective of the client, any operations +// that modify the filesystem (e.g. PUT, MKDIR, etc.) should call invalidate() +// to invalidate the cache. type StatCache struct { TTL time.Duration // mu guards the below values. mu sync.Mutex - cachesByDepthAndPath map[int]*ttlcache.Cache[string, []byte] + cachesByDepthAndPath map[int]*ttlcache.Cache[string, *cacheEntry] } // getOr checks the cache for the named value at the given depth. If a cached @@ -32,25 +66,57 @@ type StatCache struct { // status and value. If the function returned http.StatusMultiStatus, getOr // caches the resulting value at the given name and depth before returning. func (c *StatCache) getOr(name string, depth int, or func() (int, []byte)) (int, []byte) { - cached := c.get(name, depth) - if cached != nil { - return http.StatusMultiStatus, cached - } - status, next := or() - if c != nil && status == http.StatusMultiStatus && next != nil { - c.set(name, depth, next) + ce := c.get(name, depth) + if ce == nil { + // Not cached, fetch value. + status, raw := or() + ce = newCacheEntry(status, raw) + if status == http.StatusMultiStatus || status == http.StatusNotFound { + // Got a legit status, cache value + c.set(name, depth, ce) + } } - return status, next + return ce.Status, ce.Raw } -func (c *StatCache) get(name string, depth int) []byte { +// get retrieves the entry for the named file at the given depth. If no entry +// is found, and depth == 0, get will check to see if the parent path of name +// is present in the cache at depth 1. If so, it will infer that the child does +// not exist and return notFound (404). +func (c *StatCache) get(name string, depth int) *cacheEntry { if c == nil { return nil } + name = shared.Normalize(name) + c.mu.Lock() defer c.mu.Unlock() + ce := c.tryGetLocked(name, depth) + if ce != nil { + // Cache hit. + return ce + } + + if depth > 0 { + // Cache miss. + return nil + } + + // At depth 0, if child's parent is in the cache, and the child isn't + // cached, we can infer that the child is notFound. + p := c.tryGetLocked(shared.Parent(name), 1) + if p != nil { + return notFound + } + + // No parent in cache, cache miss. + return nil +} + +// tryGetLocked requires that c.mu be held. +func (c *StatCache) tryGetLocked(name string, depth int) *cacheEntry { if c.cachesByDepthAndPath == nil { return nil } @@ -65,28 +131,80 @@ func (c *StatCache) get(name string, depth int) []byte { return item.Value() } -func (c *StatCache) set(name string, depth int, value []byte) { +// set stores the given cacheEntry in the cache at the given name and depth. If +// the depth is 1, set also populates depth 0 entries in the cache for the bare +// name. If status is StatusMultiStatus, set will parse the PROPFIND result and +// store depth 0 entries for all children. If parsing the result fails, nothing +// is cached. +func (c *StatCache) set(name string, depth int, ce *cacheEntry) { if c == nil { return } + name = shared.Normalize(name) + + var self *cacheEntry + var children map[string]*cacheEntry + if depth == 1 { + switch ce.Status { + case http.StatusNotFound: + // Record notFound as the self entry. + self = ce + case http.StatusMultiStatus: + // Parse the raw MultiStatus and extract specific responses + // corresponding to the self entry (e.g. the directory, but at depth 0) + // and children (e.g. files within the directory) so that subsequent + // requests for these can be satisfied from the cache. + var ms multiStatus + err := xml.Unmarshal(ce.Raw, &ms) + if err != nil { + // unparseable MultiStatus response, don't cache + log.Printf("statcache.set error: %s", err) + return + } + children = make(map[string]*cacheEntry, len(ms.Responses)-1) + for i := 0; i < len(ms.Responses); i++ { + response := ms.Responses[i] + name := shared.Normalize(response.Href) + raw := marshalMultiStatus(response) + entry := newCacheEntry(ce.Status, raw) + if i == 0 { + self = entry + } else { + children[name] = entry + } + } + } + } + c.mu.Lock() defer c.mu.Unlock() + c.setLocked(name, depth, ce) + if self != nil { + c.setLocked(name, 0, self) + } + for childName, child := range children { + c.setLocked(childName, 0, child) + } +} +// setLocked requires that c.mu be held. +func (c *StatCache) setLocked(name string, depth int, ce *cacheEntry) { if c.cachesByDepthAndPath == nil { - c.cachesByDepthAndPath = make(map[int]*ttlcache.Cache[string, []byte]) + c.cachesByDepthAndPath = make(map[int]*ttlcache.Cache[string, *cacheEntry]) } cache := c.cachesByDepthAndPath[depth] if cache == nil { cache = ttlcache.New( - ttlcache.WithTTL[string, []byte](c.TTL), + ttlcache.WithTTL[string, *cacheEntry](c.TTL), ) go cache.Start() c.cachesByDepthAndPath[depth] = cache } - cache.Set(name, value, ttlcache.DefaultTTL) + cache.Set(name, ce, ttlcache.DefaultTTL) } +// invalidate invalidates the entire cache. func (c *StatCache) invalidate() { if c == nil { return @@ -108,3 +226,54 @@ func (c *StatCache) stop() { cache.Stop() } } + +type cacheEntry struct { + Status int + Raw []byte +} + +func newCacheEntry(status int, raw []byte) *cacheEntry { + return &cacheEntry{Status: status, Raw: raw} +} + +type propStat struct { + InnerXML []byte `xml:",innerxml"` +} + +type response struct { + XMLName xml.Name `xml:"response"` + Href string `xml:"href"` + PropStats []*propStat `xml:"propstat"` +} + +type multiStatus struct { + XMLName xml.Name `xml:"multistatus"` + Responses []*response `xml:"response"` +} + +// marshalMultiStatus performs custom marshalling of a MultiStatus to preserve +// the original formatting, namespacing, etc. Doing this with Go's XML encoder +// is somewhere between difficult and impossible, which is why we use this more +// manual approach. +func marshalMultiStatus(response *response) []byte { + // TODO(percy): maybe pool these buffers + var buf bytes.Buffer + buf.WriteString(multistatusTemplateStart) + buf.WriteString(response.Href) + buf.WriteString(hrefEnd) + for _, propStat := range response.PropStats { + buf.WriteString(propstatStart) + buf.Write(propStat.InnerXML) + buf.WriteString(propstatEnd) + } + buf.WriteString(multistatusTemplateEnd) + return buf.Bytes() +} + +const ( + multistatusTemplateStart = `` + hrefEnd = `` + propstatStart = `` + propstatEnd = `` + multistatusTemplateEnd = `` +) diff --git a/drive/driveimpl/compositedav/stat_cache_test.go b/drive/driveimpl/compositedav/stat_cache_test.go index c69832f26..fa63457a2 100644 --- a/drive/driveimpl/compositedav/stat_cache_test.go +++ b/drive/driveimpl/compositedav/stat_cache_test.go @@ -4,17 +4,65 @@ package compositedav import ( - "bytes" + "fmt" + "log" + "net/http" + "path" + "strings" "testing" "time" + "github.com/google/go-cmp/cmp" "tailscale.com/tstest" ) -var ( - val = []byte("1") - file = "file" -) +var parentPath = "/parent" + +var childPath = "/parent/child.txt" + +var parentResponse = ` +/parent/ + + +Mon, 29 Apr 2024 19:52:23 GMT +Fri, 19 Apr 2024 04:13:34 GMT + + + + +HTTP/1.1 200 OK + +` + +var childResponse = ` + +/parent/child.txt + + +Mon, 29 Apr 2024 19:52:23 GMT +Fri, 19 Apr 2024 04:13:34 GMT + + + + +HTTP/1.1 200 OK + +` + +var fullParent = []byte( + strings.ReplaceAll( + fmt.Sprintf(`%s%s`, parentResponse, childResponse), + "\n", "")) + +var partialParent = []byte( + strings.ReplaceAll( + fmt.Sprintf(`%s`, parentResponse), + "\n", "")) + +var fullChild = []byte( + strings.ReplaceAll( + fmt.Sprintf(`%s`, childResponse), + "\n", "")) func TestStatCacheNoTimeout(t *testing.T) { // Make sure we don't leak goroutines @@ -24,22 +72,23 @@ func TestStatCacheNoTimeout(t *testing.T) { defer c.stop() // check get before set - fetched := c.get(file, 1) + fetched := c.get(childPath, 0) if fetched != nil { - t.Errorf("got %q, want nil", fetched) + t.Errorf("got %v, want nil", fetched) } // set new stat - c.set(file, 1, val) - fetched = c.get(file, 1) - if !bytes.Equal(fetched, val) { - t.Errorf("got %q, want %q", fetched, val) + ce := newCacheEntry(http.StatusMultiStatus, fullChild) + c.set(childPath, 0, ce) + fetched = c.get(childPath, 0) + if diff := cmp.Diff(fetched, ce); diff != "" { + t.Errorf("should have gotten cached value; (-got+want):%v", diff) } // fetch stat again, should still be cached - fetched = c.get(file, 1) - if !bytes.Equal(fetched, val) { - t.Errorf("got %q, want %q", fetched, val) + fetched = c.get(childPath, 0) + if diff := cmp.Diff(fetched, ce); diff != "" { + t.Errorf("should still have gotten cached value; (-got+want):%v", diff) } } @@ -51,25 +100,114 @@ func TestStatCacheTimeout(t *testing.T) { defer c.stop() // set new stat - c.set(file, 1, val) - fetched := c.get(file, 1) - if !bytes.Equal(fetched, val) { - t.Errorf("got %q, want %q", fetched, val) + ce := newCacheEntry(http.StatusMultiStatus, fullChild) + c.set(childPath, 0, ce) + fetched := c.get(childPath, 0) + if diff := cmp.Diff(fetched, ce); diff != "" { + t.Errorf("should have gotten cached value; (-got+want):%v", diff) } // wait for cache to expire and refetch stat, should be empty now time.Sleep(c.TTL * 2) - fetched = c.get(file, 1) + fetched = c.get(childPath, 0) if fetched != nil { - t.Errorf("invalidate should have cleared cached value") + t.Errorf("cached value should have expired") } - c.set(file, 1, val) + c.set(childPath, 0, ce) // invalidate the cache and make sure nothing is returned c.invalidate() - fetched = c.get(file, 1) + fetched = c.get(childPath, 0) if fetched != nil { t.Errorf("invalidate should have cleared cached value") } } + +func TestParentChildRelationship(t *testing.T) { + // Make sure we don't leak goroutines + tstest.ResourceCheck(t) + + c := &StatCache{TTL: 24 * time.Hour} // don't expire + defer c.stop() + + missingParentPath := "/missingparent" + unparseableParentPath := "/unparseable" + + c.set(parentPath, 1, newCacheEntry(http.StatusMultiStatus, fullParent)) + c.set(missingParentPath, 1, newCacheEntry(http.StatusNotFound, nil)) + c.set(unparseableParentPath, 1, newCacheEntry(http.StatusMultiStatus, []byte("