diff --git a/cmd/tailscale/cli/file.go b/cmd/tailscale/cli/file.go index 65a834f20..19da6ab39 100644 --- a/cmd/tailscale/cli/file.go +++ b/cmd/tailscale/cli/file.go @@ -19,9 +19,12 @@ import ( "path" "path/filepath" "strings" + "sync" + "sync/atomic" "time" "unicode/utf8" + "github.com/mattn/go-isatty" "github.com/peterbourgon/ff/v3/ffcli" "golang.org/x/time/rate" "tailscale.com/client/tailscale/apitype" @@ -49,6 +52,17 @@ var fileCmd = &ffcli.Command{ }, } +type countingReader struct { + io.Reader + n atomic.Uint64 +} + +func (c *countingReader) Read(buf []byte) (int, error) { + n, err := c.Reader.Read(buf) + c.n.Add(uint64(n)) + return n, err +} + var fileCpCmd = &ffcli.Command{ Name: "cp", ShortUsage: "file cp :", @@ -116,11 +130,11 @@ func runCp(ctx context.Context, args []string) error { } for _, fileArg := range files { - var fileContents io.Reader + var fileContents *countingReader var name = cpArgs.name var contentLength int64 = -1 if fileArg == "-" { - fileContents = os.Stdin + fileContents = &countingReader{Reader: os.Stdin} if name == "" { name, fileContents, err = pickStdinFilename() if err != nil { @@ -144,19 +158,29 @@ func runCp(ctx context.Context, args []string) error { return errors.New("directories not supported") } contentLength = fi.Size() - fileContents = io.LimitReader(f, contentLength) + fileContents = &countingReader{Reader: io.LimitReader(f, contentLength)} if name == "" { name = filepath.Base(fileArg) } if envknob.Bool("TS_DEBUG_SLOW_PUSH") { - fileContents = &slowReader{r: fileContents} + fileContents = &countingReader{Reader: &slowReader{r: fileContents}} } } if cpArgs.verbose { log.Printf("sending %q to %v/%v/%v ...", name, target, ip, stableID) } + + var ( + done = make(chan struct{}, 1) + wg sync.WaitGroup + ) + if isatty.IsTerminal(os.Stderr.Fd()) { + go printProgress(&wg, done, fileContents, name, contentLength) + wg.Add(1) + } + err := localClient.PushFile(ctx, stableID, contentLength, name, fileContents) if err != nil { return err @@ -164,10 +188,61 @@ func runCp(ctx context.Context, args []string) error { if cpArgs.verbose { log.Printf("sent %q", name) } + done <- struct{}{} + wg.Wait() } return nil } +const vtRestartLine = "\r\x1b[K" + +func printProgress(wg *sync.WaitGroup, done <-chan struct{}, r *countingReader, name string, contentLength int64) { + defer wg.Done() + var lastBytesRead uint64 + + for { + select { + case <-done: + fmt.Fprintln(os.Stderr) + return + case <-time.After(time.Second): + n := r.n.Load() + contentLengthStr := "???" + if contentLength > 0 { + contentLengthStr = fmt.Sprint(contentLength / 1024) + } + + fmt.Fprintf(os.Stderr, "%s%s\t\t%s", vtRestartLine, padTruncateString(name, 36), padTruncateString(fmt.Sprintf("%d/%s kb", n/1024, contentLengthStr), 16)) + if contentLength > 0 { + fmt.Fprintf(os.Stderr, "\t%.02f%%", float64(n)/float64(contentLength)*100) + } else { + fmt.Fprintf(os.Stderr, "\t-------%%") + } + if lastBytesRead > 0 { + fmt.Fprintf(os.Stderr, "\t%d kb/s", (n-lastBytesRead)/1024) + } else { + fmt.Fprintf(os.Stderr, "\t-------") + } + lastBytesRead = n + } + } +} + +func padTruncateString(str string, truncateAt int) string { + if len(str) <= truncateAt { + return str + strings.Repeat(" ", truncateAt-len(str)) + } + + // Truncate the string, but respect unicode codepoint boundaries. + // As of RFC3629 utf-8 codepoints can be at most 4 bytes wide. + for i := 1; i <= 4 && i < len(str)-truncateAt; i++ { + if utf8.ValidString(str[:truncateAt-i]) { + return str[:truncateAt-i] + "…" + } + } + return "" // Should be unreachable +} + func getTargetStableID(ctx context.Context, ipStr string) (id tailcfg.StableNodeID, isOffline bool, err error) { ip, err := netip.ParseAddr(ipStr) if err != nil { @@ -230,12 +305,12 @@ func ext(b []byte) string { // pickStdinFilename reads a bit of stdin to return a good filename // for its contents. The returned Reader is the concatenation of the // read and unread bits. -func pickStdinFilename() (name string, r io.Reader, err error) { +func pickStdinFilename() (name string, r *countingReader, err error) { sniff, err := io.ReadAll(io.LimitReader(os.Stdin, maxSniff)) if err != nil { return "", nil, err } - return "stdin" + ext(sniff), io.MultiReader(bytes.NewReader(sniff), os.Stdin), nil + return "stdin" + ext(sniff), &countingReader{Reader: io.MultiReader(bytes.NewReader(sniff), os.Stdin)}, nil } type slowReader struct {