From e6b84f2159bae0c8e26d093048cf313045039229 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 17 May 2020 09:51:38 -0700 Subject: [PATCH] all: make client use server-provided DERP map, add DERP region support Instead of hard-coding the DERP map (except for cmd/tailscale netcheck for now), get it from the control server at runtime. And make the DERP map support multiple nodes per region with clients picking the first one that's available. (The server will balance the order presented to clients for load balancing) This deletes the stunner package, merging it into the netcheck package instead, to minimize all the config hooks that would've been required. Also fix some test flakes & races. Fixes #387 (Don't hard-code the DERP map) Updates #388 (Add DERP region support) Fixes #399 (wgengine: flaky tests) Signed-off-by: Brad Fitzpatrick --- cmd/tailscale/netcheck.go | 26 +- control/controlclient/direct.go | 8 + control/controlclient/netmap.go | 4 + derp/derphttp/derphttp_client.go | 253 ++++++- derp/derpmap/derpmap.go | 180 ++--- go.mod | 1 - go.sum | 2 - ipn/local.go | 8 +- netcheck/netcheck.go | 945 ++++++++++++++++++--------- netcheck/netcheck_test.go | 324 ++++++--- stun/stuntest/stuntest.go | 60 +- stunner/stunner.go | 310 --------- stunner/stunner_test.go | 154 ----- tailcfg/derpmap.go | 32 + tailcfg/tailcfg.go | 7 +- wgengine/magicsock/magicsock.go | 198 +++--- wgengine/magicsock/magicsock_test.go | 68 +- wgengine/userspace.go | 26 +- wgengine/watchdog.go | 5 +- wgengine/wgengine.go | 7 +- 20 files changed, 1428 insertions(+), 1190 deletions(-) delete mode 100644 stunner/stunner.go delete mode 100644 stunner/stunner_test.go diff --git a/cmd/tailscale/netcheck.go b/cmd/tailscale/netcheck.go index c58103dce..7de01f72a 100644 --- a/cmd/tailscale/netcheck.go +++ b/cmd/tailscale/netcheck.go @@ -9,6 +9,7 @@ import ( "fmt" "log" "sort" + "time" "github.com/peterbourgon/ff/v2/ffcli" "tailscale.com/derp/derpmap" @@ -26,12 +27,12 @@ var netcheckCmd = &ffcli.Command{ func runNetcheck(ctx context.Context, args []string) error { c := &netcheck.Client{ - DERP: derpmap.Prod(), Logf: logger.WithPrefix(log.Printf, "netcheck: "), DNSCache: dnscache.Get(), } - report, err := c.GetReport(ctx) + dm := derpmap.Prod() + report, err := c.GetReport(ctx, dm) if err != nil { log.Fatalf("netcheck: %v", err) } @@ -55,18 +56,23 @@ func runNetcheck(ctx context.Context, args []string) error { // When DERP latency checking failed, // magicsock will try to pick the DERP server that // most of your other nodes are also using - if len(report.DERPLatency) == 0 { + if len(report.RegionLatency) == 0 { fmt.Printf("\t* Nearest DERP: unknown (no response to latency probes)\n") } else { - fmt.Printf("\t* Nearest DERP: %v (%v)\n", report.PreferredDERP, c.DERP.LocationOfID(report.PreferredDERP)) + fmt.Printf("\t* Nearest DERP: %v (%v)\n", report.PreferredDERP, dm.Regions[report.PreferredDERP].RegionCode) fmt.Printf("\t* DERP latency:\n") - var ss []string - for s := range report.DERPLatency { - ss = append(ss, s) + var rids []int + for rid := range dm.Regions { + rids = append(rids, rid) } - sort.Strings(ss) - for _, s := range ss { - fmt.Printf("\t\t- %s = %v\n", s, report.DERPLatency[s]) + sort.Ints(rids) + for _, rid := range rids { + d, ok := report.RegionLatency[rid] + var latency string + if ok { + latency = d.Round(time.Millisecond / 10).String() + } + fmt.Printf("\t\t- %v, %3s = %s\n", rid, dm.Regions[rid].RegionCode, latency) } } return nil diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 0e7afb6f8..9078e9374 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -541,6 +541,8 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM } }() + var lastDERPMap *tailcfg.DERPMap + // If allowStream, then the server will use an HTTP long poll to // return incremental results. There is always one response right // away, followed by a delay, and eventually others. @@ -582,6 +584,11 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM } vlogf("netmap: got new map") + if resp.DERPMap != nil { + vlogf("netmap: new map contains DERP map") + lastDERPMap = resp.DERPMap + } + nm := &NetworkMap{ NodeKey: tailcfg.NodeKey(persist.PrivateNodeKey.Public()), PrivateKey: persist.PrivateNodeKey, @@ -597,6 +604,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM DNSDomains: resp.SearchPaths, Hostinfo: resp.Node.Hostinfo, PacketFilter: c.parsePacketFilter(resp.PacketFilter), + DERPMap: lastDERPMap, } for _, profile := range resp.UserProfiles { nm.UserProfiles[profile.ID] = profile diff --git a/control/controlclient/netmap.go b/control/controlclient/netmap.go index b78dd6afe..764c56b38 100644 --- a/control/controlclient/netmap.go +++ b/control/controlclient/netmap.go @@ -33,6 +33,10 @@ type NetworkMap struct { Hostinfo tailcfg.Hostinfo PacketFilter filter.Matches + // DERPMap is the last DERP server map received. It's reused + // between updates and should not be modified. + DERPMap *tailcfg.DERPMap + // ACLs User tailcfg.UserID diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index bcaa08ea4..c63644bb1 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -24,9 +24,11 @@ import ( "sync" "time" + "inet.af/netaddr" "tailscale.com/derp" "tailscale.com/net/dnscache" "tailscale.com/net/tlsdial" + "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -43,7 +45,10 @@ type Client struct { privateKey key.Private logf logger.Logf - url *url.URL + + // Either url or getRegion is non-nil: + url *url.URL + getRegion func() *tailcfg.DERPRegion ctx context.Context // closed via cancelCtx in Client.Close cancelCtx context.CancelFunc @@ -55,8 +60,22 @@ type Client struct { client *derp.Client } +// NewRegionClient returns a new DERP-over-HTTP client. It connects lazily. +// To trigger a connection, use Connect. +func NewRegionClient(privateKey key.Private, logf logger.Logf, getRegion func() *tailcfg.DERPRegion) *Client { + ctx, cancel := context.WithCancel(context.Background()) + c := &Client{ + privateKey: privateKey, + logf: logf, + getRegion: getRegion, + ctx: ctx, + cancelCtx: cancel, + } + return c +} + // NewClient returns a new DERP-over-HTTP client. It connects lazily. -// To trigger a connection use Connect. +// To trigger a connection, use Connect. func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Client, error) { u, err := url.Parse(serverURL) if err != nil { @@ -65,6 +84,7 @@ func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Cli if urlPort(u) == "" { return nil, fmt.Errorf("derphttp.NewClient: invalid URL scheme %q", u.Scheme) } + ctx, cancel := context.WithCancel(context.Background()) c := &Client{ privateKey: privateKey, @@ -101,6 +121,37 @@ func urlPort(u *url.URL) string { return "" } +func (c *Client) targetString(reg *tailcfg.DERPRegion) string { + if c.url != nil { + return c.url.String() + } + return fmt.Sprintf("region %d (%v)", reg.RegionID, reg.RegionCode) +} + +func (c *Client) useHTTPS() bool { + if c.url != nil && c.url.Scheme == "http" { + return false + } + return true +} + +func (c *Client) tlsServerName(node *tailcfg.DERPNode) string { + if c.url != nil { + return c.url.Host + } + if node.CertName != "" { + return node.CertName + } + return node.HostName +} + +func (c *Client) urlString(node *tailcfg.DERPNode) string { + if c.url != nil { + return c.url.String() + } + return fmt.Sprintf("https://%s/derp", node.HostName) +} + func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) { c.mu.Lock() defer c.mu.Unlock() @@ -111,8 +162,6 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien return c.client, nil } - c.logf("%s: connecting to %v", caller, c.url) - // timeout is the fallback maximum time (if ctx doesn't limit // it further) to do all of: DNS + TCP + TLS + HTTP Upgrade + // DERP upgrade. @@ -132,46 +181,42 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien }() defer cancel() + var reg *tailcfg.DERPRegion // nil when using c.url to dial + if c.getRegion != nil { + reg = c.getRegion() + if reg == nil { + return nil, errors.New("DERP region not available") + } + } + var tcpConn net.Conn + defer func() { if err != nil { if ctx.Err() != nil { err = fmt.Errorf("%v: %v", ctx.Err(), err) } - err = fmt.Errorf("%s connect to %v: %v", caller, c.url, err) + err = fmt.Errorf("%s connect to %v: %v", caller, c.targetString(reg), err) if tcpConn != nil { go tcpConn.Close() } } }() - host := c.url.Hostname() - hostOrIP := host - - var stdDialer dialer = new(net.Dialer) - var dialer = stdDialer - if wrapDialer != nil { - dialer = wrapDialer(dialer) - } - - if c.DNSCache != nil { - ip, err := c.DNSCache.LookupIP(ctx, host) - if err == nil { - hostOrIP = ip.String() - } - if err != nil && dialer == stdDialer { - // Return an error if we're not using a dial - // proxy that can do DNS lookups for us. - return nil, err - } + var node *tailcfg.DERPNode // nil when using c.url to dial + if c.url != nil { + c.logf("%s: connecting to %v", caller, c.url) + tcpConn, err = c.dialURL(ctx) + } else { + c.logf("%s: connecting to derp-%d (%v)", caller, reg.RegionID, reg.RegionCode) + tcpConn, node, err = c.dialRegion(ctx, reg) } - - tcpConn, err = dialer.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url))) if err != nil { - return nil, fmt.Errorf("dial of %q: %v", host, err) + return nil, err } - // Now that we have a TCP connection, force close it. + // Now that we have a TCP connection, force close it if the + // TLS handshake + DERP setup takes too long. done := make(chan struct{}) defer close(done) go func() { @@ -195,15 +240,19 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien }() var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to - if c.url.Scheme == "https" { - httpConn = tls.Client(tcpConn, tlsdial.Config(c.url.Host, c.TLSConfig)) + if c.useHTTPS() { + tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig) + if node != nil && node.DERPTestPort != 0 { + tlsConf.InsecureSkipVerify = true + } + httpConn = tls.Client(tcpConn, tlsConf) } else { httpConn = tcpConn } brw := bufio.NewReadWriter(bufio.NewReader(httpConn), bufio.NewWriter(httpConn)) - req, err := http.NewRequest("GET", c.url.String(), nil) + req, err := http.NewRequest("GET", c.urlString(node), nil) if err != nil { return nil, err } @@ -243,6 +292,148 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien return c.client, nil } +func (c *Client) dialURL(ctx context.Context) (net.Conn, error) { + host := c.url.Hostname() + hostOrIP := host + + var stdDialer dialer = new(net.Dialer) + var dialer = stdDialer + if wrapDialer != nil { + dialer = wrapDialer(dialer) + } + + if c.DNSCache != nil { + ip, err := c.DNSCache.LookupIP(ctx, host) + if err == nil { + hostOrIP = ip.String() + } + if err != nil && dialer == stdDialer { + // Return an error if we're not using a dial + // proxy that can do DNS lookups for us. + return nil, err + } + } + + tcpConn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url))) + if err != nil { + return nil, fmt.Errorf("dial of %v: %v", host, err) + } + return tcpConn, nil +} + +// dialRegion returns a TCP connection to the provided region, trying +// each node in order (with dialNode) until one connects or ctx is +// done. +func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.Conn, *tailcfg.DERPNode, error) { + if len(reg.Nodes) == 0 { + return nil, nil, fmt.Errorf("no nodes for %s", c.targetString(reg)) + } + var firstErr error + for _, n := range reg.Nodes { + if n.STUNOnly { + continue + } + c, err := c.dialNode(ctx, n) + if err == nil { + return c, n, nil + } + if firstErr == nil { + firstErr = err + } + } + return nil, nil, firstErr +} + +func (c *Client) dialContext(ctx context.Context, proto, addr string) (net.Conn, error) { + var stdDialer dialer = new(net.Dialer) + var dialer = stdDialer + if wrapDialer != nil { + dialer = wrapDialer(dialer) + } + return dialer.DialContext(ctx, proto, addr) +} + +// shouldDialProto reports whether an explicitly provided IPv4 or IPv6 +// address (given in s) is valid. An empty value means to dial, but to +// use DNS. The predicate function reports whether the non-empty +// string s contained a valid IP address of the right family. +func shouldDialProto(s string, pred func(netaddr.IP) bool) bool { + if s == "" { + return true + } + ip, _ := netaddr.ParseIP(s) + return pred(ip) +} + +const dialNodeTimeout = 1500 * time.Millisecond + +// dialNode returns a TCP connection to node n, racing IPv4 and IPv6 +// (both as applicable) against each other. +// A node is only given dialNodeTimeout to connect. +// +// TODO(bradfitz): longer if no options remain perhaps? ... Or longer +// overall but have dialRegion start overlapping races? +func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, error) { + type res struct { + c net.Conn + err error + } + resc := make(chan res) // must be unbuffered + ctx, cancel := context.WithTimeout(ctx, dialNodeTimeout) + defer cancel() + + nwait := 0 + startDial := func(dstPrimary, proto string) { + nwait++ + go func() { + dst := dstPrimary + if dst == "" { + dst = n.HostName + } + port := "443" + if n.DERPTestPort != 0 { + port = fmt.Sprint(n.DERPTestPort) + } + c, err := c.dialContext(ctx, proto, net.JoinHostPort(dst, port)) + select { + case resc <- res{c, err}: + case <-ctx.Done(): + if c != nil { + c.Close() + } + } + }() + } + if shouldDialProto(n.IPv4, netaddr.IP.Is4) { + startDial(n.IPv4, "tcp4") + } + if shouldDialProto(n.IPv6, netaddr.IP.Is6) { + startDial(n.IPv6, "tcp6") + } + if nwait == 0 { + return nil, errors.New("both IPv4 and IPv6 are explicitly disabled for node") + } + + var firstErr error + for { + select { + case res := <-resc: + nwait-- + if res.err == nil { + return res.c, nil + } + if firstErr == nil { + firstErr = res.err + } + if nwait == 0 { + return nil, firstErr + } + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + func (c *Client) Send(dstKey key.Public, b []byte) error { client, err := c.connect(context.TODO(), "derphttp.Client.Send") if err != nil { diff --git a/derp/derpmap/derpmap.go b/derp/derpmap/derpmap.go index e0c55eb4b..587f8b61c 100644 --- a/derp/derpmap/derpmap.go +++ b/derp/derpmap/derpmap.go @@ -7,151 +7,59 @@ package derpmap import ( "fmt" - "net" + "strings" - "tailscale.com/types/structs" + "tailscale.com/tailcfg" ) -// World is a set of DERP server. -type World struct { - servers []*Server - ids []int - byID map[int]*Server - stun4 []string - stun6 []string -} - -func (w *World) IDs() []int { return w.ids } -func (w *World) STUN4() []string { return w.stun4 } -func (w *World) STUN6() []string { return w.stun6 } -func (w *World) ServerByID(id int) *Server { return w.byID[id] } - -// LocationOfID returns the geographic name of a node, if present. -func (w *World) LocationOfID(id int) string { - if s, ok := w.byID[id]; ok { - return s.Geo +func derpNode(suffix, v4, v6 string) *tailcfg.DERPNode { + return &tailcfg.DERPNode{ + Name: suffix, // updated later + RegionID: 0, // updated later + IPv4: v4, + IPv6: v6, } - return "" } -func (w *World) NodeIDOfSTUNServer(server string) int { - // TODO: keep reverse map? Small enough to not matter for now. - for _, s := range w.servers { - if s.STUN4 == server || s.STUN6 == server { - return s.ID - } +func derpRegion(id int, code string, nodes ...*tailcfg.DERPNode) *tailcfg.DERPRegion { + region := &tailcfg.DERPRegion{ + RegionID: id, + RegionCode: code, + Nodes: nodes, } - return 0 -} - -// ForeachServer calls fn for each DERP server, in an unspecified order. -func (w *World) ForeachServer(fn func(*Server)) { - for _, s := range w.byID { - fn(s) + for _, n := range nodes { + n.Name = fmt.Sprintf("%d%s", id, n.Name) + n.RegionID = id + n.HostName = fmt.Sprintf("derp%s.tailscale.com", strings.TrimSuffix(n.Name, "a")) } + return region } -// Prod returns the production DERP nodes. -func Prod() *World { - return prod -} - -func NewTestWorld(stun ...string) *World { - w := &World{} - for i, s := range stun { - w.add(&Server{ - ID: i + 1, - Geo: fmt.Sprintf("Testopolis-%d", i+1), - STUN4: s, - }) - } - return w -} - -func NewTestWorldWith(servers ...*Server) *World { - w := &World{} - for _, s := range servers { - w.add(s) - } - return w -} - -var prod = new(World) // ... a dazzling place I never knew - -func addProd(id int, geo string) { - prod.add(&Server{ - ID: id, - Geo: geo, - HostHTTPS: fmt.Sprintf("derp%v.tailscale.com", id), - STUN4: fmt.Sprintf("derp%v.tailscale.com:3478", id), - STUN6: fmt.Sprintf("derp%v-v6.tailscale.com:3478", id), - }) -} - -func (w *World) add(s *Server) { - if s.ID == 0 { - panic("ID required") - } - if _, dup := w.byID[s.ID]; dup { - panic("duplicate prod server") - } - if w.byID == nil { - w.byID = make(map[int]*Server) - } - w.byID[s.ID] = s - w.ids = append(w.ids, s.ID) - w.servers = append(w.servers, s) - if s.STUN4 != "" { - w.stun4 = append(w.stun4, s.STUN4) - if _, _, err := net.SplitHostPort(s.STUN4); err != nil { - panic("not a host:port: " + s.STUN4) - } - } - if s.STUN6 != "" { - w.stun6 = append(w.stun6, s.STUN6) - if _, _, err := net.SplitHostPort(s.STUN6); err != nil { - panic("not a host:port: " + s.STUN6) - } - } -} - -func init() { - addProd(1, "New York") - addProd(2, "San Francisco") - addProd(3, "Singapore") - addProd(4, "Frankfurt") - addProd(5, "Sydney") -} - -// Server is configuration for a DERP server. -type Server struct { - _ structs.Incomparable - - ID int - - // HostHTTPS is the HTTPS hostname. - HostHTTPS string - - // STUN4 is the host:port of the IPv4 STUN server on this DERP - // node. Required. - STUN4 string - - // STUN6 optionally provides the IPv6 host:port of the STUN - // server on the DERP node. - // It should be an IPv6-only address for now. (We currently make lazy - // assumptions that the server names are unique.) - STUN6 string - - // Geo is a human-readable geographic region name of this server. - Geo string -} - -func (s *Server) String() string { - if s == nil { - return "" - } - if s.Geo != "" { - return fmt.Sprintf("%v (%v)", s.HostHTTPS, s.Geo) +// Prod returns Tailscale's map of relay servers. +// +// This list is only used by cmd/tailscale's netcheck subcommand. In +// normal operation the Tailscale nodes get this sent to them from the +// control server. +// +// This list is subject to change and should not be relied on. +func Prod() *tailcfg.DERPMap { + return &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: derpRegion(1, "nyc", + derpNode("a", "159.89.225.99", "2604:a880:400:d1::828:b001"), + ), + 2: derpRegion(2, "sfo", + derpNode("a", "167.172.206.31", "2604:a880:2:d1::c5:7001"), + ), + 3: derpRegion(3, "sin", + derpNode("a", "68.183.179.66", "2400:6180:0:d1::67d:8001"), + ), + 4: derpRegion(4, "fra", + derpNode("a", "167.172.182.26", "2a03:b0c0:3:e0::36e:9001"), + ), + 5: derpRegion(5, "syd", + derpNode("a", "103.43.75.49", "2001:19f0:5801:10b7:5400:2ff:feaa:284c"), + ), + }, } - return s.HostHTTPS } diff --git a/go.mod b/go.mod index f3ab8dc99..6d57af576 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,6 @@ require ( golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e golang.org/x/sys v0.0.0-20200501052902-10377860bb8e golang.org/x/time v0.0.0-20191024005414-555d28b269f0 - gortc.io/stun v1.22.1 inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc rsc.io/goversion v1.2.0 ) diff --git a/go.sum b/go.sum index e091d13bd..1ccc12628 100644 --- a/go.sum +++ b/go.sum @@ -142,8 +142,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gortc.io/stun v1.22.1 h1:96mOdDATYRqhYB+TZdenWBg4CzL2Ye5kPyBXQ8KAB+8= -gortc.io/stun v1.22.1/go.mod h1:XD5lpONVyjvV3BgOyJFNo0iv6R2oZB4L+weMqxts+zg= inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc h1:We3b/z+7i9LV4Ls0yWve5vYIlnAPSPeqxKVgZseRDBs= inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww= rsc.io/goversion v1.2.0 h1:SPn+NLTiAG7w30IRK/DKp1BjvpWabYgxlLp/+kx5J8w= diff --git a/ipn/local.go b/ipn/local.go index 1a4e091f4..e608fc601 100644 --- a/ipn/local.go +++ b/ipn/local.go @@ -240,10 +240,8 @@ func (b *LocalBackend) Start(opts Options) error { b.notify = opts.Notify b.netMapCache = nil persist := b.prefs.Persist - wantDERP := !b.prefs.DisableDERP b.mu.Unlock() - b.e.SetDERPEnabled(wantDERP) b.updateFilter(nil) var err error @@ -307,11 +305,17 @@ func (b *LocalBackend) Start(opts Options) error { b.logf("netmap diff:\n%v", diff) } } + disableDERP := b.prefs != nil && b.prefs.DisableDERP b.netMapCache = newSt.NetMap b.mu.Unlock() b.send(Notify{NetMap: newSt.NetMap}) b.updateFilter(newSt.NetMap) + if disableDERP { + b.e.SetDERPMap(nil) + } else { + b.e.SetDERPMap(newSt.NetMap.DERPMap) + } } if newSt.URL != "" { b.logf("Received auth URL: %.20v...", newSt.URL) diff --git a/netcheck/netcheck.go b/netcheck/netcheck.go index bbd1f163b..f74bae5ec 100644 --- a/netcheck/netcheck.go +++ b/netcheck/netcheck.go @@ -20,26 +20,28 @@ import ( "time" "github.com/tcnksm/go-httpstat" - "golang.org/x/sync/errgroup" - "tailscale.com/derp/derpmap" + "inet.af/netaddr" "tailscale.com/net/dnscache" "tailscale.com/net/interfaces" "tailscale.com/stun" - "tailscale.com/stunner" + "tailscale.com/syncs" + "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/types/opt" ) type Report struct { - UDP bool // UDP works - IPv6 bool // IPv6 works - MappingVariesByDestIP opt.Bool // for IPv4 - HairPinning opt.Bool // for IPv4 - PreferredDERP int // or 0 for unknown - DERPLatency map[string]time.Duration // keyed by STUN host:port + UDP bool // UDP works + IPv6 bool // IPv6 works + MappingVariesByDestIP opt.Bool // for IPv4 + HairPinning opt.Bool // for IPv4 + PreferredDERP int // or 0 for unknown + RegionLatency map[int]time.Duration // keyed by DERP Region ID + RegionV4Latency map[int]time.Duration // keyed by DERP Region ID + RegionV6Latency map[int]time.Duration // keyed by DERP Region ID GlobalV4 string // ip:port of global IPv4 - GlobalV6 string // [ip]:port of global IPv6 // TODO + GlobalV6 string // [ip]:port of global IPv6 // TODO: update Clone when adding new fields } @@ -49,40 +51,50 @@ func (r *Report) Clone() *Report { return nil } r2 := *r - if r2.DERPLatency != nil { - r2.DERPLatency = map[string]time.Duration{} - for k, v := range r.DERPLatency { - r2.DERPLatency[k] = v - } - } + r2.RegionLatency = cloneDurationMap(r2.RegionLatency) + r2.RegionV4Latency = cloneDurationMap(r2.RegionV4Latency) + r2.RegionV6Latency = cloneDurationMap(r2.RegionV6Latency) return &r2 } +func cloneDurationMap(m map[int]time.Duration) map[int]time.Duration { + if m == nil { + return nil + } + m2 := make(map[int]time.Duration, len(m)) + for k, v := range m { + m2[k] = v + } + return m2 +} + // Client generates a netcheck Report. type Client struct { - // DERP is the DERP world to use. - DERP *derpmap.World - // DNSCache optionally specifies a DNSCache to use. // If nil, a DNS cache is not used. DNSCache *dnscache.Resolver // Logf optionally specifies where to log to. + // If nil, log.Printf is used. Logf logger.Logf // TimeNow, if non-nil, is used instead of time.Now. TimeNow func() time.Time + // GetSTUNConn4 optionally provides a func to return the + // connection to use for sending & receiving IPv4 packets. If + // nil, an emphemeral one is created as needed. GetSTUNConn4 func() STUNConn + + // GetSTUNConn6 is like GetSTUNConn4, but for IPv6. GetSTUNConn6 func() STUNConn - mu sync.Mutex // guards following - prev map[time.Time]*Report // some previous reports - last *Report // most recent report - s4 *stunner.Stunner - s6 *stunner.Stunner - hairTX stun.TxID - gotHairSTUN chan *net.UDPAddr // non-nil if we're in GetReport + mu sync.Mutex // guards following + nextFull bool // do a full region scan, even if last != nil + prev map[time.Time]*Report // some previous reports + last *Report // most recent report + lastFull time.Time // time of last full (non-incremental) report + curState *reportState // non-nil if we're in a call to GetReportn } // STUNConn is the interface required by the netcheck Client when @@ -102,16 +114,14 @@ func (c *Client) logf(format string, a ...interface{}) { // handleHairSTUN reports whether pkt (from src) was our magic hairpin // probe packet that we sent to ourselves. -func (c *Client) handleHairSTUN(pkt []byte, src *net.UDPAddr) bool { - c.mu.Lock() - defer c.mu.Unlock() - return c.handleHairSTUNLocked(pkt, src) -} - func (c *Client) handleHairSTUNLocked(pkt []byte, src *net.UDPAddr) bool { - if tx, err := stun.ParseBindingRequest(pkt); err == nil && tx == c.hairTX { + rs := c.curState + if rs == nil { + return false + } + if tx, err := stun.ParseBindingRequest(pkt); err == nil && tx == rs.hairTX { select { - case c.gotHairSTUN <- src: + case rs.gotHairSTUN <- src: default: } return true @@ -119,381 +129,570 @@ func (c *Client) handleHairSTUNLocked(pkt []byte, src *net.UDPAddr) bool { return false } +// MakeNextReportFull forces the next GetReport call to be a full +// (non-incremental) probe of all DERP regions. +func (c *Client) MakeNextReportFull() { + c.mu.Lock() + c.nextFull = true + c.mu.Unlock() +} + func (c *Client) ReceiveSTUNPacket(pkt []byte, src *net.UDPAddr) { if src == nil || src.IP == nil { panic("bogus src") } c.mu.Lock() - if c.handleHairSTUNLocked(pkt, src) { c.mu.Unlock() return } + rs := c.curState + c.mu.Unlock() - var st *stunner.Stunner - if src.IP.To4() != nil { - st = c.s4 - } else { - st = c.s6 + if rs == nil { + return } - c.mu.Unlock() + tx, addr, port, err := stun.ParseResponse(pkt) + if err != nil { + c.mu.Unlock() + if _, err := stun.ParseBindingRequest(pkt); err == nil { + // This was probably our own netcheck hairpin + // check probe coming in late. Ignore. + return + } + c.logf("netcheck: received unexpected STUN message response from %v: %v", src, err) + return + } - if st != nil { - st.Receive(pkt, src) + rs.mu.Lock() + onDone, ok := rs.inFlight[tx] + if ok { + delete(rs.inFlight, tx) + } + rs.mu.Unlock() + if ok { + if ipp, ok := netaddr.FromStdAddr(addr, int(port), ""); ok { + onDone(ipp) + } } } -// pickSubset selects a subset of IPv4 and IPv6 STUN server addresses -// to hit based on history. +// probeProto is the protocol used to time a node's latency. +type probeProto uint8 + +const ( + probeIPv4 probeProto = iota // STUN IPv4 + probeIPv6 // STUN IPv6 + probeHTTPS // HTTPS +) + +type probe struct { + // delay is when the probe is started, relative to the time + // that GetReport is called. One probe in each probePlan + // should have a delay of 0. Non-zero values are for retries + // on UDP loss or timeout. + delay time.Duration + + // node is the name of the node name. DERP node names are globally + // unique so there's no region ID. + node string + + // proto is how the node should be probed. + proto probeProto + + // wait is how long to wait until the probe is considered failed. + // 0 means to use a default value. + wait time.Duration +} + +// probePlan is a set of node probes to run. +// The map key is a descriptive name, only used for tests. // -// maxTries is the max number of tries per server. +// The values are logically an unordered set of tests to run concurrently. +// In practice there's some order to them based on their delay fields, +// but multiple probes can have the same delay time or be running concurrently +// both within and between sets. // -// The caller owns the returned values. -func (c *Client) pickSubset() (stuns4, stuns6 []string, maxTries map[string]int, err error) { - c.mu.Lock() - defer c.mu.Unlock() +// A set of probes is done once either one of the probes completes, or +// the next probe to run wouldn't yield any new information not +// already discovered by any previous probe in any set. +type probePlan map[string][]probe + +// sortRegions returns the regions of dm first sorted +// from fastest to slowest (based on the 'last' report), +// end in regions that have no data. +func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) { + prev = make([]*tailcfg.DERPRegion, 0, len(dm.Regions)) + for _, reg := range dm.Regions { + prev = append(prev, reg) + } + sort.Slice(prev, func(i, j int) bool { + da, db := last.RegionLatency[prev[i].RegionID], last.RegionLatency[prev[j].RegionID] + if db == 0 && da != 0 { + // Non-zero sorts before zero. + return true + } + if da == 0 { + // Zero can't sort before anything else. + return false + } + return da < db + }) + return prev +} + +// numIncrementalRegions is the number of fastest regions to +// periodically re-query during incremental netcheck reports. (During +// a full report, all regions are scanned.) +const numIncrementalRegions = 3 + +// makeProbePlan generates the probe plan for a DERPMap, given the most +// recent report and whether IPv6 is configured on an interface. +func makeProbePlan(dm *tailcfg.DERPMap, have6if bool, last *Report) (plan probePlan) { + if last == nil || len(last.RegionLatency) == 0 { + return makeProbePlanInitial(dm, have6if) + } + plan = make(probePlan) + had4 := len(last.RegionV4Latency) > 0 + had6 := len(last.RegionV6Latency) > 0 + hadBoth := have6if && had4 && had6 + for ri, reg := range sortRegions(dm, last) { + if ri == numIncrementalRegions { + break + } + var p4, p6 []probe + do4 := true + do6 := have6if + + // By default, each node only gets one STUN packet sent, + // except the fastest two from the previous round. + tries := 1 + isFastestTwo := ri < 2 + + if isFastestTwo { + tries = 2 + } else if hadBoth { + // For dual stack machines, make the 3rd & slower nodes alternate + // breetween + if ri%2 == 0 { + do4, do6 = true, false + } else { + do4, do6 = false, true + } + } + if !isFastestTwo && !had6 { + do6 = false + } - const defaultMaxTries = 2 - maxTries = map[string]int{} + for try := 0; try < tries; try++ { + if len(reg.Nodes) == 0 { + // Shouldn't be possible. + continue + } + if try != 0 && !had6 { + do6 = false + } + n := reg.Nodes[try%len(reg.Nodes)] + prevLatency := last.RegionLatency[reg.RegionID] * 120 / 100 + if prevLatency == 0 { + prevLatency = 200 * time.Millisecond + } + delay := time.Duration(try) * prevLatency + if do4 { + p4 = append(p4, probe{delay: delay, node: n.Name, proto: probeIPv4}) + } + if do6 { + p6 = append(p6, probe{delay: delay, node: n.Name, proto: probeIPv6}) + } + } + if len(p4) > 0 { + plan[fmt.Sprintf("region-%d-v4", reg.RegionID)] = p4 + } + if len(p6) > 0 { + plan[fmt.Sprintf("region-%d-v6", reg.RegionID)] = p6 + } + } + return plan +} - var prev4, prev6 []string // sorted fastest to slowest - if c.last != nil { - condAppend := func(dst []string, server string) []string { - if server != "" && c.last.DERPLatency[server] != 0 { - return append(dst, server) +func makeProbePlanInitial(dm *tailcfg.DERPMap, have6if bool) (plan probePlan) { + plan = make(probePlan) + + // initialSTUNTimeout is only 100ms because some extra retransmits + // when starting up is tolerable. + const initialSTUNTimeout = 100 * time.Millisecond + + for _, reg := range dm.Regions { + var p4 []probe + var p6 []probe + for try := 0; try < 3; try++ { + n := reg.Nodes[try%len(reg.Nodes)] + delay := time.Duration(try) * initialSTUNTimeout + if nodeMight4(n) { + p4 = append(p4, probe{delay: delay, node: n.Name, proto: probeIPv4}) + } + if have6if && nodeMight6(n) { + p6 = append(p6, probe{delay: delay, node: n.Name, proto: probeIPv6}) } - return dst } - c.DERP.ForeachServer(func(s *derpmap.Server) { - prev4 = condAppend(prev4, s.STUN4) - prev6 = condAppend(prev6, s.STUN6) - }) - sort.Slice(prev4, func(i, j int) bool { return c.last.DERPLatency[prev4[i]] < c.last.DERPLatency[prev4[j]] }) - sort.Slice(prev6, func(i, j int) bool { return c.last.DERPLatency[prev6[i]] < c.last.DERPLatency[prev6[j]] }) + if len(p4) > 0 { + plan[fmt.Sprintf("region-%d-v4", reg.RegionID)] = p4 + } + if len(p6) > 0 { + plan[fmt.Sprintf("region-%d-v6", reg.RegionID)] = p6 + } + } + return plan +} + +// nodeMight6 reports whether n might reply to STUN over IPv6 based on +// its config alone, without DNS lookups. It only returns false if +// it's not explicitly disabled. +func nodeMight6(n *tailcfg.DERPNode) bool { + if n.IPv6 == "" { + return true + } + ip, _ := netaddr.ParseIP(n.IPv6) + return ip.Is6() + +} + +// nodeMight4 reports whether n might reply to STUN over IPv4 based on +// its config alone, without DNS lookups. It only returns false if +// it's not explicitly disabled. +func nodeMight4(n *tailcfg.DERPNode) bool { + if n.IPv4 == "" { + return true } + ip, _ := netaddr.ParseIP(n.IPv4) + return ip.Is4() +} - c.DERP.ForeachServer(func(s *derpmap.Server) { - if s.STUN4 == "" { +// readPackets reads STUN packets from pc until there's an error or ctx is done. +// In either case, it closes pc. +func (c *Client) readPackets(ctx context.Context, pc net.PacketConn) { + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + case <-done: + } + pc.Close() + }() + + var buf [64 << 10]byte + for { + n, addr, err := pc.ReadFrom(buf[:]) + if err != nil { + if ctx.Err() != nil { + return + } + c.logf("ReadFrom: %v", err) return } - // STUN against all DERP's IPv4 endpoints, but - // if the previous report had results from - // more than 2 servers, only do 1 try against - // all but the first two. - stuns4 = append(stuns4, s.STUN4) - tries := defaultMaxTries - if len(prev4) > 2 && !stringsContains(prev4[:2], s.STUN4) { - tries = 1 - } - maxTries[s.STUN4] = tries - if s.STUN6 != "" && tries == defaultMaxTries { - // For IPv6, we mostly care whether the user has IPv6 at all, - // so we don't need to send to all servers. The IPv4 timing - // information is enough for now. (We don't yet support IPv6-only) - // So only add the two fastest ones, or all if this is a fresh one. - stuns6 = append(stuns6, s.STUN6) - maxTries[s.STUN6] = 1 + ua, ok := addr.(*net.UDPAddr) + if !ok { + c.logf("ReadFrom: unexpected addr %T", addr) + continue } - }) + pkt := buf[:n] + if !stun.Is(pkt) { + continue + } + c.ReceiveSTUNPacket(pkt, ua) + } +} - if len(stuns4) == 0 { - // TODO: make this work? if we ever need it - // to. Requirement for self-hosted Tailscale might be - // to run a DERP+STUN server co-resident with the - // Control server. - return nil, nil, nil, errors.New("netcheck: GetReport: no STUN servers, no Report") +// reportState holds the state for a single invocation of Client.GetReport. +type reportState struct { + c *Client + hairTX stun.TxID + gotHairSTUN chan *net.UDPAddr + hairTimeout chan struct{} // closed on timeout + pc4 STUNConn + pc6 STUNConn + pc4Hair net.PacketConn + + mu sync.Mutex + sentHairCheck bool + report *Report // to be returned by GetReport + inFlight map[stun.TxID]func(netaddr.IPPort) // called without c.mu held + gotEP4 string +} + +func (rs *reportState) anyUDP() bool { + rs.mu.Lock() + defer rs.mu.Unlock() + return rs.report.UDP +} + +func (rs *reportState) haveRegionLatency(regionID int) bool { + rs.mu.Lock() + defer rs.mu.Unlock() + _, ok := rs.report.RegionLatency[regionID] + return ok +} + +// probeWouldHelp reports whether executing the given probe would +// yield any new information. +// The given node is provided just because the sole caller already has it +// and it saves a lookup. +func (rs *reportState) probeWouldHelp(probe probe, node *tailcfg.DERPNode) bool { + rs.mu.Lock() + defer rs.mu.Unlock() + + // If the probe is for a region we don't yet know about, that + // would help. + if _, ok := rs.report.RegionLatency[node.RegionID]; !ok { + return true + } + + // If the probe is for IPv6 and we don't yet have an IPv6 + // report, that would help. + if probe.proto == probeIPv6 && len(rs.report.RegionV6Latency) == 0 { + return true + } + + // For IPv4, we need at least two IPv4 results overall to + // determine whether we're behind a NAT that shows us as + // different source IPs and/or ports depending on who we're + // talking to. If we don't yet have two results yet + // (MappingVariesByDestIP is blank), then another IPv4 probe + // would be good. + if probe.proto == probeIPv4 && rs.report.MappingVariesByDestIP == "" { + return true + } + + // Otherwise not interesting. + return false +} + +func (rs *reportState) startHairCheckLocked(dst netaddr.IPPort) { + if rs.sentHairCheck { + return + } + rs.sentHairCheck = true + rs.pc4Hair.WriteTo(stun.Request(rs.hairTX), dst.UDPAddr()) + time.AfterFunc(500*time.Millisecond, func() { close(rs.hairTimeout) }) +} + +func (rs *reportState) waitHairCheck(ctx context.Context) { + rs.mu.Lock() + defer rs.mu.Unlock() + if !rs.sentHairCheck { + return + } + ret := rs.report + + select { + case <-rs.gotHairSTUN: + ret.HairPinning.Set(true) + case <-rs.hairTimeout: + ret.HairPinning.Set(false) + default: + select { + case <-rs.gotHairSTUN: + ret.HairPinning.Set(true) + case <-rs.hairTimeout: + ret.HairPinning.Set(false) + case <-ctx.Done(): + } + } +} + +// addNodeLatency updates rs to note that node's latency is d. If ipp +// is non-zero (for all but HTTPS replies), it's recorded as our UDP +// IP:port. +func (rs *reportState) addNodeLatency(node *tailcfg.DERPNode, ipp netaddr.IPPort, d time.Duration) { + var ipPortStr string + if ipp != (netaddr.IPPort{}) { + ipPortStr = net.JoinHostPort(ipp.IP.String(), fmt.Sprint(ipp.Port)) + } + + rs.mu.Lock() + defer rs.mu.Unlock() + ret := rs.report + + ret.UDP = true + updateLatency(&ret.RegionLatency, node.RegionID, d) + + switch { + case ipp.IP.Is6(): + updateLatency(&ret.RegionV6Latency, node.RegionID, d) + ret.IPv6 = true + ret.GlobalV6 = ipPortStr + // TODO: track MappingVariesByDestIP for IPv6 + // too? Would be sad if so, but who knows. + case ipp.IP.Is4(): + updateLatency(&ret.RegionV4Latency, node.RegionID, d) + if rs.gotEP4 == "" { + rs.gotEP4 = ipPortStr + ret.GlobalV4 = ipPortStr + rs.startHairCheckLocked(ipp) + } else { + if rs.gotEP4 != ipPortStr { + ret.MappingVariesByDestIP.Set(true) + } else if ret.MappingVariesByDestIP == "" { + ret.MappingVariesByDestIP.Set(false) + } + } } - sort.Strings(stuns4) - sort.Strings(stuns6) - return stuns4, stuns6, maxTries, nil } // GetReport gets a report. // // It may not be called concurrently with itself. -func (c *Client) GetReport(ctx context.Context) (*Report, error) { +func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (*Report, error) { // Mask user context with ours that we guarantee to cancel so // we can depend on it being closed in goroutines later. // (User ctx might be context.Background, etc) ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - if c.DERP == nil { - return nil, errors.New("netcheck: GetReport: Client.DERP is nil") + if dm == nil { + return nil, errors.New("netcheck: GetReport: DERP map is nil") } c.mu.Lock() - if c.gotHairSTUN != nil { + if c.curState != nil { c.mu.Unlock() return nil, errors.New("invalid concurrent call to GetReport") } - hairTX := stun.NewTxID() // random payload - c.hairTX = hairTX - gotHairSTUN := make(chan *net.UDPAddr, 1) - c.gotHairSTUN = gotHairSTUN + rs := &reportState{ + c: c, + report: new(Report), + inFlight: map[stun.TxID]func(netaddr.IPPort){}, + hairTX: stun.NewTxID(), // random payload + gotHairSTUN: make(chan *net.UDPAddr, 1), + hairTimeout: make(chan struct{}), + } + c.curState = rs + last := c.last + now := c.timeNow() + if c.nextFull || now.Sub(c.lastFull) > 5*time.Minute { + last = nil // causes makeProbePlan below to do a full (initial) plan + c.nextFull = false + c.lastFull = now + } c.mu.Unlock() defer func() { c.mu.Lock() defer c.mu.Unlock() - c.s4 = nil - c.s6 = nil - c.gotHairSTUN = nil + c.curState = nil }() - stuns4, stuns6, maxTries, err := c.pickSubset() - if err != nil { - return nil, err - } - - closeOnCtx := func(c io.Closer) { - <-ctx.Done() - c.Close() - } - v6iface, err := interfaces.HaveIPv6GlobalAddress() if err != nil { c.logf("interfaces: %v", err) } // Create a UDP4 socket used for sending to our discovered IPv4 address. - pc4Hair, err := net.ListenPacket("udp4", ":0") + rs.pc4Hair, err = net.ListenPacket("udp4", ":0") if err != nil { c.logf("udp4: %v", err) return nil, err } - defer pc4Hair.Close() - hairTimeout := make(chan bool, 1) - startHairCheck := func(dstEP string) { - if dst, err := net.ResolveUDPAddr("udp4", dstEP); err == nil { - pc4Hair.WriteTo(stun.Request(hairTX), dst) - time.AfterFunc(500*time.Millisecond, func() { hairTimeout <- true }) - } - } - - var ( - mu sync.Mutex - ret = &Report{ - DERPLatency: map[string]time.Duration{}, - } - gotEP = map[string]string{} // server -> ipPort - gotEP4 string - ) - anyV6 := func() bool { - mu.Lock() - defer mu.Unlock() - return ret.IPv6 - } - anyV4 := func() bool { - mu.Lock() - defer mu.Unlock() - return gotEP4 != "" - } - add := func(server, ipPort string, d time.Duration) { - ua, err := net.ResolveUDPAddr("udp", ipPort) - if err != nil { - c.logf("[unexpected] STUN addr %q", ipPort) - return - } - isV6 := ua.IP.To4() == nil - - mu.Lock() - defer mu.Unlock() - ret.UDP = true - ret.DERPLatency[server] = d - if isV6 { - ret.IPv6 = true - ret.GlobalV6 = ipPort - // TODO: track MappingVariesByDestIP for IPv6 - // too? Would be sad if so, but who knows. - } else { - // IPv4 - if gotEP4 == "" { - gotEP4 = ipPort - ret.GlobalV4 = ipPort - startHairCheck(ipPort) - } else { - if gotEP4 != ipPort { - ret.MappingVariesByDestIP.Set(true) - } else if ret.MappingVariesByDestIP == "" { - ret.MappingVariesByDestIP.Set(false) - } - } - } - gotEP[server] = ipPort - } - - var pc4, pc6 STUNConn + defer rs.pc4Hair.Close() if f := c.GetSTUNConn4; f != nil { - pc4 = f() + rs.pc4 = f() } else { u4, err := net.ListenPacket("udp4", ":0") if err != nil { c.logf("udp4: %v", err) return nil, err } - pc4 = u4 - go closeOnCtx(u4) + rs.pc4 = u4 + go c.readPackets(ctx, u4) } if v6iface { if f := c.GetSTUNConn6; f != nil { - pc6 = f() + rs.pc6 = f() } else { u6, err := net.ListenPacket("udp6", ":0") if err != nil { c.logf("udp6: %v", err) } else { - pc6 = u6 - go closeOnCtx(u6) + rs.pc6 = u6 + go c.readPackets(ctx, u6) } } } - reader := func(s *stunner.Stunner, pc STUNConn) { - var buf [64 << 10]byte - for { - n, addr, err := pc.ReadFrom(buf[:]) - if err != nil { - if ctx.Err() != nil { - return - } - c.logf("ReadFrom: %v", err) - return - } - ua, ok := addr.(*net.UDPAddr) - if !ok { - c.logf("ReadFrom: unexpected addr %T", addr) - continue - } - if c.handleHairSTUN(buf[:n], ua) { - continue - } - s.Receive(buf[:n], ua) - } - - } - - var grp errgroup.Group - - s4 := &stunner.Stunner{ - Send: pc4.WriteTo, - Endpoint: add, - Servers: stuns4, - Logf: c.logf, - DNSCache: dnscache.Get(), - MaxTries: maxTries, - } - - c.mu.Lock() - c.s4 = s4 - c.mu.Unlock() + plan := makeProbePlan(dm, v6iface, last) - grp.Go(func() error { - err := s4.Run(ctx) - if errors.Is(err, context.DeadlineExceeded) { - if !anyV4() { - c.logf("netcheck: no IPv4 UDP STUN replies") + wg := syncs.NewWaitGroupChan() + wg.Add(len(plan)) + for _, probeSet := range plan { + setCtx, cancelSet := context.WithCancel(ctx) + go func(probeSet []probe) { + for _, probe := range probeSet { + go rs.runProbe(setCtx, dm, probe, cancelSet) } - return nil - } - return err - }) - if c.GetSTUNConn4 == nil { - go reader(s4, pc4) + <-setCtx.Done() + wg.Decr() + }(probeSet) } - if pc6 != nil && len(stuns6) > 0 { - s6 := &stunner.Stunner{ - Endpoint: add, - Send: pc6.WriteTo, - Servers: stuns6, - Logf: c.logf, - OnlyIPv6: true, - DNSCache: dnscache.Get(), - MaxTries: maxTries, - } - - c.mu.Lock() - c.s6 = s6 - c.mu.Unlock() - - grp.Go(func() error { - err := s6.Run(ctx) - if errors.Is(err, context.DeadlineExceeded) { - if !anyV6() { - // IPv6 seemed like it was configured, but actually failed. - // Just log and return a nil error. - c.logf("IPv6 seemed configured, but no UDP STUN replies") - } - return nil - } - // Otherwise must be some invalid use of Stunner. - return err // - }) - if c.GetSTUNConn6 == nil { - go reader(s6, pc6) - } + select { + case <-ctx.Done(): + case <-wg.DoneChan(): } - err = grp.Wait() - if err != nil { - return nil, err - } + rs.waitHairCheck(ctx) - mu.Lock() - // Check hairpinning. - if ret.MappingVariesByDestIP == "false" && gotEP4 != "" { - select { - case <-gotHairSTUN: - ret.HairPinning.Set(true) - case <-hairTimeout: - ret.HairPinning.Set(false) - } - } - mu.Unlock() - - // Try HTTPS latency check if UDP is blocked and all checkings failed - if !anyV4() { - c.logf("netcheck: UDP is blocked, try HTTPS") + // Try HTTPS latency check if all STUN probes failed due to UDP presumably being blocked. + if !rs.anyUDP() { var wg sync.WaitGroup - for _, server := range stuns4 { - server := server - if _, ok := ret.DERPLatency[server]; ok { - continue + var need []*tailcfg.DERPRegion + for rid, reg := range dm.Regions { + if !rs.haveRegionLatency(rid) && regionHasDERPNode(reg) { + need = append(need, reg) } - - wg.Add(1) - go func() { + } + if len(need) > 0 { + wg.Add(len(need)) + c.logf("netcheck: UDP is blocked, trying HTTPS") + } + for _, reg := range need { + go func(reg *tailcfg.DERPRegion) { defer wg.Done() - if d, err := c.measureHTTPSLatency(server); err != nil { - c.logf("netcheck: measuring HTTPS latency of %v: %v", server, err) + if d, err := c.measureHTTPSLatency(reg); err != nil { + c.logf("netcheck: measuring HTTPS latency of %v (%d): %v", reg.RegionCode, reg.RegionID, err) } else { - mu.Lock() - ret.DERPLatency[server] = d - mu.Unlock() + rs.mu.Lock() + rs.report.RegionLatency[reg.RegionID] = d + rs.mu.Unlock() } - }() + }(reg) } wg.Wait() } - report := ret.Clone() + rs.mu.Lock() + report := rs.report.Clone() + rs.mu.Unlock() c.addReportHistoryAndSetPreferredDERP(report) - c.logConciseReport(report) + c.logConciseReport(report, dm) return report, nil } -func (c *Client) measureHTTPSLatency(server string) (time.Duration, error) { - host, _, err := net.SplitHostPort(server) - if err != nil { - return 0, err +// TODO: have caller pass in context +func (c *Client) measureHTTPSLatency(reg *tailcfg.DERPRegion) (time.Duration, error) { + if len(reg.Nodes) == 0 { + return 0, errors.New("no nodes") } + node := reg.Nodes[0] // TODO: use all nodes per region + host := node.HostName + // TODO: connect using provided IPv4/IPv6; use a Trasport & set the dialer var result httpstat.Result hctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &result), 5*time.Second) @@ -522,7 +721,7 @@ func (c *Client) measureHTTPSLatency(server string) (time.Duration, error) { return result.ServerProcessing, nil } -func (c *Client) logConciseReport(r *Report) { +func (c *Client) logConciseReport(r *Report, dm *tailcfg.DERPMap) { buf := bytes.NewBuffer(make([]byte, 0, 256)) // empirically: 5 DERPs + IPv6 == ~233 bytes fmt.Fprintf(buf, "udp=%v", r.UDP) fmt.Fprintf(buf, " v6=%v", r.IPv6) @@ -537,21 +736,20 @@ func (c *Client) logConciseReport(r *Report) { fmt.Fprintf(buf, " derp=%v", r.PreferredDERP) if r.PreferredDERP != 0 { fmt.Fprintf(buf, " derpdist=") - for i, id := range c.DERP.IDs() { + for i, rid := range dm.RegionIDs() { if i != 0 { buf.WriteByte(',') } - s := c.DERP.ServerByID(id) needComma := false - if d := r.DERPLatency[s.STUN4]; d != 0 { - fmt.Fprintf(buf, "%dv4:%v", id, d.Round(time.Millisecond)) + if d := r.RegionV4Latency[rid]; d != 0 { + fmt.Fprintf(buf, "%dv4:%v", rid, d.Round(time.Millisecond)) needComma = true } - if d := r.DERPLatency[s.STUN6]; d != 0 { + if d := r.RegionV6Latency[rid]; d != 0 { if needComma { buf.WriteByte(',') } - fmt.Fprintf(buf, "%dv6:%v", id, d.Round(time.Millisecond)) + fmt.Fprintf(buf, "%dv6:%v", rid, d.Round(time.Millisecond)) } } } @@ -581,15 +779,15 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report) { const maxAge = 5 * time.Minute - // STUN host:port => its best recent latency in last maxAge - bestRecent := map[string]time.Duration{} + // region ID => its best recent latency in last maxAge + bestRecent := map[int]time.Duration{} for t, pr := range c.prev { if now.Sub(t) > maxAge { delete(c.prev, t) continue } - for hp, d := range pr.DERPLatency { + for hp, d := range pr.RegionLatency { if bd, ok := bestRecent[hp]; !ok || d < bd { bestRecent[hp] = d } @@ -599,18 +797,133 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report) { // Then, pick which currently-alive DERP server from the // current report has the best latency over the past maxAge. var bestAny time.Duration - for hp := range r.DERPLatency { + for hp := range r.RegionLatency { best := bestRecent[hp] if r.PreferredDERP == 0 || best < bestAny { bestAny = best - r.PreferredDERP = c.DERP.NodeIDOfSTUNServer(hp) + r.PreferredDERP = hp + } + } +} + +func updateLatency(mp *map[int]time.Duration, regionID int, d time.Duration) { + if *mp == nil { + *mp = make(map[int]time.Duration) + } + m := *mp + if prev, ok := m[regionID]; !ok || d < prev { + m[regionID] = d + } +} + +func namedNode(dm *tailcfg.DERPMap, nodeName string) *tailcfg.DERPNode { + if dm == nil { + return nil + } + for _, r := range dm.Regions { + for _, n := range r.Nodes { + if n.Name == nodeName { + return n + } + } + } + return nil +} + +func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe probe, cancelSet func()) { + c := rs.c + node := namedNode(dm, probe.node) + if node == nil { + c.logf("netcheck.runProbe: named node %q not found", probe.node) + return + } + + if probe.delay > 0 { + delayTimer := time.NewTimer(probe.delay) + select { + case <-delayTimer.C: + case <-ctx.Done(): + delayTimer.Stop() + return + } + } + + if !rs.probeWouldHelp(probe, node) { + cancelSet() + return + } + + addr := c.nodeAddr(ctx, node, probe.proto) + if addr == nil { + return + } + + txID := stun.NewTxID() + req := stun.Request(txID) + + sent := time.Now() // after DNS lookup above + + rs.mu.Lock() + rs.inFlight[txID] = func(ipp netaddr.IPPort) { + rs.addNodeLatency(node, ipp, time.Since(sent)) + cancelSet() // abort other nodes in this set + } + rs.mu.Unlock() + + switch probe.proto { + case probeIPv4: + rs.pc4.WriteTo(req, addr) + case probeIPv6: + rs.pc6.WriteTo(req, addr) + default: + panic("bad probe proto " + fmt.Sprint(probe.proto)) + } +} + +// proto is 4 or 6 +// If it returns nil, the node is skipped. +func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeProto) *net.UDPAddr { + port := n.STUNPort + if port == 0 { + port = 3478 + } + if port < 0 || port > 1<<16-1 { + return nil + } + switch proto { + case probeIPv4: + if n.IPv4 != "" { + ip, _ := netaddr.ParseIP(n.IPv4) + if !ip.Is4() { + return nil + } + return netaddr.IPPort{ip, uint16(port)}.UDPAddr() + } + case probeIPv6: + if n.IPv6 != "" { + ip, _ := netaddr.ParseIP(n.IPv6) + if !ip.Is6() { + return nil + } + return netaddr.IPPort{ip, uint16(port)}.UDPAddr() + } + default: + return nil + } + + // TODO(bradfitz): add singleflight+dnscache here. + addrs, _ := net.DefaultResolver.LookupIPAddr(ctx, n.HostName) + for _, a := range addrs { + if (a.IP.To4() != nil) == (proto == probeIPv4) { + return &net.UDPAddr{IP: a.IP, Port: port} } } + return nil } -func stringsContains(ss []string, s string) bool { - for _, v := range ss { - if s == v { +func regionHasDERPNode(r *tailcfg.DERPRegion) bool { + for _, n := range r.Nodes { + if !n.STUNOnly { return true } } diff --git a/netcheck/netcheck_test.go b/netcheck/netcheck_test.go index 6e9894df9..58a4eef74 100644 --- a/netcheck/netcheck_test.go +++ b/netcheck/netcheck_test.go @@ -9,28 +9,34 @@ import ( "fmt" "net" "reflect" + "sort" + "strconv" + "strings" "testing" "time" - "tailscale.com/derp/derpmap" "tailscale.com/stun" "tailscale.com/stun/stuntest" + "tailscale.com/tailcfg" ) func TestHairpinSTUN(t *testing.T) { + tx := stun.NewTxID() c := &Client{ - hairTX: stun.NewTxID(), - gotHairSTUN: make(chan *net.UDPAddr, 1), + curState: &reportState{ + hairTX: tx, + gotHairSTUN: make(chan *net.UDPAddr, 1), + }, } - req := stun.Request(c.hairTX) + req := stun.Request(tx) if !stun.Is(req) { t.Fatal("expected STUN message") } - if !c.handleHairSTUN(req, nil) { + if !c.handleHairSTUNLocked(req, nil) { t.Fatal("expected true") } select { - case <-c.gotHairSTUN: + case <-c.curState.gotHairSTUN: default: t.Fatal("expected value") } @@ -41,25 +47,24 @@ func TestBasic(t *testing.T) { defer cleanup() c := &Client{ - DERP: derpmap.NewTestWorld(stunAddr), Logf: t.Logf, } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - r, err := c.GetReport(ctx) + r, err := c.GetReport(ctx, stuntest.DERPMapOf(stunAddr.String())) if err != nil { t.Fatal(err) } if !r.UDP { t.Error("want UDP") } - if len(r.DERPLatency) != 1 { - t.Errorf("expected 1 key in DERPLatency; got %+v", r.DERPLatency) + if len(r.RegionLatency) != 1 { + t.Errorf("expected 1 key in DERPLatency; got %+v", r.RegionLatency) } - if _, ok := r.DERPLatency[stunAddr]; !ok { - t.Errorf("expected key %q in DERPLatency; got %+v", stunAddr, r.DERPLatency) + if _, ok := r.RegionLatency[1]; !ok { + t.Errorf("expected key 1 in DERPLatency; got %+v", r.RegionLatency) } if r.GlobalV4 == "" { t.Error("expected GlobalV4 set") @@ -78,20 +83,20 @@ func TestWorksWhenUDPBlocked(t *testing.T) { stunAddr := blackhole.LocalAddr().String() + dm := stuntest.DERPMapOf(stunAddr) + dm.Regions[1].Nodes[0].STUNOnly = true + c := &Client{ - DERP: derpmap.NewTestWorld(stunAddr), Logf: t.Logf, } ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) defer cancel() - r, err := c.GetReport(ctx) + r, err := c.GetReport(ctx, dm) if err != nil { t.Fatal(err) } - want := &Report{ - DERPLatency: map[string]time.Duration{}, - } + want := new(Report) if !reflect.DeepEqual(r, want) { t.Errorf("mismatch\n got: %+v\nwant: %+v\n", r, want) @@ -99,30 +104,24 @@ func TestWorksWhenUDPBlocked(t *testing.T) { } func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { - derps := derpmap.NewTestWorldWith( - &derpmap.Server{ - ID: 1, - STUN4: "d1:1", - }, - &derpmap.Server{ - ID: 2, - STUN4: "d2:1", - }, - &derpmap.Server{ - ID: 3, - STUN4: "d3:1", - }, - ) // report returns a *Report from (DERP host, time.Duration)+ pairs. report := func(a ...interface{}) *Report { - r := &Report{DERPLatency: map[string]time.Duration{}} + r := &Report{RegionLatency: map[int]time.Duration{}} for i := 0; i < len(a); i += 2 { - k := a[i].(string) + ":1" + s := a[i].(string) + if !strings.HasPrefix(s, "d") { + t.Fatalf("invalid derp server key %q", s) + } + regionID, err := strconv.Atoi(s[1:]) + if err != nil { + t.Fatalf("invalid derp server key %q", s) + } + switch v := a[i+1].(type) { case time.Duration: - r.DERPLatency[k] = v + r.RegionLatency[regionID] = v case int: - r.DERPLatency[k] = time.Second * time.Duration(v) + r.RegionLatency[regionID] = time.Second * time.Duration(v) default: panic(fmt.Sprintf("unexpected type %T", v)) } @@ -194,7 +193,6 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { t.Run(tt.name, func(t *testing.T) { fakeTime := time.Unix(123, 0) c := &Client{ - DERP: derps, TimeNow: func() time.Time { return fakeTime }, } for _, s := range tt.steps { @@ -212,81 +210,217 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { } } -func TestPickSubset(t *testing.T) { - derps := derpmap.NewTestWorldWith( - &derpmap.Server{ - ID: 1, - STUN4: "d1:4", - STUN6: "d1:6", +func TestMakeProbePlan(t *testing.T) { + // basicMap has 5 regions. each region has a number of nodes + // equal to the region number (1 has 1a, 2 has 2a and 2b, etc.) + basicMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{}, + } + for rid := 1; rid <= 5; rid++ { + var nodes []*tailcfg.DERPNode + for nid := 0; nid < rid; nid++ { + nodes = append(nodes, &tailcfg.DERPNode{ + Name: fmt.Sprintf("%d%c", rid, 'a'+rune(nid)), + RegionID: rid, + HostName: fmt.Sprintf("derp%d-%d", rid, nid), + IPv4: fmt.Sprintf("%d.0.0.%d", rid, nid), + IPv6: fmt.Sprintf("%d::%d", rid, nid), + }) + } + basicMap.Regions[rid] = &tailcfg.DERPRegion{ + RegionID: rid, + Nodes: nodes, + } + } + + const ms = time.Millisecond + p := func(name string, c rune, d ...time.Duration) probe { + var proto probeProto + switch c { + case 4: + proto = probeIPv4 + case 6: + proto = probeIPv6 + case 'h': + proto = probeHTTPS + } + pr := probe{node: name, proto: proto} + if len(d) == 1 { + pr.delay = d[0] + } else if len(d) > 1 { + panic("too many args") + } + return pr + } + tests := []struct { + name string + dm *tailcfg.DERPMap + have6if bool + last *Report + want probePlan + }{ + { + name: "initial_v6", + dm: basicMap, + have6if: true, + last: nil, // initial + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 100*ms), p("1a", 4, 200*ms)}, // all a + "region-1-v6": []probe{p("1a", 6), p("1a", 6, 100*ms), p("1a", 6, 200*ms)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 100*ms), p("2a", 4, 200*ms)}, // a -> b -> a + "region-2-v6": []probe{p("2a", 6), p("2b", 6, 100*ms), p("2a", 6, 200*ms)}, + "region-3-v4": []probe{p("3a", 4), p("3b", 4, 100*ms), p("3c", 4, 200*ms)}, // a -> b -> c + "region-3-v6": []probe{p("3a", 6), p("3b", 6, 100*ms), p("3c", 6, 200*ms)}, + "region-4-v4": []probe{p("4a", 4), p("4b", 4, 100*ms), p("4c", 4, 200*ms)}, + "region-4-v6": []probe{p("4a", 6), p("4b", 6, 100*ms), p("4c", 6, 200*ms)}, + "region-5-v4": []probe{p("5a", 4), p("5b", 4, 100*ms), p("5c", 4, 200*ms)}, + "region-5-v6": []probe{p("5a", 6), p("5b", 6, 100*ms), p("5c", 6, 200*ms)}, + }, }, - &derpmap.Server{ - ID: 2, - STUN4: "d2:4", - STUN6: "d2:6", + { + name: "initial_no_v6", + dm: basicMap, + have6if: false, + last: nil, // initial + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 100*ms), p("1a", 4, 200*ms)}, // all a + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 100*ms), p("2a", 4, 200*ms)}, // a -> b -> a + "region-3-v4": []probe{p("3a", 4), p("3b", 4, 100*ms), p("3c", 4, 200*ms)}, // a -> b -> c + "region-4-v4": []probe{p("4a", 4), p("4b", 4, 100*ms), p("4c", 4, 200*ms)}, + "region-5-v4": []probe{p("5a", 4), p("5b", 4, 100*ms), p("5c", 4, 200*ms)}, + }, }, - &derpmap.Server{ - ID: 3, - STUN4: "d3:4", - STUN6: "d3:6", + { + name: "second_v4_no_6if", + dm: basicMap, + have6if: false, + last: &Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + // Pretend 5 is missing + }, + RegionV4Latency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + }, + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 12*ms)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)}, + "region-3-v4": []probe{p("3a", 4)}, + }, }, - ) - tests := []struct { - name string - last *Report - want4 []string - want6 []string - wantTries map[string]int - }{ { - name: "fresh", - last: nil, - want4: []string{"d1:4", "d2:4", "d3:4"}, - want6: []string{"d1:6", "d2:6", "d3:6"}, - wantTries: map[string]int{ - "d1:4": 2, - "d2:4": 2, - "d3:4": 2, - "d1:6": 1, - "d2:6": 1, - "d3:6": 1, + name: "second_v4_only_with_6if", + dm: basicMap, + have6if: true, + last: &Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + // Pretend 5 is missing + }, + RegionV4Latency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + }, + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 12*ms)}, + "region-1-v6": []probe{p("1a", 6)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)}, + "region-2-v6": []probe{p("2a", 6)}, + "region-3-v4": []probe{p("3a", 4)}, }, }, { - name: "1_and_3_closest", + name: "second_mixed", + dm: basicMap, + have6if: true, last: &Report{ - DERPLatency: map[string]time.Duration{ - "d1:4": 15 * time.Millisecond, - "d2:4": 300 * time.Millisecond, - "d3:4": 25 * time.Millisecond, + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + // Pretend 5 is missing + }, + RegionV4Latency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + }, + RegionV6Latency: map[int]time.Duration{ + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, }, }, - want4: []string{"d1:4", "d2:4", "d3:4"}, - want6: []string{"d1:6", "d3:6"}, - wantTries: map[string]int{ - "d1:4": 2, - "d3:4": 2, - "d2:4": 1, - "d1:6": 1, - "d3:6": 1, + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 12*ms)}, + "region-1-v6": []probe{p("1a", 6), p("1a", 6, 12*ms)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)}, + "region-2-v6": []probe{p("2a", 6), p("2b", 6, 24*ms)}, + "region-3-v4": []probe{p("3a", 4)}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &Client{DERP: derps, last: tt.last} - got4, got6, gotTries, err := c.pickSubset() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got4, tt.want4) { - t.Errorf("stuns4 = %q; want %q", got4, tt.want4) - } - if !reflect.DeepEqual(got6, tt.want6) { - t.Errorf("stuns6 = %q; want %q", got6, tt.want6) - } - if !reflect.DeepEqual(gotTries, tt.wantTries) { - t.Errorf("tries = %v; want %v", gotTries, tt.wantTries) + got := makeProbePlan(tt.dm, tt.have6if, tt.last) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unexpected plan; got:\n%v\nwant:\n%v\n", got, tt.want) } }) } } + +func (plan probePlan) String() string { + var sb strings.Builder + keys := []string{} + for k := range plan { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, key := range keys { + fmt.Fprintf(&sb, "[%s]", key) + pv := plan[key] + for _, p := range pv { + fmt.Fprintf(&sb, " %v", p) + } + sb.WriteByte('\n') + } + return sb.String() +} + +func (p probe) String() string { + wait := "" + if p.wait > 0 { + wait = "+" + p.wait.String() + } + delay := "" + if p.delay > 0 { + delay = "@" + p.delay.String() + } + return fmt.Sprintf("%s-%s%s%s", p.node, p.proto, delay, wait) +} + +func (p probeProto) String() string { + switch p { + case probeIPv4: + return "v4" + case probeIPv6: + return "v4" + case probeHTTPS: + return "https" + } + return "?" +} diff --git a/stun/stuntest/stuntest.go b/stun/stuntest/stuntest.go index ba244fc37..b53db0e0a 100644 --- a/stun/stuntest/stuntest.go +++ b/stun/stuntest/stuntest.go @@ -6,12 +6,16 @@ package stuntest import ( + "fmt" "net" + "strconv" "strings" "sync" "testing" + "inet.af/netaddr" "tailscale.com/stun" + "tailscale.com/tailcfg" ) type stunStats struct { @@ -20,7 +24,7 @@ type stunStats struct { readIPv6 int } -func Serve(t *testing.T) (addr string, cleanupFn func()) { +func Serve(t *testing.T) (addr *net.UDPAddr, cleanupFn func()) { t.Helper() // TODO(crawshaw): use stats to test re-STUN logic @@ -30,13 +34,13 @@ func Serve(t *testing.T) (addr string, cleanupFn func()) { if err != nil { t.Fatalf("failed to open STUN listener: %v", err) } - - stunAddr := pc.LocalAddr().String() - stunAddr = strings.Replace(stunAddr, "0.0.0.0:", "127.0.0.1:", 1) - + addr = &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: pc.LocalAddr().(*net.UDPAddr).Port, + } doneCh := make(chan struct{}) go runSTUN(t, pc, &stats, doneCh) - return stunAddr, func() { + return addr, func() { pc.Close() <-doneCh } @@ -79,3 +83,47 @@ func runSTUN(t *testing.T, pc net.PacketConn, stats *stunStats, done chan<- stru } } } + +func DERPMapOf(stun ...string) *tailcfg.DERPMap { + m := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{}, + } + for i, hostPortStr := range stun { + regionID := i + 1 + host, portStr, err := net.SplitHostPort(hostPortStr) + if err != nil { + panic(fmt.Sprintf("bogus STUN hostport: %q", hostPortStr)) + } + port, err := strconv.Atoi(portStr) + if err != nil { + panic(fmt.Sprintf("bogus port %q in %q", portStr, hostPortStr)) + } + var ipv4, ipv6 string + ip, err := netaddr.ParseIP(host) + if err != nil { + panic(fmt.Sprintf("bogus non-IP STUN host %q in %q", host, hostPortStr)) + } + if ip.Is4() { + ipv4 = host + ipv6 = "none" + } + if ip.Is6() { + ipv6 = host + ipv4 = "none" + } + node := &tailcfg.DERPNode{ + Name: fmt.Sprint(regionID) + "a", + RegionID: regionID, + HostName: fmt.Sprintf("d%d.invalid", regionID), + IPv4: ipv4, + IPv6: ipv6, + STUNPort: port, + STUNOnly: true, + } + m.Regions[regionID] = &tailcfg.DERPRegion{ + RegionID: regionID, + Nodes: []*tailcfg.DERPNode{node}, + } + } + return m +} diff --git a/stunner/stunner.go b/stunner/stunner.go deleted file mode 100644 index a99ad3c49..000000000 --- a/stunner/stunner.go +++ /dev/null @@ -1,310 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stunner - -import ( - "context" - "errors" - "fmt" - "math/rand" - "net" - "strconv" - "strings" - "sync" - "time" - - "tailscale.com/net/dnscache" - "tailscale.com/stun" - "tailscale.com/types/structs" -) - -// Stunner sends a STUN request to several servers and handles a response. -// -// It is designed to used on a connection owned by other code and so does -// not directly reference a net.Conn of any sort. Instead, the user should -// provide Send function to send packets, and call Receive when a new -// STUN response is received. -// -// In response, a Stunner will call Endpoint with any endpoints determined -// for the connection. (An endpoint may be reported multiple times if -// multiple servers are provided.) -type Stunner struct { - // Send sends a packet. - // It will typically be a PacketConn.WriteTo method value. - Send func([]byte, net.Addr) (int, error) // sends a packet - - // Endpoint is called whenever a STUN response is received. - // The server is the STUN server that replied, endpoint is the ip:port - // from the STUN response, and d is the duration that the STUN request - // took on the wire (not including DNS lookup time. - Endpoint func(server, endpoint string, d time.Duration) - - // onPacket is the internal version of Endpoint that does de-dup. - // It's set by Run. - onPacket func(server, endpoint string, d time.Duration) - - Servers []string // STUN servers to contact - - // DNSCache optionally specifies a DNSCache to use. - // If nil, a DNS cache is not used. - DNSCache *dnscache.Resolver - - // Logf optionally specifies a log function. If nil, logging is disabled. - Logf func(format string, args ...interface{}) - - // OnlyIPv6 controls whether IPv6 is exclusively used. - // If false, only IPv4 is used. There is currently no mixed mode. - OnlyIPv6 bool - - // MaxTries optionally provides a mapping from server name to the maximum - // number of tries that should be made for a given server. - // If nil or a server is not present in the map, the default is 1. - // Values less than 1 are ignored. - MaxTries map[string]int - - mu sync.Mutex - inFlight map[stun.TxID]request -} - -func (s *Stunner) addTX(tx stun.TxID, server string) { - s.mu.Lock() - defer s.mu.Unlock() - if _, dup := s.inFlight[tx]; dup { - panic("unexpected duplicate STUN TransactionID") - } - s.inFlight[tx] = request{sent: time.Now(), server: server} -} - -func (s *Stunner) removeTX(tx stun.TxID) (request, bool) { - s.mu.Lock() - defer s.mu.Unlock() - if s.inFlight == nil { - return request{}, false - } - r, ok := s.inFlight[tx] - if ok { - delete(s.inFlight, tx) - } else { - s.logf("stunner: got STUN packet for unknown TxID %x", tx) - } - return r, ok -} - -type request struct { - _ structs.Incomparable - sent time.Time - server string -} - -func (s *Stunner) logf(format string, args ...interface{}) { - if s.Logf != nil { - s.Logf(format, args...) - } -} - -// Receive delivers a STUN packet to the stunner. -func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) { - if !stun.Is(p) { - s.logf("[unexpected] stunner: received non-STUN packet") - return - } - now := time.Now() - tx, addr, port, err := stun.ParseResponse(p) - if err != nil { - if _, err := stun.ParseBindingRequest(p); err == nil { - // This was probably our own netcheck hairpin - // check probe coming in late. Ignore. - return - } - s.logf("stunner: received unexpected STUN message response from %v: %v", fromAddr, err) - return - } - r, ok := s.removeTX(tx) - if !ok { - return - } - d := now.Sub(r.sent) - - host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port)) - s.onPacket(r.server, host, d) -} - -func (s *Stunner) resolver() *net.Resolver { - return net.DefaultResolver -} - -// cleanUpPostRun zeros out some fields, mostly for debugging (so -// things crash or race+fail if there's a sender still running.) -func (s *Stunner) cleanUpPostRun() { - s.mu.Lock() - s.inFlight = nil - s.mu.Unlock() -} - -// Run starts a Stunner and blocks until all servers either respond -// or are tried multiple times and timeout. -// It can not be called concurrently with itself. -func (s *Stunner) Run(ctx context.Context) error { - for _, server := range s.Servers { - if _, _, err := net.SplitHostPort(server); err != nil { - return fmt.Errorf("Stunner.Run: invalid server %q (in Server list %q)", server, s.Servers) - } - } - if len(s.Servers) == 0 { - return errors.New("stunner: no Servers") - } - - s.inFlight = make(map[stun.TxID]request) - defer s.cleanUpPostRun() - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - type sender struct { - ctx context.Context - cancel context.CancelFunc - } - var ( - needMu sync.Mutex - need = make(map[string]sender) // keyed by server; deleted when done - allDone = make(chan struct{}) // closed when need is empty - ) - s.onPacket = func(server, endpoint string, d time.Duration) { - needMu.Lock() - defer needMu.Unlock() - sender, ok := need[server] - if !ok { - return - } - sender.cancel() - delete(need, server) - s.Endpoint(server, endpoint, d) - if len(need) == 0 { - close(allDone) - } - } - - var wg sync.WaitGroup - for _, server := range s.Servers { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - need[server] = sender{ctx, cancel} - } - needMu.Lock() - for server, sender := range need { - wg.Add(1) - server, ctx := server, sender.ctx - go func() { - defer wg.Done() - s.sendPackets(ctx, server) - }() - } - needMu.Unlock() - var err error - select { - case <-ctx.Done(): - err = ctx.Err() - case <-allDone: - cancel() - } - wg.Wait() - - var missing []string - needMu.Lock() - for server := range need { - missing = append(missing, server) - } - needMu.Unlock() - - if len(missing) == 0 || err == nil { - return nil - } - return fmt.Errorf("got STUN error: %w; missing replies from: %v", err, strings.Join(missing, ", ")) -} - -func (s *Stunner) serverAddr(ctx context.Context, server string) (*net.UDPAddr, error) { - hostStr, portStr, err := net.SplitHostPort(server) - if err != nil { - return nil, err - } - addrPort, err := strconv.Atoi(portStr) - if err != nil { - return nil, fmt.Errorf("port: %v", err) - } - if addrPort == 0 { - addrPort = 3478 - } - addr := &net.UDPAddr{Port: addrPort} - - var ipAddrs []net.IPAddr - if s.DNSCache != nil { - ip, err := s.DNSCache.LookupIP(ctx, hostStr) - if err != nil { - return nil, err - } - ipAddrs = []net.IPAddr{{IP: ip}} - } else { - ipAddrs, err = s.resolver().LookupIPAddr(ctx, hostStr) - if err != nil { - return nil, fmt.Errorf("lookup ip addr (%q): %v", hostStr, err) - } - } - - for _, ipAddr := range ipAddrs { - ip4 := ipAddr.IP.To4() - if ip4 != nil { - if s.OnlyIPv6 { - continue - } - addr.IP = ip4 - break - } else if s.OnlyIPv6 { - addr.IP = ipAddr.IP - addr.Zone = ipAddr.Zone - } - } - if addr.IP == nil { - if s.OnlyIPv6 { - return nil, fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs) - } - return nil, fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs) - } - return addr, nil -} - -// maxTriesForServer returns the maximum number of STUN queries that -// will be sent to server (for one call to Run). The default is 1. -func (s *Stunner) maxTriesForServer(server string) int { - if v, ok := s.MaxTries[server]; ok && v > 0 { - return v - } - return 1 -} - -func (s *Stunner) sendPackets(ctx context.Context, server string) error { - addr, err := s.serverAddr(ctx, server) - if err != nil { - return err - } - maxTries := s.maxTriesForServer(server) - for i := 0; i < maxTries; i++ { - txID := stun.NewTxID() - req := stun.Request(txID) - s.addTX(txID, server) - _, err = s.Send(req, addr) - if err != nil { - return fmt.Errorf("send: %v", err) - } - - select { - case <-ctx.Done(): - // Ignore error. The caller deals with handling contexts. - // We only use it to dermine when to stop spraying STUN packets. - return nil - case <-time.After(time.Millisecond * time.Duration(50+rand.Intn(200))): - } - } - return nil -} diff --git a/stunner/stunner_test.go b/stunner/stunner_test.go deleted file mode 100644 index a3555f1be..000000000 --- a/stunner/stunner_test.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stunner - -import ( - "context" - "errors" - "fmt" - "net" - "sort" - "testing" - "time" - - "gortc.io/stun" -) - -func TestStun(t *testing.T) { - conn1, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer conn1.Close() - conn2, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer conn2.Close() - stunServers := []string{ - conn1.LocalAddr().String(), conn2.LocalAddr().String(), - } - - epCh := make(chan string, 16) - - localConn, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - - s := &Stunner{ - Send: localConn.WriteTo, - Endpoint: func(server, ep string, d time.Duration) { epCh <- ep }, - Servers: stunServers, - MaxTries: map[string]int{ - stunServers[0]: 2, - stunServers[1]: 2, - }, - } - - stun1Err := make(chan error) - go func() { - stun1Err <- startSTUN(conn1, s.Receive) - }() - stun2Err := make(chan error) - go func() { - stun2Err <- startSTUNDrop1(conn2, s.Receive) - }() - - errCh := make(chan error) - go func() { - errCh <- s.Run(context.Background()) - }() - - var eps []string - select { - case ep := <-epCh: - eps = append(eps, ep) - case <-time.After(100 * time.Millisecond): - t.Fatal("missing first endpoint response") - } - select { - case ep := <-epCh: - eps = append(eps, ep) - case <-time.After(500 * time.Millisecond): - t.Fatal("missing second endpoint response") - } - sort.Strings(eps) - if want := "1.2.3.4:1234"; eps[0] != want { - t.Errorf("eps[0]=%q, want %q", eps[0], want) - } - if want := "4.5.6.7:4567"; eps[1] != want { - t.Errorf("eps[1]=%q, want %q", eps[1], want) - } - - if err := <-errCh; err != nil { - t.Fatal(err) - } -} - -func startSTUNDrop1(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error { - if _, _, err := conn.ReadFrom(make([]byte, 1024)); err != nil { - return fmt.Errorf("first stun server read failed: %v", err) - } - req := new(stun.Message) - res := new(stun.Message) - - p := make([]byte, 1024) - n, addr, err := conn.ReadFrom(p) - if err != nil { - return err - } - p = p[:n] - if !stun.IsMessage(p) { - return errors.New("not a STUN message") - } - if _, err := req.Write(p); err != nil { - return err - } - mappedAddr := &stun.XORMappedAddress{ - IP: net.ParseIP("1.2.3.4"), - Port: 1234, - } - software := stun.NewSoftware("endpointer") - err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint) - if err != nil { - return err - } - writeTo(res.Raw, addr.(*net.UDPAddr)) - return nil -} - -func startSTUN(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error { - req := new(stun.Message) - res := new(stun.Message) - - p := make([]byte, 1024) - n, addr, err := conn.ReadFrom(p) - if err != nil { - return err - } - p = p[:n] - if !stun.IsMessage(p) { - return errors.New("not a STUN message") - } - if _, err := req.Write(p); err != nil { - return err - } - mappedAddr := &stun.XORMappedAddress{ - IP: net.ParseIP("4.5.6.7"), - Port: 4567, - } - software := stun.NewSoftware("endpointer") - err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint) - if err != nil { - return err - } - writeTo(res.Raw, addr.(*net.UDPAddr)) - return nil -} - -// TODO: test retry timeout (overwrite the retryDurations) -// TODO: test canceling context passed to Run -// TODO: test sending bad packets diff --git a/tailcfg/derpmap.go b/tailcfg/derpmap.go index c7553545c..69aa92157 100644 --- a/tailcfg/derpmap.go +++ b/tailcfg/derpmap.go @@ -4,6 +4,8 @@ package tailcfg +import "sort" + // DERPMap describes the set of DERP packet relay servers that are available. type DERPMap struct { // Regions is the set of geographic regions running DERP node(s). @@ -14,6 +16,16 @@ type DERPMap struct { Regions map[int]*DERPRegion } +/// RegionIDs returns the sorted region IDs. +func (m *DERPMap) RegionIDs() []int { + ret := make([]int, 0, len(m.Regions)) + for rid := range m.Regions { + ret = append(ret, rid) + } + sort.Ints(ret) + return ret +} + // DERPRegion is a geographic region running DERP relay node(s). // // Client nodes discover which region they're closest to, advertise @@ -85,9 +97,29 @@ type DERPNode struct { // IPv4 optionally forces an IPv4 address to use, instead of using DNS. // If empty, A record(s) from DNS lookups of HostName are used. + // If the string is not an IPv4 address, IPv4 is not used; the + // conventional string to disable IPv4 (and not use DNS) is + // "none". IPv4 string `json:",omitempty"` // IPv6 optionally forces an IPv6 address to use, instead of using DNS. // If empty, AAAA record(s) from DNS lookups of HostName are used. + // If the string is not an IPv6 address, IPv6 is not used; the + // conventional string to disable IPv6 (and not use DNS) is + // "none". IPv6 string `json:",omitempty"` + + // Port optionally specifies a STUN port to use. + // Zero means 3478. + // To disable STUN on this node, use -1. + STUNPort int `json:",omitempty"` + + // STUNOnly marks a node as only a STUN server and not a DERP + // server. + STUNOnly bool `json:",omitempty"` + + // DERPTestPort is used in tests to override the port, instead + // of using the default port of 443. If non-zero, TLS + // verification is skipped. + DERPTestPort int `json:",omitempty"` } diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index f1ddb1ee8..5afe05129 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -315,8 +315,9 @@ type NetInfo struct { LinkType string // "wired", "wifi", "mobile" (LTE, 4G, 3G, etc) // DERPLatency is the fastest recent time to reach various - // DERP STUN servers, in seconds. The map key is the DERP - // server's STUN host:port. + // DERP STUN servers, in seconds. The map key is the + // "regionID-v4" or "-v6"; it was previously the DERP server's + // STUN host:port. // // This should only be updated rarely, or when there's a // material change, as any change here also gets uploaded to @@ -336,7 +337,7 @@ func (ni *NetInfo) String() string { } // BasicallyEqual reports whether ni and ni2 are basically equal, ignoring -// changes in DERPLatency. +// changes in DERP ServerLatency & RegionLatency. func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool { if (ni == nil) != (ni2 == nil) { return false diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index d9e3a4800..1b8cb0cc4 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -9,7 +9,6 @@ package magicsock import ( "bytes" "context" - "crypto/tls" "encoding/binary" "errors" "fmt" @@ -17,6 +16,7 @@ import ( "math/rand" "net" "os" + "reflect" "sort" "strconv" "strings" @@ -32,7 +32,6 @@ import ( "inet.af/netaddr" "tailscale.com/derp" "tailscale.com/derp/derphttp" - "tailscale.com/derp/derpmap" "tailscale.com/ipn/ipnstate" "tailscale.com/net/dnscache" "tailscale.com/net/interfaces" @@ -55,7 +54,6 @@ type Conn struct { epFunc func(endpoints []string) logf logger.Logf sendLogLimit *rate.Limiter - derps *derpmap.World netChecker *netcheck.Client // bufferedIPv4From and bufferedIPv4Packet are owned by @@ -76,7 +74,8 @@ type Conn struct { mu sync.Mutex // guards all following fields - closed bool + started bool + closed bool endpointsUpdateWaiter *sync.Cond endpointsUpdateActive bool @@ -104,13 +103,12 @@ type Conn struct { netInfoFunc func(*tailcfg.NetInfo) // nil until set netInfoLast *tailcfg.NetInfo - wantDerp bool - privateKey key.Private - myDerp int // nearest DERP server; 0 means none/unknown - derpStarted chan struct{} // closed on first connection to DERP; for tests - activeDerp map[int]activeDerp - prevDerp map[int]*syncs.WaitGroupChan - derpTLSConfig *tls.Config // normally nil; used by tests + derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled + privateKey key.Private + myDerp int // nearest DERP region ID; 0 means none/unknown + derpStarted chan struct{} // closed on first connection to DERP; for tests + activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region + prevDerp map[int]*syncs.WaitGroupChan // derpRoute contains optional alternate routes to use as an // optimization instead of contacting a peer via their home @@ -196,14 +194,9 @@ type Options struct { // Zero means to pick one automatically. Port uint16 - // DERPs, if non-nil, is used instead of derpmap.Prod. - DERPs *derpmap.World - // EndpointsFunc optionally provides a func to be called when // endpoints change. The called func does not own the slice. EndpointsFunc func(endpoint []string) - - derpTLSConfig *tls.Config // normally nil; used by tests } func (o *Options) logf() logger.Logf { @@ -220,37 +213,39 @@ func (o *Options) endpointsFunc() func([]string) { return o.EndpointsFunc } -// Listen creates a magic Conn listening on opts.Port. -// As the set of possible endpoints for a Conn changes, the -// callback opts.EndpointsFunc is called. -func Listen(opts Options) (*Conn, error) { +// newConn is the error-free, network-listening-side-effect-free based +// of NewConn. Mostly for tests. +func newConn() *Conn { c := &Conn{ - pconnPort: opts.Port, - logf: opts.logf(), - epFunc: opts.endpointsFunc(), - sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), - addrsByUDP: make(map[netaddr.IPPort]*AddrSet), - addrsByKey: make(map[key.Public]*AddrSet), - wantDerp: true, - derpRecvCh: make(chan derpReadResult), - udpRecvCh: make(chan udpReadResult), - derpTLSConfig: opts.derpTLSConfig, - derpStarted: make(chan struct{}), - derps: opts.DERPs, - peerLastDerp: make(map[key.Public]int), + sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), + addrsByUDP: make(map[netaddr.IPPort]*AddrSet), + addrsByKey: make(map[key.Public]*AddrSet), + derpRecvCh: make(chan derpReadResult), + udpRecvCh: make(chan udpReadResult), + derpStarted: make(chan struct{}), + peerLastDerp: make(map[key.Public]int), } c.endpointsUpdateWaiter = sync.NewCond(&c.mu) + return c +} + +// NewConn creates a magic Conn listening on opts.Port. +// As the set of possible endpoints for a Conn changes, the +// callback opts.EndpointsFunc is called. +// +// It doesn't start doing anything until Start is called. +func NewConn(opts Options) (*Conn, error) { + c := newConn() + c.pconnPort = opts.Port + c.logf = opts.logf() + c.epFunc = opts.endpointsFunc() if err := c.initialBind(); err != nil { return nil, err } c.connCtx, c.connCtxCancel = context.WithCancel(context.Background()) - if c.derps == nil { - c.derps = derpmap.Prod() - } c.netChecker = &netcheck.Client{ - DERP: c.derps, Logf: logger.WithPrefix(c.logf, "netcheck: "), GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 }, } @@ -259,6 +254,18 @@ func Listen(opts Options) (*Conn, error) { } c.ignoreSTUNPackets() + + return c, nil +} + +func (c *Conn) Start() { + c.mu.Lock() + if c.started { + panic("duplicate Start call") + } + c.started = true + c.mu.Unlock() + c.ReSTUN("initial") // We assume that LinkChange notifications are plumbed through well @@ -267,8 +274,6 @@ func Listen(opts Options) (*Conn, error) { go c.periodicReSTUN() } go c.periodicDerpCleanup() - - return c, nil } func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() } @@ -278,10 +283,6 @@ func (c *Conn) ignoreSTUNPackets() { c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) } -// runs in its own goroutine until ctx is shut down. -// Whenever c.startEpUpdate receives a value, it starts an -// STUN endpoint lookup. -// // c.mu must NOT be held. func (c *Conn) updateEndpoints(why string) { defer func() { @@ -326,7 +327,11 @@ func (c *Conn) setEndpoints(endpoints []string) (changed bool) { } func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { - if DisableSTUNForTesting { + c.mu.Lock() + dm := c.derpMap + c.mu.Unlock() + + if DisableSTUNForTesting || dm == nil { return new(netcheck.Report), nil } @@ -336,7 +341,7 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { c.stunReceiveFunc.Store(c.netChecker.ReceiveSTUNPacket) defer c.ignoreSTUNPackets() - report, err := c.netChecker.GetReport(ctx) + report, err := c.netChecker.GetReport(ctx, dm) if err != nil { return nil, err } @@ -346,8 +351,11 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { MappingVariesByDestIP: report.MappingVariesByDestIP, HairPinning: report.HairPinning, } - for server, d := range report.DERPLatency { - ni.DERPLatency[server] = d.Seconds() + for rid, d := range report.RegionV4Latency { + ni.DERPLatency[fmt.Sprintf("%d-v4", rid)] = d.Seconds() + } + for rid, d := range report.RegionV6Latency { + ni.DERPLatency[fmt.Sprintf("%d-v6", rid)] = d.Seconds() } ni.WorkingIPv6.Set(report.IPv6) ni.WorkingUDP.Set(report.UDP) @@ -380,9 +388,12 @@ func (c *Conn) pickDERPFallback() int { c.mu.Lock() defer c.mu.Unlock() - ids := c.derps.IDs() + if !c.wantDerpLocked() { + return 0 + } + ids := c.derpMap.RegionIDs() if len(ids) == 0 { - // No DERP nodes registered. + // No DERP regions in non-nil map. return 0 } @@ -458,7 +469,7 @@ func (c *Conn) SetNetInfoCallback(fn func(*tailcfg.NetInfo)) { func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) { c.mu.Lock() defer c.mu.Unlock() - if !c.wantDerp { + if !c.wantDerpLocked() { c.myDerp = 0 return false } @@ -476,7 +487,7 @@ func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) { // On change, notify all currently connected DERP servers and // start connecting to our home DERP if we are not already. - c.logf("magicsock: home is now derp-%v (%v)", derpNum, c.derps.ServerByID(derpNum).Geo) + c.logf("magicsock: home is now derp-%v (%v)", derpNum, c.derpMap.Regions[derpNum].RegionCode) for i, ad := range c.activeDerp { go ad.c.NotePreferred(i == c.myDerp) } @@ -791,11 +802,11 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de if !addr.IP.Equal(derpMagicIP) { return nil } - nodeID := addr.Port + regionID := addr.Port c.mu.Lock() defer c.mu.Unlock() - if !c.wantDerp || c.closed { + if !c.wantDerpLocked() || c.closed { return nil } if c.privateKey.IsZero() { @@ -807,10 +818,10 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de // first. If so, might as well use it. (It's a little // arbitrary whether we use this one vs. the reverse route // below when we have both.) - ad, ok := c.activeDerp[nodeID] + ad, ok := c.activeDerp[regionID] if ok { *ad.lastWrite = time.Now() - c.setPeerLastDerpLocked(peer, nodeID, nodeID) + c.setPeerLastDerpLocked(peer, regionID, regionID) return ad.writeCh } @@ -823,7 +834,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de if !peer.IsZero() && debugUseDerpRoute { if r, ok := c.derpRoute[peer]; ok { if ad, ok := c.activeDerp[r.derpID]; ok && ad.c == r.dc { - c.setPeerLastDerpLocked(peer, r.derpID, nodeID) + c.setPeerLastDerpLocked(peer, r.derpID, regionID) *ad.lastWrite = time.Now() return ad.writeCh } @@ -834,7 +845,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de if !peer.IsZero() { why = peerShort(peer) } - c.logf("magicsock: adding connection to derp-%v for %v", nodeID, why) + c.logf("magicsock: adding connection to derp-%v for %v", regionID, why) firstDerp := false if c.activeDerp == nil { @@ -842,22 +853,23 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de c.activeDerp = make(map[int]activeDerp) c.prevDerp = make(map[int]*syncs.WaitGroupChan) } - derpSrv := c.derps.ServerByID(nodeID) - if derpSrv == nil || derpSrv.HostHTTPS == "" { + if c.derpMap == nil || c.derpMap.Regions[regionID] == nil { return nil } // Note that derphttp.NewClient does not dial the server // so it is safe to do under the mu lock. - dc, err := derphttp.NewClient(c.privateKey, "https://"+derpSrv.HostHTTPS+"/derp", c.logf) - if err != nil { - c.logf("magicsock: derphttp.NewClient: node %d, host %q invalid? err: %v", nodeID, derpSrv.HostHTTPS, err) - return nil - } + dc := derphttp.NewRegionClient(c.privateKey, c.logf, func() *tailcfg.DERPRegion { + c.mu.Lock() + defer c.mu.Unlock() + if c.derpMap == nil { + return nil + } + return c.derpMap.Regions[regionID] + }) - dc.NotePreferred(c.myDerp == nodeID) + dc.NotePreferred(c.myDerp == regionID) dc.DNSCache = dnscache.Get() - dc.TLSConfig = c.derpTLSConfig ctx, cancel := context.WithCancel(c.connCtx) ch := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop) @@ -868,21 +880,21 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de ad.lastWrite = new(time.Time) *ad.lastWrite = time.Now() ad.createTime = time.Now() - c.activeDerp[nodeID] = ad + c.activeDerp[regionID] = ad c.logActiveDerpLocked() - c.setPeerLastDerpLocked(peer, nodeID, nodeID) + c.setPeerLastDerpLocked(peer, regionID, regionID) // Build a startGate for the derp reader+writer // goroutines, so they don't start running until any // previous generation is closed. startGate := syncs.ClosedChan() - if prev := c.prevDerp[nodeID]; prev != nil { + if prev := c.prevDerp[regionID]; prev != nil { startGate = prev.DoneChan() } // And register a WaitGroup(Chan) for this generation. wg := syncs.NewWaitGroupChan() wg.Add(2) - c.prevDerp[nodeID] = wg + c.prevDerp[regionID] = wg if firstDerp { startGate = c.derpStarted @@ -899,37 +911,37 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de } // setPeerLastDerpLocked notes that peer is now being written to via -// provided DERP node nodeID, and that that advertises a DERP home -// node of homeID. +// the provided DERP regionID, and that the peer advertises a DERP +// home region ID of homeID. // // If there's any change, it logs. // // c.mu must be held. -func (c *Conn) setPeerLastDerpLocked(peer key.Public, nodeID, homeID int) { +func (c *Conn) setPeerLastDerpLocked(peer key.Public, regionID, homeID int) { if peer.IsZero() { return } old := c.peerLastDerp[peer] - if old == nodeID { + if old == regionID { return } - c.peerLastDerp[peer] = nodeID + c.peerLastDerp[peer] = regionID var newDesc string switch { - case nodeID == homeID && nodeID == c.myDerp: + case regionID == homeID && regionID == c.myDerp: newDesc = "shared home" - case nodeID == homeID: + case regionID == homeID: newDesc = "their home" - case nodeID == c.myDerp: + case regionID == c.myDerp: newDesc = "our home" - case nodeID != homeID: + case regionID != homeID: newDesc = "alt" } if old == 0 { - c.logf("magicsock: derp route for %s set to derp-%d (%s)", peerShort(peer), nodeID, newDesc) + c.logf("magicsock: derp route for %s set to derp-%d (%s)", peerShort(peer), regionID, newDesc) } else { - c.logf("magicsock: derp route for %s changed from derp-%d => derp-%d (%s)", peerShort(peer), old, nodeID, newDesc) + c.logf("magicsock: derp route for %s changed from derp-%d => derp-%d (%s)", peerShort(peer), old, regionID, newDesc) } } @@ -1284,18 +1296,27 @@ func (c *Conn) UpdatePeers(newPeers map[key.Public]struct{}) { } } -// SetDERPEnabled controls whether DERP is used. -// New connections have it enabled by default. -func (c *Conn) SetDERPEnabled(wantDerp bool) { +// SetDERPMap controls which (if any) DERP servers are used. +// A nil value means to disable DERP; it's disabled by default. +func (c *Conn) SetDERPMap(dm *tailcfg.DERPMap) { c.mu.Lock() defer c.mu.Unlock() - c.wantDerp = wantDerp - if !wantDerp { + if reflect.DeepEqual(dm, c.derpMap) { + return + } + + c.derpMap = dm + if dm == nil { c.closeAllDerpLocked("derp-disabled") + return } + + go c.ReSTUN("derp-map-update") } +func (c *Conn) wantDerpLocked() bool { return c.derpMap != nil } + // c.mu must be held. func (c *Conn) closeAllDerpLocked(why string) { if len(c.activeDerp) == 0 { @@ -1352,7 +1373,7 @@ func (c *Conn) logEndpointChange(endpoints []string, reasons map[string]string) } // c.mu must be held. -func (c *Conn) foreachActiveDerpSortedLocked(fn func(nodeID int, ad activeDerp)) { +func (c *Conn) foreachActiveDerpSortedLocked(fn func(regionID int, ad activeDerp)) { if len(c.activeDerp) < 2 { for id, ad := range c.activeDerp { fn(id, ad) @@ -1473,6 +1494,9 @@ func (c *Conn) periodicDerpCleanup() { func (c *Conn) ReSTUN(why string) { c.mu.Lock() defer c.mu.Unlock() + if !c.started { + panic("call to ReSTUN before Start") + } if c.closed { // raced with a shutdown. return diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 9b7b30aac..992768adc 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -27,6 +27,7 @@ import ( "tailscale.com/derp/derphttp" "tailscale.com/derp/derpmap" "tailscale.com/stun/stuntest" + "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -54,7 +55,7 @@ func (c *Conn) WaitReady(t *testing.T) { } } -func TestListen(t *testing.T) { +func TestNewConn(t *testing.T) { tstest.PanicOnLog() rc := tstest.NewResourceCheck() defer rc.Assert(t) @@ -70,9 +71,8 @@ func TestListen(t *testing.T) { defer stunCleanupFn() port := pickPort(t) - conn, err := Listen(Options{ + conn, err := NewConn(Options{ Port: port, - DERPs: derpmap.NewTestWorld(stunAddr), EndpointsFunc: epFunc, Logf: t.Logf, }) @@ -80,6 +80,8 @@ func TestListen(t *testing.T) { t.Fatal(err) } defer conn.Close() + conn.Start() + conn.SetDERPMap(stuntest.DERPMapOf(stunAddr.String())) go func() { var pkt [64 << 10]byte @@ -136,9 +138,8 @@ func TestPickDERPFallback(t *testing.T) { rc := tstest.NewResourceCheck() defer rc.Assert(t) - c := &Conn{ - derps: derpmap.Prod(), - } + c := newConn() + c.derpMap = derpmap.Prod() a := c.pickDERPFallback() if a == 0 { t.Fatalf("pickDERPFallback returned 0") @@ -156,7 +157,8 @@ func TestPickDERPFallback(t *testing.T) { // distribution over nodes works. got := map[int]int{} for i := 0; i < 50; i++ { - c = &Conn{derps: derpmap.Prod()} + c = newConn() + c.derpMap = derpmap.Prod() got[c.pickDERPFallback()]++ } t.Logf("distribution: %v", got) @@ -236,7 +238,7 @@ func parseCIDR(t *testing.T, addr string) wgcfg.CIDR { return cidr } -func runDERP(t *testing.T, logf logger.Logf) (s *derp.Server, addr string, cleanupFn func()) { +func runDERP(t *testing.T, logf logger.Logf) (s *derp.Server, addr *net.TCPAddr, cleanupFn func()) { var serverPrivateKey key.Private if _, err := crand.Read(serverPrivateKey[:]); err != nil { t.Fatal(err) @@ -250,14 +252,13 @@ func runDERP(t *testing.T, logf logger.Logf) (s *derp.Server, addr string, clean httpsrv.StartTLS() logf("DERP server URL: %s", httpsrv.URL) - addr = strings.TrimPrefix(httpsrv.URL, "https://") cleanupFn = func() { httpsrv.CloseClientConnections() httpsrv.Close() s.Close() } - return s, addr, cleanupFn + return s, httpsrv.Listener.Addr().(*net.TCPAddr), cleanupFn } // devLogger returns a wireguard-go device.Logger that writes @@ -286,13 +287,14 @@ func TestDeviceStartStop(t *testing.T) { rc := tstest.NewResourceCheck() defer rc.Assert(t) - conn, err := Listen(Options{ + conn, err := NewConn(Options{ EndpointsFunc: func(eps []string) {}, Logf: t.Logf, }) if err != nil { t.Fatal(err) } + conn.Start() defer conn.Close() tun := tuntest.NewChannelTUN() @@ -337,48 +339,58 @@ func TestTwoDevicePing(t *testing.T) { // all log using the "current" t.Logf function. Sigh. logf, setT := makeNestable(t) - // Wipe default DERP list, add local server. - // (Do it now, or derpHost will try to connect to derp1.tailscale.com.) derpServer, derpAddr, derpCleanupFn := runDERP(t, logf) defer derpCleanupFn() - stunAddr, stunCleanupFn := stuntest.Serve(t) defer stunCleanupFn() - derps := derpmap.NewTestWorldWith(&derpmap.Server{ - ID: 1, - HostHTTPS: derpAddr, - STUN4: stunAddr, - Geo: "Testopolis", - }) + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: &tailcfg.DERPRegion{ + RegionID: 1, + RegionCode: "test", + Nodes: []*tailcfg.DERPNode{ + { + Name: "t1", + RegionID: 1, + HostName: "test-node.unused", + IPv4: "127.0.0.1", + IPv6: "none", + STUNPort: stunAddr.Port, + DERPTestPort: derpAddr.Port, + }, + }, + }, + }, + } epCh1 := make(chan []string, 16) - conn1, err := Listen(Options{ - Logf: logger.WithPrefix(logf, "conn1: "), - DERPs: derps, + conn1, err := NewConn(Options{ + Logf: logger.WithPrefix(logf, "conn1: "), EndpointsFunc: func(eps []string) { epCh1 <- eps }, - derpTLSConfig: &tls.Config{InsecureSkipVerify: true}, }) if err != nil { t.Fatal(err) } defer conn1.Close() + conn1.Start() + conn1.SetDERPMap(derpMap) epCh2 := make(chan []string, 16) - conn2, err := Listen(Options{ - Logf: logger.WithPrefix(logf, "conn2: "), - DERPs: derps, + conn2, err := NewConn(Options{ + Logf: logger.WithPrefix(logf, "conn2: "), EndpointsFunc: func(eps []string) { epCh2 <- eps }, - derpTLSConfig: &tls.Config{InsecureSkipVerify: true}, }) if err != nil { t.Fatal(err) } defer conn2.Close() + conn2.Start() + conn2.SetDERPMap(derpMap) ports := []uint16{conn1.LocalPort(), conn2.LocalPort()} cfgs := makeConfigs(t, ports) diff --git a/wgengine/userspace.go b/wgengine/userspace.go index e0a4fda34..9a960fde6 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "log" @@ -49,7 +50,7 @@ const minimalMTU = 1280 type userspaceEngine struct { logf logger.Logf reqCh chan struct{} - waitCh chan struct{} + waitCh chan struct{} // chan is closed when first Close call completes; contrast with closing bool tundev *tstun.TUN wgdev *device.Device router router.Router @@ -61,6 +62,7 @@ type userspaceEngine struct { lastCfg wgcfg.Config mu sync.Mutex // guards following; see lock order comment below + closing bool // Close was called (even if we're still closing) statusCallback StatusCallback peerSequence []wgcfg.Key endpoints []string @@ -149,7 +151,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R Port: listenPort, EndpointsFunc: endpointsFn, } - e.magicConn, err = magicsock.Listen(magicsockOpts) + e.magicConn, err = magicsock.NewConn(magicsockOpts) if err != nil { tundev.Close() return nil, fmt.Errorf("wgengine: %v", err) @@ -210,6 +212,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R // routers do not Read or Write, but do access native interfaces. e.router, err = routerGen(logf, e.wgdev, e.tundev.Unwrap()) if err != nil { + e.magicConn.Close() return nil, err } @@ -235,16 +238,19 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R e.wgdev.Up() if err := e.router.Up(); err != nil { + e.magicConn.Close() e.wgdev.Close() return nil, err } // TODO(danderson): we should delete this. It's pointless to apply // a no-op settings here. if err := e.router.Set(nil); err != nil { + e.magicConn.Close() e.wgdev.Close() return nil, err } e.linkMon.Start() + e.magicConn.Start() return e, nil } @@ -407,6 +413,13 @@ func (e *userspaceEngine) getStatus() (*Status, error) { e.wgLock.Lock() defer e.wgLock.Unlock() + e.mu.Lock() + closing := e.closing + e.mu.Unlock() + if closing { + return nil, errors.New("engine closing; no status") + } + if e.wgdev == nil { // RequestStatus was invoked before the wgengine has // finished initializing. This can happen when wgegine @@ -553,6 +566,11 @@ func (e *userspaceEngine) RequestStatus() { func (e *userspaceEngine) Close() { e.mu.Lock() + if e.closing { + e.mu.Unlock() + return + } + e.closing = true for key, cancel := range e.pingers { delete(e.pingers, key) cancel() @@ -614,8 +632,8 @@ func (e *userspaceEngine) SetNetInfoCallback(cb NetInfoCallback) { e.magicConn.SetNetInfoCallback(cb) } -func (e *userspaceEngine) SetDERPEnabled(v bool) { - e.magicConn.SetDERPEnabled(v) +func (e *userspaceEngine) SetDERPMap(dm *tailcfg.DERPMap) { + e.magicConn.SetDERPMap(dm) } func (e *userspaceEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index 9d409ece7..ef9393a47 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -12,6 +12,7 @@ import ( "github.com/tailscale/wireguard-go/wgcfg" "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/router" ) @@ -88,8 +89,8 @@ func (e *watchdogEngine) RequestStatus() { func (e *watchdogEngine) LinkChange(isExpensive bool) { e.watchdog("LinkChange", func() { e.wrap.LinkChange(isExpensive) }) } -func (e *watchdogEngine) SetDERPEnabled(v bool) { - e.watchdog("SetDERPEnabled", func() { e.wrap.SetDERPEnabled(v) }) +func (e *watchdogEngine) SetDERPMap(m *tailcfg.DERPMap) { + e.watchdog("SetDERPMap", func() { e.wrap.SetDERPMap(m) }) } func (e *watchdogEngine) Close() { e.watchdog("Close", e.wrap.Close) diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index 3a229c713..81dcee80e 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -95,9 +95,10 @@ type Engine interface { // action on. LinkChange(isExpensive bool) - // SetDERPEnabled controls whether DERP is enabled. - // It starts enabled by default. - SetDERPEnabled(bool) + // SetDERPMap controls which (if any) DERP servers are used. + // If nil, DERP is disabled. It starts disabled until a DERP map + // is configured. + SetDERPMap(*tailcfg.DERPMap) // SetNetInfoCallback sets the function to call when a // new NetInfo summary is available.