From 20691894f5722deb756e397c6407eb5fd5273126 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Fri, 9 Aug 2024 08:03:58 -0700 Subject: [PATCH] cmd/stunstamp: refactor to support multiple protocols (#13063) 'stun' has been removed from metric names and replaced with a protocol label. This refactor is preparation work for HTTPS & ICMP support. Updates tailscale/corp#22114 Signed-off-by: Jordan Whited --- cmd/stunstamp/stunstamp.go | 345 +++++++++++++++++++---------- cmd/stunstamp/stunstamp_default.go | 8 +- cmd/stunstamp/stunstamp_linux.go | 25 ++- 3 files changed, 242 insertions(+), 136 deletions(-) diff --git a/cmd/stunstamp/stunstamp.go b/cmd/stunstamp/stunstamp.go index a4d25922a..e2b034e32 100644 --- a/cmd/stunstamp/stunstamp.go +++ b/cmd/stunstamp/stunstamp.go @@ -42,7 +42,9 @@ var ( flagIPv6 = flag.Bool("ipv6", false, "probe IPv6 addresses") flagRemoteWriteURL = flag.String("rw-url", "", "prometheus remote write URL") flagInstance = flag.String("instance", "", "instance label value; defaults to hostname if unspecified") - flagDstPorts = flag.String("dst-ports", "", "comma-separated list of destination ports to monitor") + flagSTUNDstPorts = flag.String("stun-dst-ports", "", "comma-separated list of STUN destination ports to monitor") + flagHTTPSDstPorts = flag.String("https-dst-ports", "", "comma-separated list of HTTPS destination ports to monitor") + flagICMP = flag.Bool("icmp", false, "probe ICMP") ) const ( @@ -89,12 +91,21 @@ func (t timestampSource) String() string { } } +type protocol string + +const ( + protocolSTUN protocol = "stun" + protocolICMP protocol = "icmp" + protocolHTTPS protocol = "https" +) + // resultKey contains the stable dimensions and their values for a given // timeseries, i.e. not time and not rtt/timeout. type resultKey struct { meta nodeMeta timestampSource timestampSource connStability connStability + protocol protocol dstPort int } @@ -104,7 +115,7 @@ type result struct { rtt *time.Duration // nil signifies failure, e.g. timeout } -func measureRTT(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) { +func measureSTUNRTT(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) { uconn, ok := conn.(*net.UDPConn) if !ok { return 0, fmt.Errorf("unexpected conn type: %T", conn) @@ -116,7 +127,10 @@ func measureRTT(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, e txID := stun.NewTxID() req := stun.Request(txID) txAt := time.Now() - _, err = uconn.WriteToUDP(req, dst) + _, err = uconn.WriteToUDP(req, &net.UDPAddr{ + IP: dst.Addr().AsSlice(), + Port: int(dst.Port()), + }) if err != nil { return 0, fmt.Errorf("error writing to udp socket: %w", err) } @@ -153,11 +167,11 @@ type nodeMeta struct { addr netip.Addr } -type measureFn func(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) +type measureFn func(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) -// probe measures STUN round trip time for the node described by meta over -// conn against dstPort. It may return a nil duration and nil error if the -// STUN request timed out. A non-nil error indicates an unrecoverable or +// probe measures round trip time for the node described by meta over +// conn against dstPort using fn. It may return a nil duration and nil error in +// the event of a timeout. A non-nil error indicates an unrecoverable or // non-temporary error. func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*time.Duration, error) { ua := &net.UDPAddr{ @@ -166,7 +180,7 @@ func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (* } time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx - rtt, err := fn(conn, ua) + rtt, err := fn(conn, netip.AddrPortFrom(meta.addr, uint16(dstPort))) if err != nil { if isTemporaryOrTimeoutErr(err) { log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err) @@ -237,43 +251,71 @@ func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]node return stale, nil } -func getStableConns(stableConns map[netip.Addr]map[int][2]io.ReadWriteCloser, addr netip.Addr, dstPort int) ([2]io.ReadWriteCloser, error) { - conns := [2]io.ReadWriteCloser{} - byDstPort, ok := stableConns[addr] - if ok { - conns, ok = byDstPort[dstPort] - if ok { - return conns, nil +func newConn(source timestampSource, protocol protocol) (io.ReadWriteCloser, error) { + switch protocol { + case protocolSTUN: + if source == timestampSourceKernel { + return getUDPConnKernelTimestamp() + } else { + return net.ListenUDP("udp", &net.UDPAddr{}) } + case protocolICMP: + // TODO(jwhited): implement + return nil, errors.New("unimplemented protocol") + case protocolHTTPS: + // TODO(jwhited): implement + return nil, errors.New("unimplemented protocol") + } + return nil, errors.New("unknown protocol") +} + +type stableConnKey struct { + node netip.Addr + protocol protocol + port int +} + +func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr netip.Addr, protocol protocol, dstPort int) ([2]io.ReadWriteCloser, error) { + if !protocolSupportsStableConn(protocol) { + return [2]io.ReadWriteCloser{}, nil } - if supportsKernelTS() { - kconn, err := getConnKernelTimestamp() + conns, ok := stableConns[stableConnKey{addr, protocol, dstPort}] + if ok { + return conns, nil + } + + if protocolSupportsKernelTS(protocol) { + kconn, err := newConn(timestampSourceKernel, protocol) if err != nil { return conns, err } conns[timestampSourceKernel] = kconn } - uconn, err := net.ListenUDP("udp", &net.UDPAddr{}) + uconn, err := newConn(timestampSourceUserspace, protocol) if err != nil { - if supportsKernelTS() { + if protocolSupportsKernelTS(protocol) { conns[timestampSourceKernel].Close() } return conns, err } conns[timestampSourceUserspace] = uconn - if byDstPort == nil { - byDstPort = make(map[int][2]io.ReadWriteCloser) - } - byDstPort[dstPort] = conns - stableConns[addr] = byDstPort return conns, nil } -// probeNodes measures the round-trip time for STUN binding requests against the -// DERP nodes described by nodeMetaByAddr while using/updating stableConns for -// UDP sockets that should be recycled across runs. It returns the results or -// an error if one occurs. -func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Addr]map[int][2]io.ReadWriteCloser, dstPorts []int) ([]result, error) { +func protocolSupportsStableConn(p protocol) bool { + if p == protocolICMP { + // no value for ICMP + return false + } + return true +} + +// probeNodes measures the round-trip time for the protocols and ports described +// by portsByProtocol against the DERP nodes described by nodeMetaByAddr. +// stableConns are used to recycle connections across calls to probeNodes. +// probeNodes is also responsible for trimming stableConns based on node +// lifetime in nodeMetaByAddr. It returns the results or an error if one occurs. +func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]io.ReadWriteCloser, portsByProtocol map[protocol][]int) ([]result, error) { wg := sync.WaitGroup{} results := make([]result, 0) resultsCh := make(chan result) @@ -283,23 +325,20 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad at := time.Now() addrsToProbe := make(map[netip.Addr]bool) - doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, dstPort int) { + doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, protocol protocol, dstPort int) { defer wg.Done() r := result{ key: resultKey{ meta: meta, timestampSource: source, dstPort: dstPort, + protocol: protocol, }, at: at, } if conn == nil { var err error - if source == timestampSourceKernel { - conn, err = getConnKernelTimestamp() - } else { - conn, err = net.ListenUDP("udp", &net.UDPAddr{}) - } + conn, err = newConn(source, protocol) if err != nil { select { case <-doneCh: @@ -312,9 +351,17 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad } else { r.key.connStability = stableConn } - fn := measureRTT - if source == timestampSourceKernel { - fn = measureRTTKernel + var fn measureFn + switch protocol { + case protocolSTUN: + fn = measureSTUNRTT + if source == timestampSourceKernel { + fn = measureSTUNRTTKernel + } + case protocolICMP: + // TODO(jwhited): implement + case protocolHTTPS: + // TODO(jwhited): implement } rtt, err := probe(meta, conn, fn, dstPort) if err != nil { @@ -334,37 +381,47 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad for _, meta := range nodeMetaByAddr { addrsToProbe[meta.addr] = true - for _, port := range dstPorts { - stable, err := getStableConns(stableConns, meta.addr, port) - if err != nil { - close(doneCh) - wg.Wait() - return nil, err - } + for p, ports := range portsByProtocol { + for _, port := range ports { + stable, err := getStableConns(stableConns, meta.addr, p, port) + if err != nil { + close(doneCh) + wg.Wait() + return nil, err + } - wg.Add(2) - numProbes += 2 - go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace, port) - go doProbe(nil, meta, timestampSourceUserspace, port) - if supportsKernelTS() { - wg.Add(2) - numProbes += 2 - go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel, port) - go doProbe(nil, meta, timestampSourceKernel, port) + if protocolSupportsStableConn(p) { + wg.Add(1) + numProbes++ + go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace, p, port) + } + wg.Add(1) + numProbes++ + go doProbe(nil, meta, timestampSourceUserspace, p, port) + + if protocolSupportsKernelTS(p) { + if protocolSupportsStableConn(p) { + wg.Add(1) + numProbes++ + go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel, p, port) + } + + wg.Add(1) + numProbes++ + go doProbe(nil, meta, timestampSourceKernel, p, port) + } } } } // cleanup conns we no longer need - for k, byDstPort := range stableConns { - if !addrsToProbe[k] { - for _, conns := range byDstPort { - if conns[timestampSourceKernel] != nil { - conns[timestampSourceKernel].Close() - } - conns[timestampSourceUserspace].Close() - delete(stableConns, k) + for k, conns := range stableConns { + if !addrsToProbe[k.node] { + if conns[timestampSourceKernel] != nil { + conns[timestampSourceKernel].Close() } + conns[timestampSourceUserspace].Close() + delete(stableConns, k) } } @@ -391,11 +448,11 @@ const ( ) const ( - rttMetricName = "stunstamp_derp_stun_rtt_ns" - timeoutsMetricName = "stunstamp_derp_stun_timeouts_total" + rttMetricName = "stunstamp_derp_rtt_ns" + timeoutsMetricName = "stunstamp_derp_timeouts_total" ) -func timeSeriesLabels(metricName string, meta nodeMeta, instance string, source timestampSource, stability connStability, dstPort int) []prompb.Label { +func timeSeriesLabels(metricName string, meta nodeMeta, instance string, source timestampSource, stability connStability, protocol protocol, dstPort int) []prompb.Label { addressFamily := "ipv4" if meta.addr.Is6() { addressFamily = "ipv6" @@ -425,6 +482,10 @@ func timeSeriesLabels(metricName string, meta nodeMeta, instance string, source Name: "hostname", Value: meta.hostname, }) + labels = append(labels, prompb.Label{ + Name: "protocol", + Value: string(protocol), + }) labels = append(labels, prompb.Label{ Name: "dst_port", Value: strconv.Itoa(dstPort), @@ -453,53 +514,61 @@ const ( staleNaN uint64 = 0x7ff0000000000002 ) -func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, dstPorts []int) []prompb.TimeSeries { +func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, portsByProtocol map[protocol][]int) []prompb.TimeSeries { staleMarkers := make([]prompb.TimeSeries, 0) now := time.Now() - for _, s := range stale { - for _, dstPort := range dstPorts { - samples := []prompb.Sample{ - { - Timestamp: now.UnixMilli(), - Value: math.Float64frombits(staleNaN), - }, - } - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, unstableConn, dstPort), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, stableConn, dstPort), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, unstableConn, dstPort), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, stableConn, dstPort), - Samples: samples, - }) - if supportsKernelTS() { - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, unstableConn, dstPort), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, stableConn, dstPort), - Samples: samples, - }) + + for p, ports := range portsByProtocol { + for _, port := range ports { + for _, s := range stale { + samples := []prompb.Sample{ + { + Timestamp: now.UnixMilli(), + Value: math.Float64frombits(staleNaN), + }, + } staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, unstableConn, dstPort), + Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port), Samples: samples, }) staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, stableConn, dstPort), + Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port), Samples: samples, }) + if protocolSupportsStableConn(p) { + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, stableConn, p, port), + Samples: samples, + }) + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, stableConn, p, port), + Samples: samples, + }) + } + if protocolSupportsKernelTS(p) { + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, unstableConn, p, port), + Samples: samples, + }) + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, unstableConn, p, port), + Samples: samples, + }) + if protocolSupportsStableConn(p) { + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, stableConn, p, port), + Samples: samples, + }) + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, stableConn, p, port), + Samples: samples, + }) + } + } } } } + return staleMarkers } @@ -513,7 +582,7 @@ func resultsToPromTimeSeries(results []result, instance string, timeouts map[res for _, r := range results { timeoutsCount := timeouts[r.key] // a non-existent key will return a zero val seenKeys[r.key] = true - rttLabels := timeSeriesLabels(rttMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.dstPort) + rttLabels := timeSeriesLabels(rttMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.protocol, r.key.dstPort) rttSamples := make([]prompb.Sample, 1) rttSamples[0].Timestamp = r.at.UnixMilli() if r.rtt != nil { @@ -528,7 +597,7 @@ func resultsToPromTimeSeries(results []result, instance string, timeouts map[res } all = append(all, rttTS) timeouts[r.key] = timeoutsCount - timeoutsLabels := timeSeriesLabels(timeoutsMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.dstPort) + timeoutsLabels := timeSeriesLabels(timeoutsMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.protocol, r.key.dstPort) timeoutsSamples := make([]prompb.Sample, 1) timeoutsSamples[0].Timestamp = r.at.UnixMilli() timeoutsSamples[0].Value = float64(timeoutsCount) @@ -620,22 +689,56 @@ func remoteWriteTimeSeries(client *remoteWriteClient, tsCh chan []prompb.TimeSer } } +func getPortsFromFlag(f string) ([]int, error) { + if len(f) == 0 { + return nil, nil + } + split := strings.Split(f, ",") + slices.Sort(split) + split = slices.Compact(split) + ports := make([]int, 0) + for _, portStr := range split { + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, err + } + ports = append(ports, int(port)) + } + return ports, nil +} + func main() { flag.Parse() - if len(*flagDstPorts) == 0 { - log.Fatal("dst-ports flag is unset") - } - dstPortsSplit := strings.Split(*flagDstPorts, ",") - slices.Sort(dstPortsSplit) - dstPortsSplit = slices.Compact(dstPortsSplit) - dstPorts := make([]int, 0, len(dstPortsSplit)) - for _, d := range dstPortsSplit { - i, err := strconv.ParseUint(d, 10, 16) - if err != nil { - log.Fatal("invalid dst-ports") + + portsByProtocol := make(map[protocol][]int) + stunPorts, err := getPortsFromFlag(*flagSTUNDstPorts) + if err != nil { + log.Fatalf("invalid stun-dst-ports flag value: %v", err) + } + if len(stunPorts) > 0 { + portsByProtocol[protocolSTUN] = stunPorts + } + httpsPorts, err := getPortsFromFlag(*flagHTTPSDstPorts) + if err != nil { + log.Fatalf("invalid https-dst-ports flag value: %v", err) + } + if len(httpsPorts) > 0 { + portsByProtocol[protocolHTTPS] = httpsPorts + } + if *flagICMP { + portsByProtocol[protocolICMP] = []int{0} + } + if len(portsByProtocol) == 0 { + log.Fatal("nothing to probe") + } + + // TODO(jwhited): remove protocol restriction + for k := range portsByProtocol { + if k != protocolSTUN { + log.Fatal("HTTPS & ICMP are not yet supported") } - dstPorts = append(dstPorts, int(i)) } + if len(*flagDERPMap) < 1 { log.Fatal("derp-map flag is unset") } @@ -645,7 +748,7 @@ func main() { if len(*flagRemoteWriteURL) < 1 { log.Fatal("rw-url flag is unset") } - _, err := url.Parse(*flagRemoteWriteURL) + _, err = url.Parse(*flagRemoteWriteURL) if err != nil { log.Fatalf("invalid rw-url flag value: %v", err) } @@ -707,7 +810,7 @@ func main() { for _, v := range nodeMetaByAddr { staleMeta = append(staleMeta, v) } - staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, dstPorts) + staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, portsByProtocol) if len(staleMarkers) > 0 { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) rwc.write(ctx, staleMarkers) @@ -723,8 +826,8 @@ func main() { // in a higher probability of the packets traversing the same underlay path. // Comparison of stable and unstable 5-tuple results can shed light on // differences between paths where hashing (multipathing/load balancing) - // comes into play. - stableConns := make(map[netip.Addr]map[int][2]io.ReadWriteCloser) + // comes into play. The inner 2 element array index is timestampSource. + stableConns := make(map[stableConnKey][2]io.ReadWriteCloser) // timeouts holds counts of timeout events. Values are persisted for the // lifetime of the related node in the DERP map. @@ -738,7 +841,7 @@ func main() { for { select { case <-probeTicker.C: - results, err := probeNodes(nodeMetaByAddr, stableConns, dstPorts) + results, err := probeNodes(nodeMetaByAddr, stableConns, portsByProtocol) if err != nil { log.Printf("unrecoverable error while probing: %v", err) shutdown() @@ -761,7 +864,7 @@ func main() { log.Printf("error parsing DERP map, continuing with stale map: %v", err) continue } - staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, dstPorts) + staleMarkers := staleMarkersFromNodeMeta(staleMeta, *flagInstance, portsByProtocol) if len(staleMarkers) < 1 { continue } diff --git a/cmd/stunstamp/stunstamp_default.go b/cmd/stunstamp/stunstamp_default.go index 2fb69dc68..707035306 100644 --- a/cmd/stunstamp/stunstamp_default.go +++ b/cmd/stunstamp/stunstamp_default.go @@ -8,18 +8,18 @@ package main import ( "errors" "io" - "net" + "net/netip" "time" ) -func getConnKernelTimestamp() (io.ReadWriteCloser, error) { +func getUDPConnKernelTimestamp() (io.ReadWriteCloser, error) { return nil, errors.New("unimplemented") } -func measureRTTKernel(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) { +func measureSTUNRTTKernel(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) { return 0, errors.New("unimplemented") } -func supportsKernelTS() bool { +func protocolSupportsKernelTS(_ protocol) bool { return false } diff --git a/cmd/stunstamp/stunstamp_linux.go b/cmd/stunstamp/stunstamp_linux.go index 898ab19f1..1545c067f 100644 --- a/cmd/stunstamp/stunstamp_linux.go +++ b/cmd/stunstamp/stunstamp_linux.go @@ -10,7 +10,7 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "time" "github.com/mdlayher/socket" @@ -24,7 +24,7 @@ const ( unix.SOF_TIMESTAMPING_SOFTWARE // report software timestamps ) -func getConnKernelTimestamp() (io.ReadWriteCloser, error) { +func getUDPConnKernelTimestamp() (io.ReadWriteCloser, error) { sconn, err := socket.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP, "udp", nil) if err != nil { return nil, err @@ -56,24 +56,23 @@ func parseTimestampFromCmsgs(oob []byte) (time.Time, error) { return time.Time{}, errors.New("failed to parse timestamp from cmsgs") } -func measureRTTKernel(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) { +func measureSTUNRTTKernel(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) { sconn, ok := conn.(*socket.Conn) if !ok { return 0, fmt.Errorf("conn of unexpected type: %T", conn) } var to unix.Sockaddr - to4 := dst.IP.To4() - if to4 != nil { + if dst.Addr().Is4() { to = &unix.SockaddrInet4{ - Port: dst.Port, + Port: int(dst.Port()), } - copy(to.(*unix.SockaddrInet4).Addr[:], to4) + copy(to.(*unix.SockaddrInet4).Addr[:], dst.Addr().AsSlice()) } else { to = &unix.SockaddrInet6{ - Port: dst.Port, + Port: int(dst.Port()), } - copy(to.(*unix.SockaddrInet6).Addr[:], dst.IP) + copy(to.(*unix.SockaddrInet6).Addr[:], dst.Addr().AsSlice()) } txID := stun.NewTxID() @@ -138,6 +137,10 @@ func measureRTTKernel(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Durat } -func supportsKernelTS() bool { - return true +func protocolSupportsKernelTS(p protocol) bool { + if p == protocolSTUN { + return true + } + // TODO: jwhited support ICMP + return false }