cmd/stunstamp: refactor to support multiple protocols (#13063)

'stun' has been removed from metric names and replaced with a protocol
label. This refactor is preparation work for HTTPS & ICMP support.

Updates tailscale/corp#22114

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/13082/head
Jordan Whited 3 months ago committed by GitHub
parent f23932bd98
commit 20691894f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -42,7 +42,9 @@ var (
flagIPv6 = flag.Bool("ipv6", false, "probe IPv6 addresses") flagIPv6 = flag.Bool("ipv6", false, "probe IPv6 addresses")
flagRemoteWriteURL = flag.String("rw-url", "", "prometheus remote write URL") flagRemoteWriteURL = flag.String("rw-url", "", "prometheus remote write URL")
flagInstance = flag.String("instance", "", "instance label value; defaults to hostname if unspecified") flagInstance = flag.String("instance", "", "instance label value; defaults to hostname if unspecified")
flagDstPorts = flag.String("dst-ports", "", "comma-separated list of destination ports to monitor") flagSTUNDstPorts = flag.String("stun-dst-ports", "", "comma-separated list of STUN destination ports to monitor")
flagHTTPSDstPorts = flag.String("https-dst-ports", "", "comma-separated list of HTTPS destination ports to monitor")
flagICMP = flag.Bool("icmp", false, "probe ICMP")
) )
const ( const (
@ -89,12 +91,21 @@ func (t timestampSource) String() string {
} }
} }
type protocol string
const (
protocolSTUN protocol = "stun"
protocolICMP protocol = "icmp"
protocolHTTPS protocol = "https"
)
// resultKey contains the stable dimensions and their values for a given // resultKey contains the stable dimensions and their values for a given
// timeseries, i.e. not time and not rtt/timeout. // timeseries, i.e. not time and not rtt/timeout.
type resultKey struct { type resultKey struct {
meta nodeMeta meta nodeMeta
timestampSource timestampSource timestampSource timestampSource
connStability connStability connStability connStability
protocol protocol
dstPort int dstPort int
} }
@ -104,7 +115,7 @@ type result struct {
rtt *time.Duration // nil signifies failure, e.g. timeout rtt *time.Duration // nil signifies failure, e.g. timeout
} }
func measureRTT(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) { func measureSTUNRTT(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) {
uconn, ok := conn.(*net.UDPConn) uconn, ok := conn.(*net.UDPConn)
if !ok { if !ok {
return 0, fmt.Errorf("unexpected conn type: %T", conn) return 0, fmt.Errorf("unexpected conn type: %T", conn)
@ -116,7 +127,10 @@ func measureRTT(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, e
txID := stun.NewTxID() txID := stun.NewTxID()
req := stun.Request(txID) req := stun.Request(txID)
txAt := time.Now() txAt := time.Now()
_, err = uconn.WriteToUDP(req, dst) _, err = uconn.WriteToUDP(req, &net.UDPAddr{
IP: dst.Addr().AsSlice(),
Port: int(dst.Port()),
})
if err != nil { if err != nil {
return 0, fmt.Errorf("error writing to udp socket: %w", err) return 0, fmt.Errorf("error writing to udp socket: %w", err)
} }
@ -153,11 +167,11 @@ type nodeMeta struct {
addr netip.Addr addr netip.Addr
} }
type measureFn func(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) type measureFn func(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error)
// probe measures STUN round trip time for the node described by meta over // probe measures round trip time for the node described by meta over
// conn against dstPort. It may return a nil duration and nil error if the // conn against dstPort using fn. It may return a nil duration and nil error in
// STUN request timed out. A non-nil error indicates an unrecoverable or // the event of a timeout. A non-nil error indicates an unrecoverable or
// non-temporary error. // non-temporary error.
func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*time.Duration, error) { func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*time.Duration, error) {
ua := &net.UDPAddr{ ua := &net.UDPAddr{
@ -166,7 +180,7 @@ func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*
} }
time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx
rtt, err := fn(conn, ua) rtt, err := fn(conn, netip.AddrPortFrom(meta.addr, uint16(dstPort)))
if err != nil { if err != nil {
if isTemporaryOrTimeoutErr(err) { if isTemporaryOrTimeoutErr(err) {
log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err) log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err)
@ -237,43 +251,71 @@ func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]node
return stale, nil return stale, nil
} }
func getStableConns(stableConns map[netip.Addr]map[int][2]io.ReadWriteCloser, addr netip.Addr, dstPort int) ([2]io.ReadWriteCloser, error) { func newConn(source timestampSource, protocol protocol) (io.ReadWriteCloser, error) {
conns := [2]io.ReadWriteCloser{} switch protocol {
byDstPort, ok := stableConns[addr] case protocolSTUN:
if ok { if source == timestampSourceKernel {
conns, ok = byDstPort[dstPort] return getUDPConnKernelTimestamp()
if ok { } else {
return conns, nil return net.ListenUDP("udp", &net.UDPAddr{})
} }
case protocolICMP:
// TODO(jwhited): implement
return nil, errors.New("unimplemented protocol")
case protocolHTTPS:
// TODO(jwhited): implement
return nil, errors.New("unimplemented protocol")
}
return nil, errors.New("unknown protocol")
}
type stableConnKey struct {
node netip.Addr
protocol protocol
port int
}
func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr netip.Addr, protocol protocol, dstPort int) ([2]io.ReadWriteCloser, error) {
if !protocolSupportsStableConn(protocol) {
return [2]io.ReadWriteCloser{}, nil
} }
if supportsKernelTS() { conns, ok := stableConns[stableConnKey{addr, protocol, dstPort}]
kconn, err := getConnKernelTimestamp() if ok {
return conns, nil
}
if protocolSupportsKernelTS(protocol) {
kconn, err := newConn(timestampSourceKernel, protocol)
if err != nil { if err != nil {
return conns, err return conns, err
} }
conns[timestampSourceKernel] = kconn conns[timestampSourceKernel] = kconn
} }
uconn, err := net.ListenUDP("udp", &net.UDPAddr{}) uconn, err := newConn(timestampSourceUserspace, protocol)
if err != nil { if err != nil {
if supportsKernelTS() { if protocolSupportsKernelTS(protocol) {
conns[timestampSourceKernel].Close() conns[timestampSourceKernel].Close()
} }
return conns, err return conns, err
} }
conns[timestampSourceUserspace] = uconn conns[timestampSourceUserspace] = uconn
if byDstPort == nil {
byDstPort = make(map[int][2]io.ReadWriteCloser)
}
byDstPort[dstPort] = conns
stableConns[addr] = byDstPort
return conns, nil return conns, nil
} }
// probeNodes measures the round-trip time for STUN binding requests against the func protocolSupportsStableConn(p protocol) bool {
// DERP nodes described by nodeMetaByAddr while using/updating stableConns for if p == protocolICMP {
// UDP sockets that should be recycled across runs. It returns the results or // no value for ICMP
// an error if one occurs. return false
func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Addr]map[int][2]io.ReadWriteCloser, dstPorts []int) ([]result, error) { }
return true
}
// probeNodes measures the round-trip time for the protocols and ports described
// by portsByProtocol against the DERP nodes described by nodeMetaByAddr.
// stableConns are used to recycle connections across calls to probeNodes.
// probeNodes is also responsible for trimming stableConns based on node
// lifetime in nodeMetaByAddr. It returns the results or an error if one occurs.
func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]io.ReadWriteCloser, portsByProtocol map[protocol][]int) ([]result, error) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
results := make([]result, 0) results := make([]result, 0)
resultsCh := make(chan result) resultsCh := make(chan result)
@ -283,23 +325,20 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad
at := time.Now() at := time.Now()
addrsToProbe := make(map[netip.Addr]bool) addrsToProbe := make(map[netip.Addr]bool)
doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, dstPort int) { doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, protocol protocol, dstPort int) {
defer wg.Done() defer wg.Done()
r := result{ r := result{
key: resultKey{ key: resultKey{
meta: meta, meta: meta,
timestampSource: source, timestampSource: source,
dstPort: dstPort, dstPort: dstPort,
protocol: protocol,
}, },
at: at, at: at,
} }
if conn == nil { if conn == nil {
var err error var err error
if source == timestampSourceKernel { conn, err = newConn(source, protocol)
conn, err = getConnKernelTimestamp()
} else {
conn, err = net.ListenUDP("udp", &net.UDPAddr{})
}
if err != nil { if err != nil {
select { select {
case <-doneCh: case <-doneCh:
@ -312,9 +351,17 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad
} else { } else {
r.key.connStability = stableConn r.key.connStability = stableConn
} }
fn := measureRTT var fn measureFn
if source == timestampSourceKernel { switch protocol {
fn = measureRTTKernel case protocolSTUN:
fn = measureSTUNRTT
if source == timestampSourceKernel {
fn = measureSTUNRTTKernel
}
case protocolICMP:
// TODO(jwhited): implement
case protocolHTTPS:
// TODO(jwhited): implement
} }
rtt, err := probe(meta, conn, fn, dstPort) rtt, err := probe(meta, conn, fn, dstPort)
if err != nil { if err != nil {
@ -334,37 +381,47 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad
for _, meta := range nodeMetaByAddr { for _, meta := range nodeMetaByAddr {
addrsToProbe[meta.addr] = true addrsToProbe[meta.addr] = true
for _, port := range dstPorts { for p, ports := range portsByProtocol {
stable, err := getStableConns(stableConns, meta.addr, port) for _, port := range ports {
if err != nil { stable, err := getStableConns(stableConns, meta.addr, p, port)
close(doneCh) if err != nil {
wg.Wait() close(doneCh)
return nil, err wg.Wait()
} return nil, err
}
wg.Add(2) if protocolSupportsStableConn(p) {
numProbes += 2 wg.Add(1)
go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace, port) numProbes++
go doProbe(nil, meta, timestampSourceUserspace, port) go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace, p, port)
if supportsKernelTS() { }
wg.Add(2) wg.Add(1)
numProbes += 2 numProbes++
go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel, port) go doProbe(nil, meta, timestampSourceUserspace, p, port)
go doProbe(nil, meta, timestampSourceKernel, port)
if protocolSupportsKernelTS(p) {
if protocolSupportsStableConn(p) {
wg.Add(1)
numProbes++
go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel, p, port)
}
wg.Add(1)
numProbes++
go doProbe(nil, meta, timestampSourceKernel, p, port)
}
} }
} }
} }
// cleanup conns we no longer need // cleanup conns we no longer need
for k, byDstPort := range stableConns { for k, conns := range stableConns {
if !addrsToProbe[k] { if !addrsToProbe[k.node] {
for _, conns := range byDstPort { if conns[timestampSourceKernel] != nil {
if conns[timestampSourceKernel] != nil { conns[timestampSourceKernel].Close()
conns[timestampSourceKernel].Close()
}
conns[timestampSourceUserspace].Close()
delete(stableConns, k)
} }
conns[timestampSourceUserspace].Close()
delete(stableConns, k)
} }
} }
@ -391,11 +448,11 @@ const (
) )
const ( const (
rttMetricName = "stunstamp_derp_stun_rtt_ns" rttMetricName = "stunstamp_derp_rtt_ns"
timeoutsMetricName = "stunstamp_derp_stun_timeouts_total" timeoutsMetricName = "stunstamp_derp_timeouts_total"
) )
func timeSeriesLabels(metricName string, meta nodeMeta, instance string, source timestampSource, stability connStability, dstPort int) []prompb.Label { func timeSeriesLabels(metricName string, meta nodeMeta, instance string, source timestampSource, stability connStability, protocol protocol, dstPort int) []prompb.Label {
addressFamily := "ipv4" addressFamily := "ipv4"
if meta.addr.Is6() { if meta.addr.Is6() {
addressFamily = "ipv6" addressFamily = "ipv6"
@ -425,6 +482,10 @@ func timeSeriesLabels(metricName string, meta nodeMeta, instance string, source
Name: "hostname", Name: "hostname",
Value: meta.hostname, Value: meta.hostname,
}) })
labels = append(labels, prompb.Label{
Name: "protocol",
Value: string(protocol),
})
labels = append(labels, prompb.Label{ labels = append(labels, prompb.Label{
Name: "dst_port", Name: "dst_port",
Value: strconv.Itoa(dstPort), Value: strconv.Itoa(dstPort),
@ -453,53 +514,61 @@ const (
staleNaN uint64 = 0x7ff0000000000002 staleNaN uint64 = 0x7ff0000000000002
) )
func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, dstPorts []int) []prompb.TimeSeries { func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, portsByProtocol map[protocol][]int) []prompb.TimeSeries {
staleMarkers := make([]prompb.TimeSeries, 0) staleMarkers := make([]prompb.TimeSeries, 0)
now := time.Now() now := time.Now()
for _, s := range stale {
for _, dstPort := range dstPorts { for p, ports := range portsByProtocol {
samples := []prompb.Sample{ for _, port := range ports {
{ for _, s := range stale {
Timestamp: now.UnixMilli(), samples := []prompb.Sample{
Value: math.Float64frombits(staleNaN), {
}, Timestamp: now.UnixMilli(),
} Value: math.Float64frombits(staleNaN),
staleMarkers = append(staleMarkers, prompb.TimeSeries{ },
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, unstableConn, dstPort), }
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, stableConn, dstPort),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, unstableConn, dstPort),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, stableConn, dstPort),
Samples: samples,
})
if supportsKernelTS() {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, unstableConn, dstPort),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, stableConn, dstPort),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{ staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, unstableConn, dstPort), Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port),
Samples: samples, Samples: samples,
}) })
staleMarkers = append(staleMarkers, prompb.TimeSeries{ staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, stableConn, dstPort), Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port),
Samples: samples, Samples: samples,
}) })
if protocolSupportsStableConn(p) {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, stableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, stableConn, p, port),
Samples: samples,
})
}
if protocolSupportsKernelTS(p) {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, unstableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, unstableConn, p, port),
Samples: samples,
})
if protocolSupportsStableConn(p) {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, stableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, stableConn, p, port),
Samples: samples,
})
}
}
} }
} }
} }
return staleMarkers return staleMarkers
} }
@ -513,7 +582,7 @@ func resultsToPromTimeSeries(results []result, instance string, timeouts map[res
for _, r := range results { for _, r := range results {
timeoutsCount := timeouts[r.key] // a non-existent key will return a zero val timeoutsCount := timeouts[r.key] // a non-existent key will return a zero val
seenKeys[r.key] = true seenKeys[r.key] = true
rttLabels := timeSeriesLabels(rttMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.dstPort) rttLabels := timeSeriesLabels(rttMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.protocol, r.key.dstPort)
rttSamples := make([]prompb.Sample, 1) rttSamples := make([]prompb.Sample, 1)
rttSamples[0].Timestamp = r.at.UnixMilli() rttSamples[0].Timestamp = r.at.UnixMilli()
if r.rtt != nil { if r.rtt != nil {
@ -528,7 +597,7 @@ func resultsToPromTimeSeries(results []result, instance string, timeouts map[res
} }
all = append(all, rttTS) all = append(all, rttTS)
timeouts[r.key] = timeoutsCount timeouts[r.key] = timeoutsCount
timeoutsLabels := timeSeriesLabels(timeoutsMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.dstPort) timeoutsLabels := timeSeriesLabels(timeoutsMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.protocol, r.key.dstPort)
timeoutsSamples := make([]prompb.Sample, 1) timeoutsSamples := make([]prompb.Sample, 1)
timeoutsSamples[0].Timestamp = r.at.UnixMilli() timeoutsSamples[0].Timestamp = r.at.UnixMilli()
timeoutsSamples[0].Value = float64(timeoutsCount) timeoutsSamples[0].Value = float64(timeoutsCount)
@ -620,22 +689,56 @@ func remoteWriteTimeSeries(client *remoteWriteClient, tsCh chan []prompb.TimeSer
} }
} }
func getPortsFromFlag(f string) ([]int, error) {
if len(f) == 0 {
return nil, nil
}
split := strings.Split(f, ",")
slices.Sort(split)
split = slices.Compact(split)
ports := make([]int, 0)
for _, portStr := range split {
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
ports = append(ports, int(port))
}
return ports, nil
}
func main() { func main() {
flag.Parse() flag.Parse()
if len(*flagDstPorts) == 0 {
log.Fatal("dst-ports flag is unset") portsByProtocol := make(map[protocol][]int)
} stunPorts, err := getPortsFromFlag(*flagSTUNDstPorts)
dstPortsSplit := strings.Split(*flagDstPorts, ",") if err != nil {
slices.Sort(dstPortsSplit) log.Fatalf("invalid stun-dst-ports flag value: %v", err)
dstPortsSplit = slices.Compact(dstPortsSplit) }
dstPorts := make([]int, 0, len(dstPortsSplit)) if len(stunPorts) > 0 {
for _, d := range dstPortsSplit { portsByProtocol[protocolSTUN] = stunPorts
i, err := strconv.ParseUint(d, 10, 16) }
if err != nil { httpsPorts, err := getPortsFromFlag(*flagHTTPSDstPorts)
log.Fatal("invalid dst-ports") if err != nil {
log.Fatalf("invalid https-dst-ports flag value: %v", err)
}
if len(httpsPorts) > 0 {
portsByProtocol[protocolHTTPS] = httpsPorts
}
if *flagICMP {
portsByProtocol[protocolICMP] = []int{0}
}
if len(portsByProtocol) == 0 {
log.Fatal("nothing to probe")
}
// TODO(jwhited): remove protocol restriction
for k := range portsByProtocol {
if k != protocolSTUN {
log.Fatal("HTTPS & ICMP are not yet supported")
} }
dstPorts = append(dstPorts, int(i))
} }
if len(*flagDERPMap) < 1 { if len(*flagDERPMap) < 1 {
log.Fatal("derp-map flag is unset") log.Fatal("derp-map flag is unset")
} }
@ -645,7 +748,7 @@ func main() {
if len(*flagRemoteWriteURL) < 1 { if len(*flagRemoteWriteURL) < 1 {
log.Fatal("rw-url flag is unset") log.Fatal("rw-url flag is unset")
} }
_, err := url.Parse(*flagRemoteWriteURL) _, err = url.Parse(*flagRemoteWriteURL)
if err != nil { if err != nil {
log.Fatalf("invalid rw-url flag value: %v", err) log.Fatalf("invalid rw-url flag value: %v", err)
} }
@ -707,7 +810,7 @@ func main() {
for _, v := range nodeMetaByAddr { for _, v := range nodeMetaByAddr {
staleMeta = append(staleMeta, v) staleMeta = append(staleMeta, v)
} }
staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, dstPorts) staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, portsByProtocol)
if len(staleMarkers) > 0 { if len(staleMarkers) > 0 {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
rwc.write(ctx, staleMarkers) rwc.write(ctx, staleMarkers)
@ -723,8 +826,8 @@ func main() {
// in a higher probability of the packets traversing the same underlay path. // in a higher probability of the packets traversing the same underlay path.
// Comparison of stable and unstable 5-tuple results can shed light on // Comparison of stable and unstable 5-tuple results can shed light on
// differences between paths where hashing (multipathing/load balancing) // differences between paths where hashing (multipathing/load balancing)
// comes into play. // comes into play. The inner 2 element array index is timestampSource.
stableConns := make(map[netip.Addr]map[int][2]io.ReadWriteCloser) stableConns := make(map[stableConnKey][2]io.ReadWriteCloser)
// timeouts holds counts of timeout events. Values are persisted for the // timeouts holds counts of timeout events. Values are persisted for the
// lifetime of the related node in the DERP map. // lifetime of the related node in the DERP map.
@ -738,7 +841,7 @@ func main() {
for { for {
select { select {
case <-probeTicker.C: case <-probeTicker.C:
results, err := probeNodes(nodeMetaByAddr, stableConns, dstPorts) results, err := probeNodes(nodeMetaByAddr, stableConns, portsByProtocol)
if err != nil { if err != nil {
log.Printf("unrecoverable error while probing: %v", err) log.Printf("unrecoverable error while probing: %v", err)
shutdown() shutdown()
@ -761,7 +864,7 @@ func main() {
log.Printf("error parsing DERP map, continuing with stale map: %v", err) log.Printf("error parsing DERP map, continuing with stale map: %v", err)
continue continue
} }
staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, dstPorts) staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, portsByProtocol)
if len(staleMarkers) < 1 { if len(staleMarkers) < 1 {
continue continue
} }

@ -8,18 +8,18 @@ package main
import ( import (
"errors" "errors"
"io" "io"
"net" "net/netip"
"time" "time"
) )
func getConnKernelTimestamp() (io.ReadWriteCloser, error) { func getUDPConnKernelTimestamp() (io.ReadWriteCloser, error) {
return nil, errors.New("unimplemented") return nil, errors.New("unimplemented")
} }
func measureRTTKernel(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) { func measureSTUNRTTKernel(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) {
return 0, errors.New("unimplemented") return 0, errors.New("unimplemented")
} }
func supportsKernelTS() bool { func protocolSupportsKernelTS(_ protocol) bool {
return false return false
} }

@ -10,7 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"time" "time"
"github.com/mdlayher/socket" "github.com/mdlayher/socket"
@ -24,7 +24,7 @@ const (
unix.SOF_TIMESTAMPING_SOFTWARE // report software timestamps unix.SOF_TIMESTAMPING_SOFTWARE // report software timestamps
) )
func getConnKernelTimestamp() (io.ReadWriteCloser, error) { func getUDPConnKernelTimestamp() (io.ReadWriteCloser, error) {
sconn, err := socket.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP, "udp", nil) sconn, err := socket.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP, "udp", nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -56,24 +56,23 @@ func parseTimestampFromCmsgs(oob []byte) (time.Time, error) {
return time.Time{}, errors.New("failed to parse timestamp from cmsgs") return time.Time{}, errors.New("failed to parse timestamp from cmsgs")
} }
func measureRTTKernel(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) { func measureSTUNRTTKernel(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) {
sconn, ok := conn.(*socket.Conn) sconn, ok := conn.(*socket.Conn)
if !ok { if !ok {
return 0, fmt.Errorf("conn of unexpected type: %T", conn) return 0, fmt.Errorf("conn of unexpected type: %T", conn)
} }
var to unix.Sockaddr var to unix.Sockaddr
to4 := dst.IP.To4() if dst.Addr().Is4() {
if to4 != nil {
to = &unix.SockaddrInet4{ to = &unix.SockaddrInet4{
Port: dst.Port, Port: int(dst.Port()),
} }
copy(to.(*unix.SockaddrInet4).Addr[:], to4) copy(to.(*unix.SockaddrInet4).Addr[:], dst.Addr().AsSlice())
} else { } else {
to = &unix.SockaddrInet6{ to = &unix.SockaddrInet6{
Port: dst.Port, Port: int(dst.Port()),
} }
copy(to.(*unix.SockaddrInet6).Addr[:], dst.IP) copy(to.(*unix.SockaddrInet6).Addr[:], dst.Addr().AsSlice())
} }
txID := stun.NewTxID() txID := stun.NewTxID()
@ -138,6 +137,10 @@ func measureRTTKernel(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Durat
} }
func supportsKernelTS() bool { func protocolSupportsKernelTS(p protocol) bool {
return true if p == protocolSTUN {
return true
}
// TODO: jwhited support ICMP
return false
} }

Loading…
Cancel
Save