diff --git a/drive/driveimpl/compositedav/propfind.go b/drive/driveimpl/compositedav/propfind.go index 60b90edb2..5e6ccfa0b 100644 --- a/drive/driveimpl/compositedav/propfind.go +++ b/drive/driveimpl/compositedav/propfind.go @@ -24,33 +24,26 @@ func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request) { // Delegate to a Child. depth := getDepth(r) - cached := h.StatCache.get(r.URL.Path, depth) - if cached != nil { - w.Header().Del("Content-Length") - w.WriteHeader(http.StatusMultiStatus) - w.Write(cached) - return - } - - // Use a buffering ResponseWriter so that we can manipulate the result. - // The only thing we use from the original ResponseWriter is Header(). - bw := &bufferingResponseWriter{ResponseWriter: w} + status, result := h.StatCache.getOr(r.URL.Path, depth, func() (int, []byte) { + // Use a buffering ResponseWriter so that we can manipulate the result. + // The only thing we use from the original ResponseWriter is Header(). + bw := &bufferingResponseWriter{ResponseWriter: w} - mpl := h.maxPathLength(r) - h.delegate(mpl, pathComponents[mpl-1:], bw, r) + mpl := h.maxPathLength(r) + h.delegate(mpl, pathComponents[mpl-1:], bw, r) - // Fixup paths to add the requested path as a prefix. - pathPrefix := shared.Join(pathComponents[0:mpl]...) - b := hrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("%s/$1", pathPrefix))) + // Fixup paths to add the requested path as a prefix. + pathPrefix := shared.Join(pathComponents[0:mpl]...) + b := hrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("%s/$1", pathPrefix))) - if h.StatCache != nil && bw.status == http.StatusMultiStatus && b != nil { - h.StatCache.set(r.URL.Path, depth, b) - } + return bw.status, b + }) w.Header().Del("Content-Length") - w.WriteHeader(bw.status) - w.Write(b) - + w.WriteHeader(status) + if result != nil { + w.Write(result) + } return } diff --git a/drive/driveimpl/compositedav/stat_cache.go b/drive/driveimpl/compositedav/stat_cache.go index 8be5888a8..4d5971bd0 100644 --- a/drive/driveimpl/compositedav/stat_cache.go +++ b/drive/driveimpl/compositedav/stat_cache.go @@ -4,6 +4,7 @@ package compositedav import ( + "net/http" "sync" "time" @@ -25,6 +26,23 @@ type StatCache struct { cachesByDepthAndPath map[int]*ttlcache.Cache[string, []byte] } +// getOr checks the cache for the named value at the given depth. If a cached +// value was found, it returns http.StatusMultiStatus along with the cached +// value. Otherwise, it executes the given function and returns the resulting +// status and value. If the function returned http.StatusMultiStatus, getOr +// caches the resulting value at the given name and depth before returning. +func (c *StatCache) getOr(name string, depth int, or func() (int, []byte)) (int, []byte) { + cached := c.get(name, depth) + if cached != nil { + return http.StatusMultiStatus, cached + } + status, next := or() + if c != nil && status == http.StatusMultiStatus && next != nil { + c.set(name, depth, next) + } + return status, next +} + func (c *StatCache) get(name string, depth int) []byte { if c == nil { return nil diff --git a/drive/driveimpl/drive_test.go b/drive/driveimpl/drive_test.go index e245ac791..8e9d1a557 100644 --- a/drive/driveimpl/drive_test.go +++ b/drive/driveimpl/drive_test.go @@ -184,7 +184,7 @@ func newSystem(t *testing.T) *system { // Make sure we don't leak goroutines tstest.ResourceCheck(t) - fs := NewFileSystemForLocal(log.Printf) + fs := newFileSystemForLocal(log.Printf, nil) l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to Listen: %s", err) diff --git a/drive/driveimpl/local_impl.go b/drive/driveimpl/local_impl.go index 22936ceaa..8cdf60179 100644 --- a/drive/driveimpl/local_impl.go +++ b/drive/driveimpl/local_impl.go @@ -27,6 +27,10 @@ const ( // NewFileSystemForLocal starts serving a filesystem for local clients. // Inbound connections must be handed to HandleConn. func NewFileSystemForLocal(logf logger.Logf) *FileSystemForLocal { + return newFileSystemForLocal(logf, &compositedav.StatCache{TTL: statCacheTTL}) +} + +func newFileSystemForLocal(logf logger.Logf, statCache *compositedav.StatCache) *FileSystemForLocal { if logf == nil { logf = log.Printf } @@ -34,7 +38,7 @@ func NewFileSystemForLocal(logf logger.Logf) *FileSystemForLocal { logf: logf, h: &compositedav.Handler{ Logf: logf, - StatCache: &compositedav.StatCache{TTL: statCacheTTL}, + StatCache: statCache, }, listener: newConnListener(), }