diff --git a/net/tcpinfo/tcpinfo.go b/net/tcpinfo/tcpinfo.go new file mode 100644 index 000000000..a757add9f --- /dev/null +++ b/net/tcpinfo/tcpinfo.go @@ -0,0 +1,51 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tcpinfo provides platform-agnostic accessors to information about a +// TCP connection (e.g. RTT, MSS, etc.). +package tcpinfo + +import ( + "errors" + "net" + "time" +) + +var ( + ErrNotTCP = errors.New("tcpinfo: not a TCP conn") + ErrUnimplemented = errors.New("tcpinfo: unimplemented") +) + +// RTT returns the RTT for the given net.Conn. +// +// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then +// ErrNotTCP will be returned. If retrieving the RTT is not supported on the +// current platform, ErrUnimplemented will be returned. +func RTT(conn net.Conn) (time.Duration, error) { + tcpConn, err := unwrap(conn) + if err != nil { + return 0, err + } + + return rttImpl(tcpConn) +} + +// netConner is implemented by crypto/tls.Conn to unwrap into an underlying +// net.Conn. +type netConner interface { + NetConn() net.Conn +} + +// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn +func unwrap(nc net.Conn) (*net.TCPConn, error) { + for { + switch v := nc.(type) { + case *net.TCPConn: + return v, nil + case netConner: + nc = v.NetConn() + default: + return nil, ErrNotTCP + } + } +} diff --git a/net/tcpinfo/tcpinfo_darwin.go b/net/tcpinfo/tcpinfo_darwin.go new file mode 100644 index 000000000..53fa22fbf --- /dev/null +++ b/net/tcpinfo/tcpinfo_darwin.go @@ -0,0 +1,33 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPConnectionInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil +} diff --git a/net/tcpinfo/tcpinfo_linux.go b/net/tcpinfo/tcpinfo_linux.go new file mode 100644 index 000000000..885d462c9 --- /dev/null +++ b/net/tcpinfo/tcpinfo_linux.go @@ -0,0 +1,33 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil +} diff --git a/net/tcpinfo/tcpinfo_other.go b/net/tcpinfo/tcpinfo_other.go new file mode 100644 index 000000000..be45523ae --- /dev/null +++ b/net/tcpinfo/tcpinfo_other.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin + +package tcpinfo + +import ( + "net" + "time" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + return 0, ErrUnimplemented +} diff --git a/net/tcpinfo/tcpinfo_test.go b/net/tcpinfo/tcpinfo_test.go new file mode 100644 index 000000000..a117eb59a --- /dev/null +++ b/net/tcpinfo/tcpinfo_test.go @@ -0,0 +1,64 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "bytes" + "io" + "net" + "runtime" + "testing" +) + +func TestRTT(t *testing.T) { + switch runtime.GOOS { + case "linux", "darwin": + default: + t.Skipf("not currently supported on %s", runtime.GOOS) + } + + ln, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + t.Cleanup(func() { c.Close() }) + + // Copy from the client to nowhere + go io.Copy(io.Discard, c) + } + }() + + conn, err := net.Dial("tcp4", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + // Write a bunch of data to the conn to force TCP session establishment + // and a few packets. + junkData := bytes.Repeat([]byte("hello world\n"), 1024*1024) + for i := 0; i < 10; i++ { + if _, err := conn.Write(junkData); err != nil { + t.Fatalf("error writing junk data [%d]: %v", i, err) + } + } + + // Get the RTT now + rtt, err := RTT(conn) + if err != nil { + t.Fatalf("error getting RTT: %v", err) + } + if rtt == 0 { + t.Errorf("expected RTT > 0") + } + + t.Logf("TCP rtt: %v", rtt) +}