diff --git a/cmd/speedtest/speedtest.go b/cmd/speedtest/speedtest.go new file mode 100644 index 000000000..5f318b63b --- /dev/null +++ b/cmd/speedtest/speedtest.go @@ -0,0 +1,121 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Program speedtest provides the speedtest command. The reason to keep it separate from +// the normal tailscale cli is because it is not yet ready to go in the tailscale binary. +// It will be included in the tailscale cli after it has been added to tailscaled. + +// Example usage for client command: go run cmd/speedtest -host 127.0.0.1:20333 -t 5s +// This will connect to the server on 127.0.0.1:20333 and start a 5 second download speedtest. +// Example usage for server command: go run cmd/speedtest -s -host :20333 +// This will start a speedtest server on port 20333. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "net" + "os" + "strconv" + "text/tabwriter" + "time" + + "github.com/peterbourgon/ff/v2/ffcli" + "tailscale.com/net/speedtest" +) + +// Runs the speedtest command as a commandline program +func main() { + args := os.Args[1:] + if err := speedtestCmd.Parse(args); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + + err := speedtestCmd.Run(context.Background()) + if errors.Is(err, flag.ErrHelp) { + fmt.Fprintln(os.Stderr, speedtestCmd.ShortUsage) + os.Exit(2) + } + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +// speedtestCmd is the root command. It runs either the server or client depending on the +// flags passed to it. +var speedtestCmd = &ffcli.Command{ + Name: "speedtest", + ShortUsage: "speedtest [-host ] [-s] [-r] [-t ]", + ShortHelp: "Run a speed test", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("speedtest", flag.ExitOnError) + fs.StringVar(&speedtestArgs.host, "host", ":20333", "host:port pair to connect to or listen on") + fs.DurationVar(&speedtestArgs.testDuration, "t", speedtest.DefaultDuration, "duration of the speed test") + fs.BoolVar(&speedtestArgs.runServer, "s", false, "run a speedtest server") + fs.BoolVar(&speedtestArgs.reverse, "r", false, "run in reverse mode (server sends, client receives)") + return fs + })(), + Exec: runSpeedtest, +} + +var speedtestArgs struct { + host string + testDuration time.Duration + runServer bool + reverse bool +} + +func runSpeedtest(ctx context.Context, args []string) error { + + if _, _, err := net.SplitHostPort(speedtestArgs.host); err != nil { + var addrErr *net.AddrError + if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { + // if no port is provided, append the default port + speedtestArgs.host = net.JoinHostPort(speedtestArgs.host, strconv.Itoa(speedtest.DefaultPort)) + } + } + + if speedtestArgs.runServer { + listener, err := net.Listen("tcp", speedtestArgs.host) + if err != nil { + return err + } + + fmt.Printf("listening on %v\n", listener.Addr()) + + return speedtest.Serve(listener) + } + + // Ensure the duration is within the allowed range + if speedtestArgs.testDuration < speedtest.MinDuration || speedtestArgs.testDuration > speedtest.MaxDuration { + return fmt.Errorf("test duration must be within %v and %v", speedtest.MinDuration, speedtest.MaxDuration) + } + + dir := speedtest.Download + if speedtestArgs.reverse { + dir = speedtest.Upload + } + + fmt.Printf("Starting a %s test with %s\n", dir, speedtestArgs.host) + results, err := speedtest.RunClient(dir, speedtestArgs.testDuration, speedtestArgs.host) + if err != nil { + return err + } + + w := tabwriter.NewWriter(os.Stdout, 12, 0, 0, ' ', tabwriter.TabIndent) + fmt.Println("Results:") + fmt.Fprintln(w, "Interval\t\tTransfer\t\tBandwidth\t\t") + for _, r := range results { + if r.Total { + fmt.Fprintln(w, "-------------------------------------------------------------------------") + } + fmt.Fprintf(w, "%.2f-%.2f\tsec\t%.4f\tMBits\t%.4f\tMbits/sec\t\n", r.IntervalStart.Seconds(), r.IntervalEnd.Seconds(), r.MegaBits(), r.MBitsPerSecond()) + } + w.Flush() + return nil +} diff --git a/net/speedtest/speedtest.go b/net/speedtest/speedtest.go new file mode 100644 index 000000000..adcb2b54d --- /dev/null +++ b/net/speedtest/speedtest.go @@ -0,0 +1,88 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package speedtest contains both server and client code for +// running speedtests between tailscale nodes. +package speedtest + +import ( + "time" +) + +const ( + blockSize = 32000 // size of the block of data to send + MinDuration = 5 * time.Second // minimum duration for a test + DefaultDuration = MinDuration // default duration for a test + MaxDuration = 30 * time.Second // maximum duration for a test + version = 1 // value used when comparing client and server versions + increment = time.Second // increment to display results for, in seconds + minInterval = 10 * time.Millisecond // minimum interval length for a result to be included + DefaultPort = 20333 +) + +// config is the initial message sent to the server, that contains information on how to +// conduct the test. +type config struct { + Version int `json:"version"` + TestDuration time.Duration `json:"time"` + Direction Direction `json:"direction"` +} + +// configResponse is the response to the testConfig message. If the server has an +// error with the config, the Error variable will hold that error value. +type configResponse struct { + Error string `json:"error,omitempty"` +} + +// This represents the Result of a speedtest within a specific interval +type Result struct { + Bytes int // number of bytes sent/received during the interval + IntervalStart time.Duration // duration between the start of the interval and the start of the test + IntervalEnd time.Duration // duration between the end of the interval and the start of the test + Total bool // if true, this result struct represents the entire test, rather than a segment of the test +} + +func (r Result) MBitsPerSecond() float64 { + return r.MegaBits() / (r.IntervalEnd - r.IntervalStart).Seconds() +} + +func (r Result) MegaBytes() float64 { + return float64(r.Bytes) / 1000000.0 +} + +func (r Result) MegaBits() float64 { + return r.MegaBytes() * 8.0 +} + +func (r Result) Interval() time.Duration { + return r.IntervalEnd - r.IntervalStart +} + +type Direction int + +const ( + Download Direction = iota + Upload +) + +func (d Direction) String() string { + switch d { + case Upload: + return "upload" + case Download: + return "download" + default: + return "" + } +} + +func (d *Direction) Reverse() { + switch *d { + case Upload: + *d = Download + case Download: + *d = Upload + default: + } +} diff --git a/net/speedtest/speedtest_client.go b/net/speedtest/speedtest_client.go new file mode 100644 index 000000000..c62712908 --- /dev/null +++ b/net/speedtest/speedtest_client.go @@ -0,0 +1,42 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package speedtest + +import ( + "encoding/json" + "errors" + "net" + "time" +) + +// RunClient dials the given address and starts a speedtest. +// It returns any errors that come up in the tests. +// If there are no errors in the test, it returns a slice of results. +func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) { + conn, err := net.Dial("tcp", host) + if err != nil { + return nil, err + } + + conf := config{TestDuration: duration, Version: version, Direction: direction} + + defer conn.Close() + encoder := json.NewEncoder(conn) + + if err = encoder.Encode(conf); err != nil { + return nil, err + } + + var response configResponse + decoder := json.NewDecoder(conn) + if err = decoder.Decode(&response); err != nil { + return nil, err + } + if response.Error != "" { + return nil, errors.New(response.Error) + } + + return doTest(conn, conf) +} diff --git a/net/speedtest/speedtest_server.go b/net/speedtest/speedtest_server.go new file mode 100644 index 000000000..7985095cb --- /dev/null +++ b/net/speedtest/speedtest_server.go @@ -0,0 +1,158 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package speedtest + +import ( + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "time" +) + +// Serve starts up the server on a given host and port pair. It starts to listen for +// connections and handles each one in a goroutine. Because it runs in an infinite loop, +// this function only returns if any of the speedtests return with errors, or if the +// listener is closed. +func Serve(l net.Listener) error { + for { + conn, err := l.Accept() + if errors.Is(err, net.ErrClosed) { + return nil + } + if err != nil { + return err + } + err = handleConnection(conn) + if err != nil { + return err + } + } +} + +// handleConnection handles the initial exchange between the server and the client. +// It reads the testconfig message into a config struct. If any errors occur with +// the testconfig (specifically, if there is a version mismatch), it will return those +// errors to the client with a configResponse. After the exchange, it will start +// the speed test. +func handleConnection(conn net.Conn) error { + defer conn.Close() + var conf config + + decoder := json.NewDecoder(conn) + err := decoder.Decode(&conf) + encoder := json.NewEncoder(conn) + + // Both return and encode errors that occurred before the test started. + if err != nil { + encoder.Encode(configResponse{Error: err.Error()}) + return err + } + + // The server should always be doing the opposite of what the client is doing. + conf.Direction.Reverse() + + if conf.Version != version { + err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version) + encoder.Encode(configResponse{Error: err.Error()}) + return err + } + + // Start the test + encoder.Encode(configResponse{}) + _, err = doTest(conn, conf) + return err +} + +// TODO include code to detect whether the code is direct vs DERP + +// doTest contains the code to run both the upload and download speedtest. +// the direction value in the config parameter determines which test to run. +func doTest(conn net.Conn, conf config) ([]Result, error) { + bufferData := make([]byte, blockSize) + + intervalBytes := 0 + totalBytes := 0 + + var currentTime time.Time + var results []Result + + startTime := time.Now() + lastCalculated := startTime + + if conf.Direction == Download { + conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second)) + } else { + _, err := rand.Read(bufferData) + if err != nil { + return nil, err + } + + } + +SpeedTestLoop: + for { + var n int + var err error + + if conf.Direction == Download { + n, err = io.ReadFull(conn, bufferData) + switch err { + case io.EOF, io.ErrUnexpectedEOF: + break SpeedTestLoop + case nil: + // successful read + default: + return nil, fmt.Errorf("unexpected error has occured: %w", err) + } + } else { + // Need to change the data a little bit, to avoid any compression. + for i := range bufferData { + bufferData[i]++ + } + n, err = conn.Write(bufferData) + if err != nil { + // If the write failed, there is most likely something wrong with the connection. + return nil, fmt.Errorf("upload failed: %w", err) + } + } + currentTime = time.Now() + intervalBytes += n + + // checks if the current time is more or equal to the lastCalculated time plus the increment + if currentTime.After(lastCalculated.Add(increment)) { + intervalStart := lastCalculated.Sub(startTime) + intervalEnd := currentTime.Sub(startTime) + if (intervalEnd - intervalStart) > minInterval { + results = append(results, Result{Bytes: intervalBytes, IntervalStart: intervalStart, IntervalEnd: intervalEnd, Total: false}) + } + lastCalculated = currentTime + totalBytes += intervalBytes + intervalBytes = 0 + } + + if conf.Direction == Upload && time.Since(startTime) > conf.TestDuration { + break SpeedTestLoop + } + } + + // get last segment + intervalStart := lastCalculated.Sub(startTime) + intervalEnd := currentTime.Sub(startTime) + if (intervalEnd - intervalStart) > minInterval { + results = append(results, Result{Bytes: intervalBytes, IntervalStart: intervalStart, IntervalEnd: intervalEnd, Total: false}) + } + + // get total + totalBytes += intervalBytes + intervalEnd = currentTime.Sub(startTime) + if intervalEnd > minInterval { + results = append(results, Result{Bytes: totalBytes, IntervalStart: 0, IntervalEnd: intervalEnd, Total: true}) + } + + return results, nil +} diff --git a/net/speedtest/speedtest_test.go b/net/speedtest/speedtest_test.go new file mode 100644 index 000000000..5bf089496 --- /dev/null +++ b/net/speedtest/speedtest_test.go @@ -0,0 +1,81 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package speedtest + +import ( + "net" + "testing" +) + +func TestDownload(t *testing.T) { + // start a listener and find the port where the server will be listening. + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { l.Close() }) + + serverIP := l.Addr().String() + t.Log("server IP found:", serverIP) + + type state struct { + err error + } + displayResult := func(t *testing.T, r Result) { + t.Helper() + t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Seconds(), r.IntervalEnd.Seconds(), r.Total) + } + stateChan := make(chan state, 1) + + go func() { + err := Serve(l) + stateChan <- state{err: err} + }() + + // ensure that the test returns an appropriate number of Result structs + expectedLen := int(DefaultDuration.Seconds()) + 1 + + t.Run("download test", func(t *testing.T) { + // conduct a download test + results, err := RunClient(Download, DefaultDuration, serverIP) + + if err != nil { + t.Fatal("download test failed:", err) + } + + if len(results) < expectedLen { + t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results)) + } + + for _, result := range results { + displayResult(t, result) + } + }) + + t.Run("upload test", func(t *testing.T) { + // conduct an upload test + results, err := RunClient(Upload, DefaultDuration, serverIP) + + if err != nil { + t.Fatal("upload test failed:", err) + } + + if len(results) < expectedLen { + t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results)) + } + + for _, result := range results { + displayResult(t, result) + } + }) + + // causes the server goroutine to finish + l.Close() + + testState := <-stateChan + if testState.err != nil { + t.Error("server error:", err) + } +}