cmd/stunstamp: refactor connection construction (#13110)

getConns() is now responsible for returning both stable and unstable
conns. conn and measureFn are now passed together via connAndMeasureFn.
newConnAndMeasureFn() is responsible for constructing them.

TCP measurement timeouts are adjusted to more closely match netcheck.

Updates tailscale/corp#22114

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/13111/head
Jordan Whited 1 month ago committed by GitHub
parent 218110963d
commit 7aec8d4e6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -23,6 +23,7 @@ import (
"net/url" "net/url"
"os" "os"
"os/signal" "os/signal"
"runtime"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@ -190,11 +191,10 @@ func addrInUse(err error, lport *lportForTCPConn) bool {
return false return false
} }
func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) { func tcpDial(ctx context.Context, lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) {
for { for {
var opErr error var opErr error
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: time.Second * 2,
LocalAddr: &net.TCPAddr{ LocalAddr: &net.TCPAddr{
Port: int(*lport), Port: int(*lport),
}, },
@ -208,7 +208,7 @@ func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) {
if opErr != nil { if opErr != nil {
panic(opErr) panic(opErr)
} }
tcpConn, err := dialer.Dial("tcp", dst.String()) tcpConn, err := dialer.DialContext(ctx, "tcp", dst.String())
if err != nil { if err != nil {
if addrInUse(err, lport) { if addrInUse(err, lport) {
continue continue
@ -232,11 +232,23 @@ func measureTCPRTT(conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt t
if !ok { if !ok {
return 0, fmt.Errorf("unexpected conn type: %T", conn) return 0, fmt.Errorf("unexpected conn type: %T", conn)
} }
tcpConn, err := tcpDial(lport, dst) // Set a dial timeout < 1s (TCP_TIMEOUT_INIT on Linux) as a means to avoid
// SYN retries, which can contribute to tcpi->rtt below. This simply limits
// retries from the initiator, but SYN+ACK on the reverse path can also
// time out and be retransmitted.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*750)
defer cancel()
tcpConn, err := tcpDial(ctx, lport, dst)
if err != nil { if err != nil {
return 0, tempError{err} return 0, tempError{err}
} }
defer tcpConn.Close() defer tcpConn.Close()
// This is an unreliable method to measure TCP RTT. The Linux kernel
// describes it as such in tcp_rtt_estimator(). We take some care in how we
// hold tcp_info->rtt here, e.g. clamping dial timeout, but if we are to
// actually use this elsewhere as an input to some decision it warrants a
// deeper study and consideration for alternative methods. Its usefulness
// here is as a point of comparison against the other methods.
rtt, err = tcpinfo.RTT(tcpConn) rtt, err = tcpinfo.RTT(tcpConn)
if err != nil { if err != nil {
return 0, tempError{err} return 0, tempError{err}
@ -250,15 +262,19 @@ func measureHTTPSRTT(conn io.ReadWriteCloser, hostname string, dst netip.AddrPor
return 0, fmt.Errorf("unexpected conn type: %T", conn) return 0, fmt.Errorf("unexpected conn type: %T", conn)
} }
var httpResult httpstat.Result var httpResult httpstat.Result
ctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*3) // 5s mirrors net/netcheck.overallProbeTimeout used in net/netcheck.Client.measureHTTPSLatency.
reqCtx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*5)
defer cancel() defer cancel()
reqURL := "https://" + dst.String() + "/derp/latency-check" reqURL := "https://" + dst.String() + "/derp/latency-check"
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) req, err := http.NewRequestWithContext(reqCtx, "GET", reqURL, nil)
if err != nil { if err != nil {
return 0, err return 0, err
} }
client := &http.Client{} client := &http.Client{}
tcpConn, err := tcpDial(lport, dst) // 1.5s mirrors derp/derphttp.dialnodeTimeout used in derp/derphttp.DialNode().
dialCtx, dialCancel := context.WithTimeout(reqCtx, time.Millisecond*1500)
defer dialCancel()
tcpConn, err := tcpDial(dialCtx, lport, dst)
if err != nil { if err != nil {
return 0, tempError{err} return 0, tempError{err}
} }
@ -355,18 +371,17 @@ type nodeMeta struct {
type measureFn func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) type measureFn func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error)
// probe measures round trip time for the node described by meta over // probe measures round trip time for the node described by meta over cf against
// conn against dstPort using fn. It may return a nil duration and nil error in // dstPort. It may return a nil duration and nil error in the event of a
// the event of a timeout. A non-nil error indicates an unrecoverable or // timeout. A non-nil error indicates an unrecoverable or non-temporary error.
// non-temporary error. func probe(meta nodeMeta, cf *connAndMeasureFn, dstPort int) (*time.Duration, error) {
func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*time.Duration, error) {
ua := &net.UDPAddr{ ua := &net.UDPAddr{
IP: net.IP(meta.addr.AsSlice()), IP: net.IP(meta.addr.AsSlice()),
Port: dstPort, Port: dstPort,
} }
time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx
rtt, err := fn(conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort))) rtt, err := cf.fn(cf.conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort)))
if err != nil { if err != nil {
if isTemporaryOrTimeoutErr(err) { if isTemporaryOrTimeoutErr(err) {
log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err) log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err)
@ -437,31 +452,69 @@ func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]node
return stale, nil return stale, nil
} }
func newConn(source timestampSource, protocol protocol, stable connStability) (io.ReadWriteCloser, error) { type connAndMeasureFn struct {
conn io.ReadWriteCloser
fn measureFn
}
// newConnAndMeasureFn returns a connAndMeasureFn or an error. It may return
// nil for both if some combination of the supplied timestampSource, protocol,
// or connStability is unsupported.
func newConnAndMeasureFn(source timestampSource, protocol protocol, stable connStability) (*connAndMeasureFn, error) {
info := getProtocolSupportInfo(protocol)
if !info.stableConn && bool(stable) {
return nil, nil
}
if !info.userspaceTS && source == timestampSourceUserspace {
return nil, nil
}
if !info.kernelTS && source == timestampSourceKernel {
return nil, nil
}
switch protocol { switch protocol {
case protocolSTUN: case protocolSTUN:
if source == timestampSourceKernel { if source == timestampSourceKernel {
return getUDPConnKernelTimestamp() conn, err := getUDPConnKernelTimestamp()
if err != nil {
return nil, err
}
return &connAndMeasureFn{
conn: conn,
fn: measureSTUNRTTKernel,
}, nil
} else { } else {
return net.ListenUDP("udp", &net.UDPAddr{}) conn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
return nil, err
}
return &connAndMeasureFn{
conn: conn,
fn: measureSTUNRTT,
}, nil
} }
case protocolICMP: case protocolICMP:
// TODO(jwhited): implement // TODO(jwhited): implement
return nil, errors.New("unimplemented protocol") return nil, nil
case protocolHTTPS: case protocolHTTPS:
localPort := 0 localPort := 0
if stable { if stable {
localPort = lports.get() localPort = lports.get()
} }
ret := lportForTCPConn(localPort) conn := lportForTCPConn(localPort)
return &ret, nil return &connAndMeasureFn{
conn: &conn,
fn: measureHTTPSRTT,
}, nil
case protocolTCP: case protocolTCP:
localPort := 0 localPort := 0
if stable { if stable {
localPort = lports.get() localPort = lports.get()
} }
ret := lportForTCPConn(localPort) conn := lportForTCPConn(localPort)
return &ret, nil return &connAndMeasureFn{
conn: &conn,
fn: measureTCPRTT,
}, nil
} }
return nil, errors.New("unknown protocol") return nil, errors.New("unknown protocol")
} }
@ -472,41 +525,57 @@ type stableConnKey struct {
port int port int
} }
func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr netip.Addr, protocol protocol, dstPort int) ([2]io.ReadWriteCloser, error) { type protocolSupportInfo struct {
if !protocolSupportsStableConn(protocol) { kernelTS bool
return [2]io.ReadWriteCloser{}, nil userspaceTS bool
} stableConn bool
key := stableConnKey{addr, protocol, dstPort}
conns, ok := stableConns[key]
if ok {
return conns, nil
} }
if protocolSupportsKernelTS(protocol) { func getConns(
kconn, err := newConn(timestampSourceKernel, protocol, stableConn) stableConns map[stableConnKey][2]*connAndMeasureFn,
addr netip.Addr,
protocol protocol,
dstPort int,
) (stable, unstable [2]*connAndMeasureFn, err error) {
key := stableConnKey{addr, protocol, dstPort}
defer func() {
if err != nil { if err != nil {
return conns, err for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} {
c := stable[source]
if c != nil {
c.conn.Close()
} }
conns[timestampSourceKernel] = kconn c = unstable[source]
if c != nil {
c.conn.Close()
} }
uconn, err := newConn(timestampSourceUserspace, protocol, stableConn) }
}
}()
var ok bool
stable, ok = stableConns[key]
if !ok {
for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} {
var cf *connAndMeasureFn
cf, err = newConnAndMeasureFn(source, protocol, stableConn)
if err != nil { if err != nil {
if protocolSupportsKernelTS(protocol) { return
conns[timestampSourceKernel].Close()
} }
return conns, err stable[source] = cf
} }
conns[timestampSourceUserspace] = uconn stableConns[key] = stable
stableConns[key] = conns
return conns, nil
} }
func protocolSupportsStableConn(p protocol) bool { for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} {
if p == protocolICMP { var cf *connAndMeasureFn
// no value for ICMP cf, err = newConnAndMeasureFn(source, protocol, unstableConn)
return false if err != nil {
return
} }
return true unstable[source] = cf
}
return stable, unstable, nil
} }
// probeNodes measures the round-trip time for the protocols and ports described // probeNodes measures the round-trip time for the protocols and ports described
@ -514,7 +583,7 @@ func protocolSupportsStableConn(p protocol) bool {
// stableConns are used to recycle connections across calls to probeNodes. // stableConns are used to recycle connections across calls to probeNodes.
// probeNodes is also responsible for trimming stableConns based on node // probeNodes is also responsible for trimming stableConns based on node
// lifetime in nodeMetaByAddr. It returns the results or an error if one occurs. // lifetime in nodeMetaByAddr. It returns the results or an error if one occurs.
func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]io.ReadWriteCloser, portsByProtocol map[protocol][]int) ([]result, error) { func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]*connAndMeasureFn, portsByProtocol map[protocol][]int) ([]result, error) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
results := make([]result, 0) results := make([]result, 0)
resultsCh := make(chan result) resultsCh := make(chan result)
@ -524,47 +593,19 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo
at := time.Now() at := time.Now()
addrsToProbe := make(map[netip.Addr]bool) addrsToProbe := make(map[netip.Addr]bool)
doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, protocol protocol, dstPort int) { doProbe := func(cf *connAndMeasureFn, meta nodeMeta, source timestampSource, stable connStability, protocol protocol, dstPort int) {
defer wg.Done() defer wg.Done()
r := result{ r := result{
key: resultKey{ key: resultKey{
meta: meta, meta: meta,
timestampSource: source, timestampSource: source,
connStability: stable,
dstPort: dstPort, dstPort: dstPort,
protocol: protocol, protocol: protocol,
}, },
at: at, at: at,
} }
if conn == nil { rtt, err := probe(meta, cf, dstPort)
var err error
conn, err = newConn(source, protocol, unstableConn)
if err != nil {
select {
case <-doneCh:
return
case errCh <- err:
return
}
}
defer conn.Close()
} else {
r.key.connStability = stableConn
}
var fn measureFn
switch protocol {
case protocolSTUN:
fn = measureSTUNRTT
if source == timestampSourceKernel {
fn = measureSTUNRTTKernel
}
case protocolICMP:
// TODO(jwhited): implement
case protocolHTTPS:
fn = measureHTTPSRTT
case protocolTCP:
fn = measureTCPRTT
}
rtt, err := probe(meta, conn, fn, dstPort)
if err != nil { if err != nil {
select { select {
case <-doneCh: case <-doneCh:
@ -584,44 +625,39 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo
addrsToProbe[meta.addr] = true addrsToProbe[meta.addr] = true
for p, ports := range portsByProtocol { for p, ports := range portsByProtocol {
for _, port := range ports { for _, port := range ports {
stable, err := getStableConns(stableConns, meta.addr, p, port) stable, unstable, err := getConns(stableConns, meta.addr, p, port)
if err != nil { if err != nil {
close(doneCh) close(doneCh)
wg.Wait() wg.Wait()
return nil, err return nil, err
} }
if protocolSupportsStableConn(p) { for i, cf := range stable {
if cf != nil {
wg.Add(1) wg.Add(1)
numProbes++ numProbes++
go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace, p, port) go doProbe(cf, meta, timestampSource(i), stableConn, p, port)
} }
wg.Add(1)
numProbes++
go doProbe(nil, meta, timestampSourceUserspace, p, port)
if protocolSupportsKernelTS(p) {
if protocolSupportsStableConn(p) {
wg.Add(1)
numProbes++
go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel, p, port)
} }
for i, cf := range unstable {
if cf != nil {
wg.Add(1) wg.Add(1)
numProbes++ numProbes++
go doProbe(nil, meta, timestampSourceKernel, p, port) go doProbe(cf, meta, timestampSource(i), unstableConn, p, port)
}
} }
} }
} }
} }
// cleanup conns we no longer need // cleanup conns we no longer need
for k, conns := range stableConns { for k, cf := range stableConns {
if !addrsToProbe[k.node] { if !addrsToProbe[k.node] {
if conns[timestampSourceKernel] != nil { if cf[timestampSourceKernel] != nil {
conns[timestampSourceKernel].Close() cf[timestampSourceKernel].conn.Close()
} }
conns[timestampSourceUserspace].Close() cf[timestampSourceUserspace].conn.Close()
delete(stableConns, k) delete(stableConns, k)
} }
} }
@ -728,42 +764,16 @@ func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, portsByProtocol
Value: math.Float64frombits(staleNaN), Value: math.Float64frombits(staleNaN),
}, },
} }
// We send stale markers for all combinations in the interest
// of simplicity.
for _, name := range []string{rttMetricName, timeoutsMetricName} {
for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} {
for _, stable := range []connStability{unstableConn, stableConn} {
staleMarkers = append(staleMarkers, prompb.TimeSeries{ staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port), Labels: timeSeriesLabels(name, s, instance, source, stable, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port),
Samples: samples,
})
if protocolSupportsStableConn(p) {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, stableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, stableConn, p, port),
Samples: samples, Samples: samples,
}) })
} }
if protocolSupportsKernelTS(p) {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, unstableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, unstableConn, p, port),
Samples: samples,
})
if protocolSupportsStableConn(p) {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, stableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, stableConn, p, port),
Samples: samples,
})
} }
} }
} }
@ -909,6 +919,9 @@ func getPortsFromFlag(f string) ([]int, error) {
} }
func main() { func main() {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
log.Fatal("unsupported platform")
}
flag.Parse() flag.Parse()
portsByProtocol := make(map[protocol][]int) portsByProtocol := make(map[protocol][]int)
@ -1035,7 +1048,7 @@ func main() {
// Comparison of stable and unstable 5-tuple results can shed light on // Comparison of stable and unstable 5-tuple results can shed light on
// differences between paths where hashing (multipathing/load balancing) // differences between paths where hashing (multipathing/load balancing)
// comes into play. The inner 2 element array index is timestampSource. // comes into play. The inner 2 element array index is timestampSource.
stableConns := make(map[stableConnKey][2]io.ReadWriteCloser) stableConns := make(map[stableConnKey][2]*connAndMeasureFn)
// timeouts holds counts of timeout events. Values are persisted for the // timeouts holds counts of timeout events. Values are persisted for the
// lifetime of the related node in the DERP map. // lifetime of the related node in the DERP map.

@ -20,8 +20,28 @@ func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.Ad
return 0, errors.New("unimplemented") return 0, errors.New("unimplemented")
} }
func protocolSupportsKernelTS(_ protocol) bool { func getProtocolSupportInfo(p protocol) protocolSupportInfo {
return false switch p {
case protocolSTUN:
return protocolSupportInfo{
kernelTS: false,
userspaceTS: true,
stableConn: true,
}
case protocolHTTPS:
return protocolSupportInfo{
kernelTS: false,
userspaceTS: true,
stableConn: true,
}
case protocolTCP:
return protocolSupportInfo{
kernelTS: true,
userspaceTS: false,
stableConn: true,
}
}
return protocolSupportInfo{}
} }
func setSOReuseAddr(fd uintptr) error { func setSOReuseAddr(fd uintptr) error {

@ -138,12 +138,29 @@ func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.Ad
} }
func protocolSupportsKernelTS(p protocol) bool { func getProtocolSupportInfo(p protocol) protocolSupportInfo {
if p == protocolSTUN { switch p {
return true case protocolSTUN:
} return protocolSupportInfo{
// TODO: jwhited support ICMP kernelTS: true,
return false userspaceTS: true,
stableConn: true,
}
case protocolHTTPS:
return protocolSupportInfo{
kernelTS: false,
userspaceTS: true,
stableConn: true,
}
case protocolTCP:
return protocolSupportInfo{
kernelTS: true,
userspaceTS: false,
stableConn: true,
}
// TODO(jwhited): add ICMP
}
return protocolSupportInfo{}
} }
func setSOReuseAddr(fd uintptr) error { func setSOReuseAddr(fd uintptr) error {

Loading…
Cancel
Save