diff --git a/drive/driveimpl/compositedav/propfind.go b/drive/driveimpl/compositedav/propfind.go
index 60b90edb2..5e6ccfa0b 100644
--- a/drive/driveimpl/compositedav/propfind.go
+++ b/drive/driveimpl/compositedav/propfind.go
@@ -24,33 +24,26 @@ func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request) {
// Delegate to a Child.
depth := getDepth(r)
- cached := h.StatCache.get(r.URL.Path, depth)
- if cached != nil {
- w.Header().Del("Content-Length")
- w.WriteHeader(http.StatusMultiStatus)
- w.Write(cached)
- return
- }
-
- // Use a buffering ResponseWriter so that we can manipulate the result.
- // The only thing we use from the original ResponseWriter is Header().
- bw := &bufferingResponseWriter{ResponseWriter: w}
+ status, result := h.StatCache.getOr(r.URL.Path, depth, func() (int, []byte) {
+ // Use a buffering ResponseWriter so that we can manipulate the result.
+ // The only thing we use from the original ResponseWriter is Header().
+ bw := &bufferingResponseWriter{ResponseWriter: w}
- mpl := h.maxPathLength(r)
- h.delegate(mpl, pathComponents[mpl-1:], bw, r)
+ mpl := h.maxPathLength(r)
+ h.delegate(mpl, pathComponents[mpl-1:], bw, r)
- // Fixup paths to add the requested path as a prefix.
- pathPrefix := shared.Join(pathComponents[0:mpl]...)
- b := hrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("%s/$1", pathPrefix)))
+ // Fixup paths to add the requested path as a prefix.
+ pathPrefix := shared.Join(pathComponents[0:mpl]...)
+ b := hrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("%s/$1", pathPrefix)))
- if h.StatCache != nil && bw.status == http.StatusMultiStatus && b != nil {
- h.StatCache.set(r.URL.Path, depth, b)
- }
+ return bw.status, b
+ })
w.Header().Del("Content-Length")
- w.WriteHeader(bw.status)
- w.Write(b)
-
+ w.WriteHeader(status)
+ if result != nil {
+ w.Write(result)
+ }
return
}
diff --git a/drive/driveimpl/compositedav/stat_cache.go b/drive/driveimpl/compositedav/stat_cache.go
index 8be5888a8..4d5971bd0 100644
--- a/drive/driveimpl/compositedav/stat_cache.go
+++ b/drive/driveimpl/compositedav/stat_cache.go
@@ -4,6 +4,7 @@
package compositedav
import (
+ "net/http"
"sync"
"time"
@@ -25,6 +26,23 @@ type StatCache struct {
cachesByDepthAndPath map[int]*ttlcache.Cache[string, []byte]
}
+// getOr checks the cache for the named value at the given depth. If a cached
+// value was found, it returns http.StatusMultiStatus along with the cached
+// value. Otherwise, it executes the given function and returns the resulting
+// status and value. If the function returned http.StatusMultiStatus, getOr
+// caches the resulting value at the given name and depth before returning.
+func (c *StatCache) getOr(name string, depth int, or func() (int, []byte)) (int, []byte) {
+ cached := c.get(name, depth)
+ if cached != nil {
+ return http.StatusMultiStatus, cached
+ }
+ status, next := or()
+ if c != nil && status == http.StatusMultiStatus && next != nil {
+ c.set(name, depth, next)
+ }
+ return status, next
+}
+
func (c *StatCache) get(name string, depth int) []byte {
if c == nil {
return nil
diff --git a/drive/driveimpl/drive_test.go b/drive/driveimpl/drive_test.go
index e245ac791..8e9d1a557 100644
--- a/drive/driveimpl/drive_test.go
+++ b/drive/driveimpl/drive_test.go
@@ -184,7 +184,7 @@ func newSystem(t *testing.T) *system {
// Make sure we don't leak goroutines
tstest.ResourceCheck(t)
- fs := NewFileSystemForLocal(log.Printf)
+ fs := newFileSystemForLocal(log.Printf, nil)
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to Listen: %s", err)
diff --git a/drive/driveimpl/local_impl.go b/drive/driveimpl/local_impl.go
index 22936ceaa..8cdf60179 100644
--- a/drive/driveimpl/local_impl.go
+++ b/drive/driveimpl/local_impl.go
@@ -27,6 +27,10 @@ const (
// NewFileSystemForLocal starts serving a filesystem for local clients.
// Inbound connections must be handed to HandleConn.
func NewFileSystemForLocal(logf logger.Logf) *FileSystemForLocal {
+ return newFileSystemForLocal(logf, &compositedav.StatCache{TTL: statCacheTTL})
+}
+
+func newFileSystemForLocal(logf logger.Logf, statCache *compositedav.StatCache) *FileSystemForLocal {
if logf == nil {
logf = log.Printf
}
@@ -34,7 +38,7 @@ func NewFileSystemForLocal(logf logger.Logf) *FileSystemForLocal {
logf: logf,
h: &compositedav.Handler{
Logf: logf,
- StatCache: &compositedav.StatCache{TTL: statCacheTTL},
+ StatCache: statCache,
},
listener: newConnListener(),
}