diff --git a/clientupdate/clientupdate.go b/clientupdate/clientupdate.go index 7aa51fa76..b94c21bab 100644 --- a/clientupdate/clientupdate.go +++ b/clientupdate/clientupdate.go @@ -284,7 +284,7 @@ func (up *Updater) updateSynology() error { return nil } - up.cleanupOldDownloads(filepath.Join(os.TempDir(), "tailscale-update*")) + up.cleanupOldDownloads(filepath.Join(os.TempDir(), "tailscale-update*", "*.spk")) // Download the SPK into a temporary directory. spkDir, err := os.MkdirTemp("", "tailscale-update") if err != nil { @@ -833,6 +833,9 @@ func (up *Updater) installMSI(msi string) error { return err } +// cleanupOldDownloads removes all files matching glob (see filepath.Glob). +// Only regular files are removed, so the glob must match specific files and +// not directories. func (up *Updater) cleanupOldDownloads(glob string) { matches, err := filepath.Glob(glob) if err != nil { @@ -840,7 +843,15 @@ func (up *Updater) cleanupOldDownloads(glob string) { return } for _, m := range matches { - if err := os.RemoveAll(m); err != nil { + s, err := os.Lstat(m) + if err != nil { + up.Logf("cleaning up old downloads: %v", err) + continue + } + if !s.Mode().IsRegular() { + continue + } + if err := os.Remove(m); err != nil { up.Logf("cleaning up old downloads: %v", err) } } diff --git a/clientupdate/clientupdate_test.go b/clientupdate/clientupdate_test.go index dc300341f..36e4e18cc 100644 --- a/clientupdate/clientupdate_test.go +++ b/clientupdate/clientupdate_test.go @@ -11,6 +11,8 @@ import ( "maps" "os" "path/filepath" + "slices" + "sort" "strings" "testing" ) @@ -683,3 +685,113 @@ func TestWriteFileSymlink(t *testing.T) { } } } + +func TestCleanupOldDownloads(t *testing.T) { + tests := []struct { + desc string + before []string + symlinks map[string]string + glob string + after []string + }{ + { + desc: "MSIs", + before: []string{ + "MSICache/tailscale-1.0.0.msi", + "MSICache/tailscale-1.1.0.msi", + "MSICache/readme.txt", + }, + glob: "MSICache/*.msi", + after: []string{ + "MSICache/readme.txt", + }, + }, + { + desc: "SPKs", + before: []string{ + "tmp/tailscale-update-1/tailscale-1.0.0.spk", + "tmp/tailscale-update-2/tailscale-1.1.0.spk", + "tmp/readme.txt", + "tmp/tailscale-update-3", + "tmp/tailscale-update-4/tailscale-1.3.0", + }, + glob: "tmp/tailscale-update*/*.spk", + after: []string{ + "tmp/readme.txt", + "tmp/tailscale-update-3", + "tmp/tailscale-update-4/tailscale-1.3.0", + }, + }, + { + desc: "empty-target", + before: []string{}, + glob: "tmp/tailscale-update*/*.spk", + after: []string{}, + }, + { + desc: "keep-dirs", + before: []string{ + "tmp/tailscale-update-1/tailscale-1.0.0.spk", + }, + glob: "tmp/tailscale-update*", + after: []string{ + "tmp/tailscale-update-1/tailscale-1.0.0.spk", + }, + }, + { + desc: "no-follow-symlinks", + before: []string{ + "MSICache/tailscale-1.0.0.msi", + "MSICache/tailscale-1.1.0.msi", + "MSICache/readme.txt", + }, + symlinks: map[string]string{ + "MSICache/tailscale-1.3.0.msi": "MSICache/tailscale-1.0.0.msi", + "MSICache/tailscale-1.4.0.msi": "MSICache/readme.txt", + }, + glob: "MSICache/*.msi", + after: []string{ + "MSICache/tailscale-1.3.0.msi", + "MSICache/tailscale-1.4.0.msi", + "MSICache/readme.txt", + }, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + dir := t.TempDir() + for _, p := range tt.before { + if err := os.MkdirAll(filepath.Join(dir, filepath.Dir(p)), 0700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, p), []byte(tt.desc), 0600); err != nil { + t.Fatal(err) + } + } + for from, to := range tt.symlinks { + if err := os.Symlink(filepath.Join(dir, to), filepath.Join(dir, from)); err != nil { + t.Fatal(err) + } + } + + up := &Updater{Arguments: Arguments{Logf: t.Logf}} + up.cleanupOldDownloads(filepath.Join(dir, tt.glob)) + + var after []string + if err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if !d.IsDir() { + after = append(after, strings.TrimPrefix(filepath.ToSlash(path), filepath.ToSlash(dir)+"/")) + } + return nil + }); err != nil { + t.Fatal(err) + } + + sort.Strings(after) + sort.Strings(tt.after) + if !slices.Equal(after, tt.after) { + t.Errorf("got files after cleanup: %q, want: %q", after, tt.after) + } + }) + } +}