diff --git a/cmd/tailscale/cli/file.go b/cmd/tailscale/cli/file.go index 84f7da8c2..a340275ae 100644 --- a/cmd/tailscale/cli/file.go +++ b/cmd/tailscale/cli/file.go @@ -15,6 +15,7 @@ import ( "mime" "net/http" "os" + "path" "path/filepath" "strings" "time" @@ -286,22 +287,116 @@ func runCpTargets(ctx context.Context, args []string) error { return nil } +// onConflict is a flag.Value for the --conflict flag's three string options. +type onConflict string + +const ( + skipOnExist onConflict = "skip" + overwriteExisting onConflict = "overwrite" // Overwrite any existing file at the target location + createNumberedFiles onConflict = "rename" // Create an alternately named file in the style of Chrome Downloads +) + +func (v *onConflict) String() string { return string(*v) } + +func (v *onConflict) Set(s string) error { + if s == "" { + *v = skipOnExist + return nil + } + *v = onConflict(strings.ToLower(s)) + if *v != skipOnExist && *v != overwriteExisting && *v != createNumberedFiles { + return fmt.Errorf("%q is not one of (skip|overwrite|rename)", s) + } + return nil +} + var fileGetCmd = &ffcli.Command{ Name: "get", - ShortUsage: "file get [--wait] [--verbose] ", + ShortUsage: "file get [--wait] [--verbose] [--conflict=(skip|overwrite|rename)] ", ShortHelp: "Move files out of the Tailscale file inbox", Exec: runFileGet, FlagSet: (func() *flag.FlagSet { fs := newFlagSet("get") fs.BoolVar(&getArgs.wait, "wait", false, "wait for a file to arrive if inbox is empty") fs.BoolVar(&getArgs.verbose, "verbose", false, "verbose output") + fs.Var(&getArgs.conflict, "conflict", `behavior when a conflicting (same-named) file already exists in the target directory. + skip: skip conflicting files: leave them in the taildrop inbox and print an error. get any non-conflicting files + overwrite: overwrite existing file + rename: write to a new number-suffixed filename`) return fs })(), } -var getArgs struct { - wait bool - verbose bool +var getArgs = struct { + wait bool + verbose bool + conflict onConflict +}{conflict: skipOnExist} + +func numberedFileName(dir, name string, i int) string { + ext := path.Ext(name) + return filepath.Join(dir, fmt.Sprintf("%s (%d)%s", + strings.TrimSuffix(name, ext), + i, ext)) +} + +func openFileOrSubstitute(dir, base string, action onConflict) (*os.File, error) { + targetFile := filepath.Join(dir, base) + f, err := os.OpenFile(targetFile, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644) + if err == nil { + return f, nil + } + // Something went wrong trying to open targetFile as a new file for writing. + switch action { + default: + // This should not happen. + return nil, fmt.Errorf("file issue. how to resolve this conflict? no one knows.") + case skipOnExist: + if _, statErr := os.Stat(targetFile); statErr == nil { + // we can stat a file at that path: so it already exists. + return nil, fmt.Errorf("refusing to overwrite file: %w", err) + } + return nil, fmt.Errorf("failed to write; %w", err) + case overwriteExisting: + // remove the target file and create it anew so we don't fall for an + // attacker who symlinks a known target name to a file he wants changed. + if err = os.Remove(targetFile); err != nil { + return nil, fmt.Errorf("unable to remove target file: %w", err) + } + if f, err = os.OpenFile(targetFile, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644); err != nil { + return nil, fmt.Errorf("unable to overwrite: %w", err) + } + return f, nil + case createNumberedFiles: + // It's possible the target directory or filesystem isn't writable by us, + // not just that the target file(s) already exists. For now, give up after + // a limited number of attempts. In future, maybe distinguish this case + // and follow in the style of https://tinyurl.com/chromium100 + maxAttempts := 100 + for i := 1; i < maxAttempts; i++ { + if f, err = os.OpenFile(numberedFileName(dir, base, i), os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644); err == nil { + return f, nil + } + } + return nil, fmt.Errorf("unable to find a name for writing %v, final attempt: %w", targetFile, err) + } +} + +func receiveFile(ctx context.Context, wf apitype.WaitingFile, dir string) (targetFile string, size int64, err error) { + rc, size, err := tailscale.GetWaitingFile(ctx, wf.Name) + if err != nil { + return "", 0, fmt.Errorf("opening inbox file %q: %w", wf.Name, err) + } + f, err := openFileOrSubstitute(dir, wf.Name, getArgs.conflict) + if err != nil { + return "", 0, err + } + _, err = io.Copy(f, rc) + rc.Close() + if err != nil { + return "", 0, fmt.Errorf("failed to write %v: %v", f.Name(), err) + } + return f.Name(), size, f.Close() } func runFileGet(ctx context.Context, args []string) error { @@ -330,47 +425,40 @@ func runFileGet(ctx context.Context, args []string) error { break } if getArgs.verbose { - log.Printf("waiting for file...") + printf("waiting for file...") } if err := waitForFile(ctx); err != nil { return err } } + var errs []error deleted := 0 for _, wf := range wfs { - rc, size, err := tailscale.GetWaitingFile(ctx, wf.Name) - if err != nil { - return fmt.Errorf("opening inbox file %q: %v", wf.Name, err) - } - targetFile := filepath.Join(dir, wf.Name) - of, err := os.OpenFile(targetFile, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644) - if err != nil { - if _, err := os.Stat(targetFile); err == nil { - return fmt.Errorf("refusing to overwrite %v", targetFile) - } - return err - } - _, err = io.Copy(of, rc) - rc.Close() + writtenFile, size, err := receiveFile(ctx, wf, dir) if err != nil { - return fmt.Errorf("failed to write %v: %v", targetFile, err) - } - if err := of.Close(); err != nil { - return err + errs = append(errs, err) + continue } if getArgs.verbose { - log.Printf("wrote %v (%d bytes)", wf.Name, size) + printf("wrote %v as %v (%d bytes)\n", wf.Name, writtenFile, size) } - if err := tailscale.DeleteWaitingFile(ctx, wf.Name); err != nil { - return fmt.Errorf("deleting %q from inbox: %v", wf.Name, err) + if err = tailscale.DeleteWaitingFile(ctx, wf.Name); err != nil { + errs = append(errs, fmt.Errorf("deleting %q from inbox: %v", wf.Name, err)) + continue } deleted++ } if getArgs.verbose { - log.Printf("moved %d files", deleted) + printf("moved %d/%d files\n", deleted, len(wfs)) } - return nil + if len(errs) == 0 { + return nil + } + for _, err := range errs[:len(errs)-1] { + outln(err) + } + return errs[len(errs)-1] } func wipeInbox(ctx context.Context) error {