From 6e106712f654fe2086c9d66c9cbbb3e999c75489 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 6 Jun 2024 09:05:17 -0700 Subject: [PATCH] cmd/stunstamp: support probing multiple ports (#12356) Updates tailscale/corp#20344 Signed-off-by: Jordan Whited --- cmd/stunstamp/api.go | 5 +- cmd/stunstamp/stunstamp.go | 163 +++++++++++++++++++------------ cmd/stunstamp/stunstamp_linux.go | 4 +- 3 files changed, 107 insertions(+), 65 deletions(-) diff --git a/cmd/stunstamp/api.go b/cmd/stunstamp/api.go index 849d5a32c..8effda6e4 100644 --- a/cmd/stunstamp/api.go +++ b/cmd/stunstamp/api.go @@ -39,6 +39,7 @@ type apiResult struct { Addr string `json:"addr"` Source int `json:"source"` // timestampSourceUserspace (0) or timestampSourceKernel (1) StableConn bool `json:"stableConn"` + DstPort int `json:"dstPort"` RttNS *int `json:"rttNS"` } @@ -94,7 +95,7 @@ func (a *api) query(w http.ResponseWriter, r *http.Request) { return } - sb := sq.Select("at_unix", "region_id", "hostname", "af", "address", "timestamp_source", "stable_conn", "rtt_ns").From("rtt") + sb := sq.Select("at_unix", "region_id", "hostname", "af", "address", "timestamp_source", "stable_conn", "dst_port", "rtt_ns").From("rtt") sb = sb.Where(sq.And{ sq.GtOrEq{"at_unix": from.Unix()}, sq.LtOrEq{"at_unix": to.Unix()}, @@ -115,7 +116,7 @@ func (a *api) query(w http.ResponseWriter, r *http.Request) { result := apiResult{ RttNS: &rtt, } - err = rows.Scan(&result.At, &result.RegionID, &result.Hostname, &result.Af, &result.Addr, &result.Source, &result.StableConn, &result.RttNS) + err = rows.Scan(&result.At, &result.RegionID, &result.Hostname, &result.Af, &result.Addr, &result.Source, &result.StableConn, &result.DstPort, &result.RttNS) if err != nil { http.Error(w, err.Error(), 500) return diff --git a/cmd/stunstamp/stunstamp.go b/cmd/stunstamp/stunstamp.go index 0336ace32..605901775 100644 --- a/cmd/stunstamp/stunstamp.go +++ b/cmd/stunstamp/stunstamp.go @@ -23,6 +23,8 @@ import ( "os" "os/signal" "slices" + "strconv" + "strings" "sync" "syscall" "time" @@ -43,6 +45,7 @@ var ( flagRetention = flag.Duration("retention", time.Hour*24*7, "sqlite retention period in time.ParseDuration() format") flagRemoteWriteURL = flag.String("rw-url", "", "prometheus remote write URL") flagInstance = flag.String("instance", "", "instance label value; defaults to hostname if unspecified") + flagDstPorts = flag.String("dst-ports", "", "comma-separated list of destination ports to monitor") ) const ( @@ -91,6 +94,7 @@ type result struct { meta nodeMeta timestampSource timestampSource connStability connStability + dstPort int rtt *time.Duration // nil signifies failure, e.g. timeout } @@ -145,17 +149,17 @@ type nodeMeta struct { type measureFn func(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) -func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn) (*time.Duration, error) { +func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*time.Duration, error) { ua := &net.UDPAddr{ IP: net.IP(meta.addr.AsSlice()), - Port: 3478, + Port: dstPort, } time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx rtt, err := fn(conn, ua) if err != nil { if isTemporaryOrTimeoutErr(err) { - log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, meta.addr, err) + log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err) return nil, nil } } @@ -218,10 +222,14 @@ func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]node return stale, nil } -func getStableConns(stableConns map[netip.Addr][2]io.ReadWriteCloser, addr netip.Addr) ([2]io.ReadWriteCloser, error) { - conns, ok := stableConns[addr] +func getStableConns(stableConns map[netip.Addr]map[int][2]io.ReadWriteCloser, addr netip.Addr, dstPort int) ([2]io.ReadWriteCloser, error) { + conns := [2]io.ReadWriteCloser{} + byDstPort, ok := stableConns[addr] if ok { - return conns, nil + conns, ok = byDstPort[dstPort] + if ok { + return conns, nil + } } if supportsKernelTS() { kconn, err := getConnKernelTimestamp() @@ -232,10 +240,17 @@ func getStableConns(stableConns map[netip.Addr][2]io.ReadWriteCloser, addr netip } uconn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { + if supportsKernelTS() { + conns[timestampSourceKernel].Close() + } return conns, err } conns[timestampSourceUserspace] = uconn - stableConns[addr] = conns + if byDstPort == nil { + byDstPort = make(map[int][2]io.ReadWriteCloser) + } + byDstPort[dstPort] = conns + stableConns[addr] = byDstPort return conns, nil } @@ -243,7 +258,7 @@ func getStableConns(stableConns map[netip.Addr][2]io.ReadWriteCloser, addr netip // DERP nodes described by nodeMetaByAddr while using/updating stableConns for // UDP sockets that should be recycled across runs. It returns the results or // an error if one occurs. -func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Addr][2]io.ReadWriteCloser) ([]result, error) { +func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Addr]map[int][2]io.ReadWriteCloser, dstPorts []int) ([]result, error) { wg := sync.WaitGroup{} results := make([]result, 0) resultsCh := make(chan result) @@ -253,9 +268,14 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad at := time.Now() addrsToProbe := make(map[netip.Addr]bool) - doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource) { + doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, dstPort int) { defer wg.Done() - r := result{} + r := result{ + at: at, + meta: meta, + timestampSource: source, + dstPort: dstPort, + } if conn == nil { var err error if source == timestampSourceKernel { @@ -279,7 +299,7 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad if source == timestampSourceKernel { fn = measureRTTKernel } - rtt, err := probe(meta, conn, fn) + rtt, err := probe(meta, conn, fn, dstPort) if err != nil { select { case <-doneCh: @@ -288,9 +308,6 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad return } } - r.at = at - r.meta = meta - r.timestampSource = source r.rtt = rtt select { case <-doneCh: @@ -300,33 +317,37 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad for _, meta := range nodeMetaByAddr { addrsToProbe[meta.addr] = true - stable, err := getStableConns(stableConns, meta.addr) - if err != nil { - close(doneCh) - wg.Wait() - return nil, err - } + for _, port := range dstPorts { + stable, err := getStableConns(stableConns, meta.addr, port) + if err != nil { + close(doneCh) + wg.Wait() + return nil, err + } - wg.Add(2) - numProbes += 2 - go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace) - go doProbe(nil, meta, timestampSourceUserspace) - if supportsKernelTS() { wg.Add(2) numProbes += 2 - go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel) - go doProbe(nil, meta, timestampSourceKernel) + go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace, port) + go doProbe(nil, meta, timestampSourceUserspace, port) + if supportsKernelTS() { + wg.Add(2) + numProbes += 2 + go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel, port) + go doProbe(nil, meta, timestampSourceKernel, port) + } } } // cleanup conns we no longer need - for k, conns := range stableConns { + for k, byDstPort := range stableConns { if !addrsToProbe[k] { - if conns[timestampSourceKernel] != nil { - conns[timestampSourceKernel].Close() + for _, conns := range byDstPort { + if conns[timestampSourceKernel] != nil { + conns[timestampSourceKernel].Close() + } + conns[timestampSourceUserspace].Close() + delete(stableConns, k) } - conns[timestampSourceUserspace].Close() - delete(stableConns, k) } } @@ -352,7 +373,7 @@ const ( stableConn connStability = true ) -func timeSeriesLabels(meta nodeMeta, instance string, source timestampSource, stability connStability) []prompb.Label { +func timeSeriesLabels(meta nodeMeta, instance string, source timestampSource, stability connStability, dstPort int) []prompb.Label { addressFamily := "ipv4" if meta.addr.Is6() { addressFamily = "ipv6" @@ -382,6 +403,10 @@ func timeSeriesLabels(meta nodeMeta, instance string, source timestampSource, st Name: "hostname", Value: meta.hostname, }) + labels = append(labels, prompb.Label{ + Name: "dst_port", + Value: strconv.Itoa(dstPort), + }) labels = append(labels, prompb.Label{ Name: "__name__", Value: "stunstamp_derp_stun_rtt_ns", @@ -406,40 +431,42 @@ const ( staleNaN uint64 = 0x7ff0000000000002 ) -func staleMarkersFromNodeMeta(stale []nodeMeta, instance string) []prompb.TimeSeries { +func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, dstPorts []int) []prompb.TimeSeries { staleMarkers := make([]prompb.TimeSeries, 0) now := time.Now() for _, s := range stale { - samples := []prompb.Sample{ - { - Timestamp: now.UnixMilli(), - Value: math.Float64frombits(staleNaN), - }, - } - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(s, instance, timestampSourceUserspace, unstableConn), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(s, instance, timestampSourceUserspace, stableConn), - Samples: samples, - }) - if supportsKernelTS() { + for _, dstPort := range dstPorts { + samples := []prompb.Sample{ + { + Timestamp: now.UnixMilli(), + Value: math.Float64frombits(staleNaN), + }, + } staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(s, instance, timestampSourceKernel, unstableConn), + Labels: timeSeriesLabels(s, instance, timestampSourceUserspace, unstableConn, dstPort), Samples: samples, }) staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(s, instance, timestampSourceKernel, stableConn), + Labels: timeSeriesLabels(s, instance, timestampSourceUserspace, stableConn, dstPort), Samples: samples, }) + if supportsKernelTS() { + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(s, instance, timestampSourceKernel, unstableConn, dstPort), + Samples: samples, + }) + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(s, instance, timestampSourceKernel, stableConn, dstPort), + Samples: samples, + }) + } } } return staleMarkers } func resultToPromTimeSeries(r result, instance string) prompb.TimeSeries { - labels := timeSeriesLabels(r.meta, instance, r.timestampSource, r.connStability) + labels := timeSeriesLabels(r.meta, instance, r.timestampSource, r.connStability, r.dstPort) samples := make([]prompb.Sample, 1) samples[0].Timestamp = r.at.UnixMilli() if r.rtt != nil { @@ -535,6 +562,20 @@ func remoteWriteTimeSeries(client *remoteWriteClient, tsCh chan []prompb.TimeSer func main() { flag.Parse() + if len(*flagDstPorts) == 0 { + log.Fatal("dst-ports flag is unset") + } + dstPortsSplit := strings.Split(*flagDstPorts, ",") + slices.Sort(dstPortsSplit) + dstPortsSplit = slices.Compact(dstPortsSplit) + dstPorts := make([]int, 0, len(dstPortsSplit)) + for _, d := range dstPortsSplit { + i, err := strconv.ParseUint(d, 10, 16) + if err != nil { + log.Fatal("invalid dst-ports") + } + dstPorts = append(dstPorts, int(i)) + } if len(*flagDERPMap) < 1 { log.Fatal("derp-map flag is unset") } @@ -545,10 +586,10 @@ func main() { log.Fatalf("interval must be >= %s and <= %s", minInterval, maxBufferDuration) } if *flagRetention < *flagInterval { - log.Fatalf("retention must be >= interval") + log.Fatal("retention must be >= interval") } if len(*flagRemoteWriteURL) < 1 { - log.Fatalf("rw-url flag is unset") + log.Fatal("rw-url flag is unset") } _, err := url.Parse(*flagRemoteWriteURL) if err != nil { @@ -610,7 +651,7 @@ func main() { // ~300 data points per-interval w/o ipv6 w/kernel timestamping resulting // in ~2.6m rows in 24h w/a 10s probe interval. _, err = db.Exec(` -CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT, address TEXT, timestamp_source INT, stable_conn INT, rtt_ns INT) +CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT, address TEXT, timestamp_source INT, stable_conn INT, dst_port INT, rtt_ns INT) `) if err != nil { log.Fatalf("error initializing db: %v", err) @@ -658,7 +699,7 @@ CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT for _, v := range nodeMetaByAddr { staleMeta = append(staleMeta, v) } - staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance) + staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, dstPorts) if len(staleMarkers) > 0 { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) rwc.write(ctx, staleMarkers) @@ -676,7 +717,7 @@ CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT // Comparison of stable and unstable 5-tuple results can shed light on // differences between paths where hashing (multipathing/load balancing) // comes into play. - stableConns := make(map[netip.Addr][2]io.ReadWriteCloser) + stableConns := make(map[netip.Addr]map[int][2]io.ReadWriteCloser) derpMapTicker := time.NewTicker(time.Minute * 5) defer derpMapTicker.Stop() @@ -697,7 +738,7 @@ CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT return } case <-probeTicker.C: - results, err := probeNodes(nodeMetaByAddr, stableConns) + results, err := probeNodes(nodeMetaByAddr, stableConns, dstPorts) if err != nil { log.Printf("unrecoverable error while probing: %v", err) shutdown() @@ -728,8 +769,8 @@ CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT if result.meta.addr.Is6() { af = 6 } - _, err = tx.Exec("INSERT INTO rtt(at_unix, region_id, hostname, af, address, timestamp_source, stable_conn, rtt_ns) VALUES(?, ?, ?, ?, ?, ?, ?, ?)", - result.at.Unix(), result.meta.regionID, result.meta.hostname, af, result.meta.addr.String(), result.timestampSource, result.connStability, result.rtt) + _, err = tx.Exec("INSERT INTO rtt(at_unix, region_id, hostname, af, address, timestamp_source, stable_conn, dst_port, rtt_ns) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)", + result.at.Unix(), result.meta.regionID, result.meta.hostname, af, result.meta.addr.String(), result.timestampSource, result.connStability, result.dstPort, result.rtt) if err != nil { tx.Rollback() log.Printf("error adding result to tx: %v", err) @@ -749,7 +790,7 @@ CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT log.Printf("error parsing DERP map, continuing with stale map: %v", err) continue } - staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance) + staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, dstPorts) if len(staleMarkers) < 1 { continue } diff --git a/cmd/stunstamp/stunstamp_linux.go b/cmd/stunstamp/stunstamp_linux.go index f21b0d2ef..898ab19f1 100644 --- a/cmd/stunstamp/stunstamp_linux.go +++ b/cmd/stunstamp/stunstamp_linux.go @@ -66,12 +66,12 @@ func measureRTTKernel(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Durat to4 := dst.IP.To4() if to4 != nil { to = &unix.SockaddrInet4{ - Port: 3478, + Port: dst.Port, } copy(to.(*unix.SockaddrInet4).Addr[:], to4) } else { to = &unix.SockaddrInet6{ - Port: 3478, + Port: dst.Port, } copy(to.(*unix.SockaddrInet6).Addr[:], dst.IP) }