cmd/stunstamp: implement HTTPS & TCP latency measurements (#13082)

HTTPS mirrors current netcheck behavior and TCP uses tcp_info->rtt.

Updates tailscale/corp#22114

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/13110/head
Jordan Whited 3 months ago committed by GitHub
parent bc2744da4b
commit 218110963d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,13 +1,14 @@
// Copyright (c) Tailscale Inc & AUTHORS // Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// The stunstamp binary measures STUN round-trip latency with DERPs. // The stunstamp binary measures round-trip latency with DERPs.
package main package main
import ( import (
"bytes" "bytes"
"cmp" "cmp"
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
@ -31,8 +32,10 @@ import (
"github.com/golang/snappy" "github.com/golang/snappy"
"github.com/prometheus/prometheus/prompb" "github.com/prometheus/prometheus/prompb"
"github.com/tcnksm/go-httpstat"
"tailscale.com/logtail/backoff" "tailscale.com/logtail/backoff"
"tailscale.com/net/stun" "tailscale.com/net/stun"
"tailscale.com/net/tcpinfo"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -44,6 +47,7 @@ var (
flagInstance = flag.String("instance", "", "instance label value; defaults to hostname if unspecified") flagInstance = flag.String("instance", "", "instance label value; defaults to hostname if unspecified")
flagSTUNDstPorts = flag.String("stun-dst-ports", "", "comma-separated list of STUN destination ports to monitor") flagSTUNDstPorts = flag.String("stun-dst-ports", "", "comma-separated list of STUN destination ports to monitor")
flagHTTPSDstPorts = flag.String("https-dst-ports", "", "comma-separated list of HTTPS destination ports to monitor") flagHTTPSDstPorts = flag.String("https-dst-ports", "", "comma-separated list of HTTPS destination ports to monitor")
flagTCPDstPorts = flag.String("tcp-dst-ports", "", "comma-separated list of TCP destination ports to monitor")
flagICMP = flag.Bool("icmp", false, "probe ICMP") flagICMP = flag.Bool("icmp", false, "probe ICMP")
) )
@ -97,6 +101,7 @@ const (
protocolSTUN protocol = "stun" protocolSTUN protocol = "stun"
protocolICMP protocol = "icmp" protocolICMP protocol = "icmp"
protocolHTTPS protocol = "https" protocolHTTPS protocol = "https"
protocolTCP protocol = "tcp"
) )
// resultKey contains the stable dimensions and their values for a given // resultKey contains the stable dimensions and their values for a given
@ -115,7 +120,188 @@ type result struct {
rtt *time.Duration // nil signifies failure, e.g. timeout rtt *time.Duration // nil signifies failure, e.g. timeout
} }
func measureSTUNRTT(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) { type lportsPool struct {
sync.Mutex
ports []int
}
func (l *lportsPool) get() int {
l.Lock()
defer l.Unlock()
ret := l.ports[0]
l.ports = append(l.ports[:0], l.ports[1:]...)
return ret
}
func (l *lportsPool) put(i int) {
l.Lock()
defer l.Unlock()
l.ports = append(l.ports, int(i))
}
var (
lports *lportsPool
)
const (
lportPoolSize = 16000
lportBase = 2048
)
func init() {
lports = &lportsPool{
ports: make([]int, 0, lportPoolSize),
}
for i := lportBase; i < lportBase+lportPoolSize; i++ {
lports.ports = append(lports.ports, i)
}
}
// lportForTCPConn satisfies io.ReadWriteCloser, but is really just used to pass
// around a persistent laddr for stableConn purposes. The underlying TCP
// connection is not created until measurement time as in some cases we need to
// measure dial time.
type lportForTCPConn int
func (l *lportForTCPConn) Close() error {
if *l == 0 {
return nil
}
lports.put(int(*l))
return nil
}
func (l *lportForTCPConn) Write([]byte) (int, error) {
return 0, errors.New("unimplemented")
}
func (l *lportForTCPConn) Read([]byte) (int, error) {
return 0, errors.New("unimplemented")
}
func addrInUse(err error, lport *lportForTCPConn) bool {
if errors.Is(err, syscall.EADDRINUSE) {
old := int(*lport)
// abandon port, don't return it to pool
*lport = lportForTCPConn(lports.get()) // get a new port
log.Printf("EADDRINUSE: %v old: %d new: %d", err, old, *lport)
return true
}
return false
}
func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) {
for {
var opErr error
dialer := &net.Dialer{
Timeout: time.Second * 2,
LocalAddr: &net.TCPAddr{
Port: int(*lport),
},
Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
// we may restart faster than TIME_WAIT can clear
opErr = setSOReuseAddr(fd)
})
},
}
if opErr != nil {
panic(opErr)
}
tcpConn, err := dialer.Dial("tcp", dst.String())
if err != nil {
if addrInUse(err, lport) {
continue
}
return nil, err
}
return tcpConn, nil
}
}
type tempError struct {
error
}
func (t tempError) Temporary() bool {
return true
}
func measureTCPRTT(conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt time.Duration, err error) {
lport, ok := conn.(*lportForTCPConn)
if !ok {
return 0, fmt.Errorf("unexpected conn type: %T", conn)
}
tcpConn, err := tcpDial(lport, dst)
if err != nil {
return 0, tempError{err}
}
defer tcpConn.Close()
rtt, err = tcpinfo.RTT(tcpConn)
if err != nil {
return 0, tempError{err}
}
return rtt, nil
}
func measureHTTPSRTT(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) {
lport, ok := conn.(*lportForTCPConn)
if !ok {
return 0, fmt.Errorf("unexpected conn type: %T", conn)
}
var httpResult httpstat.Result
ctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*3)
defer cancel()
reqURL := "https://" + dst.String() + "/derp/latency-check"
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return 0, err
}
client := &http.Client{}
tcpConn, err := tcpDial(lport, dst)
if err != nil {
return 0, tempError{err}
}
defer tcpConn.Close()
tlsConn := tls.Client(tcpConn, &tls.Config{
ServerName: hostname,
})
// Mirror client/netcheck behavior, which handshakes before handing the
// tlsConn over to the http.Client via http.Transport
err = tlsConn.Handshake()
if err != nil {
return 0, tempError{err}
}
tlsConnCh := make(chan net.Conn, 1)
tlsConnCh <- tlsConn
tr := &http.Transport{
DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
select {
case tlsConn := <-tlsConnCh:
return tlsConn, nil
default:
return nil, errors.New("unexpected second call of DialTLSContext")
}
},
}
client.Transport = tr
resp, err := client.Do(req)
if err != nil {
return 0, tempError{err}
}
if resp.StatusCode/100 != 2 {
return 0, tempError{fmt.Errorf("unexpected status code: %d", resp.StatusCode)}
}
defer resp.Body.Close()
_, err = io.Copy(io.Discard, io.LimitReader(resp.Body, 8<<10))
if err != nil {
return 0, tempError{err}
}
httpResult.End(time.Now())
return httpResult.ServerProcessing, nil
}
func measureSTUNRTT(conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt time.Duration, err error) {
uconn, ok := conn.(*net.UDPConn) uconn, ok := conn.(*net.UDPConn)
if !ok { if !ok {
return 0, fmt.Errorf("unexpected conn type: %T", conn) return 0, fmt.Errorf("unexpected conn type: %T", conn)
@ -167,7 +353,7 @@ type nodeMeta struct {
addr netip.Addr addr netip.Addr
} }
type measureFn func(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) type measureFn func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error)
// probe measures round trip time for the node described by meta over // probe measures round trip time for the node described by meta over
// conn against dstPort using fn. It may return a nil duration and nil error in // conn against dstPort using fn. It may return a nil duration and nil error in
@ -180,7 +366,7 @@ func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*
} }
time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx
rtt, err := fn(conn, netip.AddrPortFrom(meta.addr, uint16(dstPort))) rtt, err := fn(conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort)))
if err != nil { if err != nil {
if isTemporaryOrTimeoutErr(err) { if isTemporaryOrTimeoutErr(err) {
log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err) log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err)
@ -251,7 +437,7 @@ func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]node
return stale, nil return stale, nil
} }
func newConn(source timestampSource, protocol protocol) (io.ReadWriteCloser, error) { func newConn(source timestampSource, protocol protocol, stable connStability) (io.ReadWriteCloser, error) {
switch protocol { switch protocol {
case protocolSTUN: case protocolSTUN:
if source == timestampSourceKernel { if source == timestampSourceKernel {
@ -263,8 +449,19 @@ func newConn(source timestampSource, protocol protocol) (io.ReadWriteCloser, err
// TODO(jwhited): implement // TODO(jwhited): implement
return nil, errors.New("unimplemented protocol") return nil, errors.New("unimplemented protocol")
case protocolHTTPS: case protocolHTTPS:
// TODO(jwhited): implement localPort := 0
return nil, errors.New("unimplemented protocol") if stable {
localPort = lports.get()
}
ret := lportForTCPConn(localPort)
return &ret, nil
case protocolTCP:
localPort := 0
if stable {
localPort = lports.get()
}
ret := lportForTCPConn(localPort)
return &ret, nil
} }
return nil, errors.New("unknown protocol") return nil, errors.New("unknown protocol")
} }
@ -279,19 +476,20 @@ func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr ne
if !protocolSupportsStableConn(protocol) { if !protocolSupportsStableConn(protocol) {
return [2]io.ReadWriteCloser{}, nil return [2]io.ReadWriteCloser{}, nil
} }
conns, ok := stableConns[stableConnKey{addr, protocol, dstPort}] key := stableConnKey{addr, protocol, dstPort}
conns, ok := stableConns[key]
if ok { if ok {
return conns, nil return conns, nil
} }
if protocolSupportsKernelTS(protocol) { if protocolSupportsKernelTS(protocol) {
kconn, err := newConn(timestampSourceKernel, protocol) kconn, err := newConn(timestampSourceKernel, protocol, stableConn)
if err != nil { if err != nil {
return conns, err return conns, err
} }
conns[timestampSourceKernel] = kconn conns[timestampSourceKernel] = kconn
} }
uconn, err := newConn(timestampSourceUserspace, protocol) uconn, err := newConn(timestampSourceUserspace, protocol, stableConn)
if err != nil { if err != nil {
if protocolSupportsKernelTS(protocol) { if protocolSupportsKernelTS(protocol) {
conns[timestampSourceKernel].Close() conns[timestampSourceKernel].Close()
@ -299,6 +497,7 @@ func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr ne
return conns, err return conns, err
} }
conns[timestampSourceUserspace] = uconn conns[timestampSourceUserspace] = uconn
stableConns[key] = conns
return conns, nil return conns, nil
} }
@ -338,7 +537,7 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo
} }
if conn == nil { if conn == nil {
var err error var err error
conn, err = newConn(source, protocol) conn, err = newConn(source, protocol, unstableConn)
if err != nil { if err != nil {
select { select {
case <-doneCh: case <-doneCh:
@ -361,7 +560,9 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo
case protocolICMP: case protocolICMP:
// TODO(jwhited): implement // TODO(jwhited): implement
case protocolHTTPS: case protocolHTTPS:
// TODO(jwhited): implement fn = measureHTTPSRTT
case protocolTCP:
fn = measureTCPRTT
} }
rtt, err := probe(meta, conn, fn, dstPort) rtt, err := probe(meta, conn, fn, dstPort)
if err != nil { if err != nil {
@ -725,6 +926,13 @@ func main() {
if len(httpsPorts) > 0 { if len(httpsPorts) > 0 {
portsByProtocol[protocolHTTPS] = httpsPorts portsByProtocol[protocolHTTPS] = httpsPorts
} }
tcpPorts, err := getPortsFromFlag(*flagTCPDstPorts)
if err != nil {
log.Fatalf("invalid tcp-dst-ports flag value: %v", err)
}
if len(tcpPorts) > 0 {
portsByProtocol[protocolTCP] = tcpPorts
}
if *flagICMP { if *flagICMP {
portsByProtocol[protocolICMP] = []int{0} portsByProtocol[protocolICMP] = []int{0}
} }
@ -734,8 +942,8 @@ func main() {
// TODO(jwhited): remove protocol restriction // TODO(jwhited): remove protocol restriction
for k := range portsByProtocol { for k := range portsByProtocol {
if k != protocolSTUN { if k != protocolSTUN && k != protocolHTTPS && k != protocolTCP {
log.Fatal("HTTPS & ICMP are not yet supported") log.Fatal("ICMP is not yet supported")
} }
} }
@ -883,7 +1091,7 @@ func main() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
updatedDM, err := getDERPMap(ctx, *flagDERPMap) updatedDM, err := getDERPMap(ctx, *flagDERPMap)
if err != nil { if err == nil {
dmCh <- updatedDM dmCh <- updatedDM
} }
}() }()

@ -16,10 +16,14 @@ func getUDPConnKernelTimestamp() (io.ReadWriteCloser, error) {
return nil, errors.New("unimplemented") return nil, errors.New("unimplemented")
} }
func measureSTUNRTTKernel(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) { func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) {
return 0, errors.New("unimplemented") return 0, errors.New("unimplemented")
} }
func protocolSupportsKernelTS(_ protocol) bool { func protocolSupportsKernelTS(_ protocol) bool {
return false return false
} }
func setSOReuseAddr(fd uintptr) error {
return nil
}

@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"syscall"
"time" "time"
"github.com/mdlayher/socket" "github.com/mdlayher/socket"
@ -56,7 +57,7 @@ func parseTimestampFromCmsgs(oob []byte) (time.Time, error) {
return time.Time{}, errors.New("failed to parse timestamp from cmsgs") return time.Time{}, errors.New("failed to parse timestamp from cmsgs")
} }
func measureSTUNRTTKernel(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) { func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) {
sconn, ok := conn.(*socket.Conn) sconn, ok := conn.(*socket.Conn)
if !ok { if !ok {
return 0, fmt.Errorf("conn of unexpected type: %T", conn) return 0, fmt.Errorf("conn of unexpected type: %T", conn)
@ -144,3 +145,8 @@ func protocolSupportsKernelTS(p protocol) bool {
// TODO: jwhited support ICMP // TODO: jwhited support ICMP
return false return false
} }
func setSOReuseAddr(fd uintptr) error {
// we may restart faster than TIME_WAIT can clear
return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
}

Loading…
Cancel
Save