diff --git a/go.sum b/go.sum index 6ebffb007..0e44997d5 100644 --- a/go.sum +++ b/go.sum @@ -185,6 +185,7 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac h1:MQEvx39qSf8vyrx3XRaOe+j1UDIzKwkYOVObRgGPVqI= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d h1:/iIZNFGxc/a7C3yWjGcnboV+Tkc7mxr+p6fDztwoxuM= golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/netcheck/netcheck.go b/netcheck/netcheck.go index 9099af0ae..4764d9e64 100644 --- a/netcheck/netcheck.go +++ b/netcheck/netcheck.go @@ -12,7 +12,6 @@ import ( "io" "log" "net" - "strconv" "sync" "time" @@ -69,8 +68,10 @@ type Client struct { GetSTUNConn4 func() STUNConn GetSTUNConn6 func() STUNConn - s4 *stunner.Stunner - s6 *stunner.Stunner + s4 *stunner.Stunner + s6 *stunner.Stunner + hairTX stun.TxID + gotHairSTUN chan *net.UDPAddr } // STUNConn is the interface required by the netcheck Client when @@ -88,11 +89,27 @@ func (c *Client) logf(format string, a ...interface{}) { } } +// handleHairSTUN reports whether pkt (from src) was our magic hairpin +// probe packet that we sent to ourselves. +func (c *Client) handleHairSTUN(pkt []byte, src *net.UDPAddr) bool { + if tx, err := stun.ParseBindingRequest(pkt); err == nil && tx == c.hairTX { + select { + case c.gotHairSTUN <- src: + default: + } + return true + } + return false +} + func (c *Client) ReceiveSTUNPacket(pkt []byte, src *net.UDPAddr) { var st *stunner.Stunner if src == nil || src.IP == nil { panic("bogus src") } + if c.handleHairSTUN(pkt, src) { + return + } if src.IP.To4() != nil { st = c.s4 } else { @@ -116,6 +133,8 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { c.s4 = nil c.s6 = nil }() + c.hairTX = stun.NewTxID() // random payload + c.gotHairSTUN = make(chan *net.UDPAddr, 1) if c.DERP == nil { return nil, errors.New("netcheck: GetReport: Client.DERP is nil") @@ -149,13 +168,28 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { if err != nil { c.logf("interfaces: %v", err) } + + // Create a UDP4 socket used for sending to our discovered IPv4 address. + pc4Hair, err := net.ListenPacket("udp4", ":0") + if err != nil { + c.logf("udp4: %v", err) + return nil, err + } + defer pc4Hair.Close() + hairTimeout := make(chan bool, 1) + startHairCheck := func(dstEP string) { + if dst, err := net.ResolveUDPAddr("udp4", dstEP); err == nil { + pc4Hair.WriteTo(stun.Request(c.hairTX), dst) + time.AfterFunc(500*time.Millisecond, func() { hairTimeout <- true }) + } + } + var ( mu sync.Mutex ret = &Report{ DERPLatency: map[string]time.Duration{}, } gotEP = map[string]string{} // server -> ipPort - gotEPHair = map[string]string{} // server -> ipPort for second UDP4 for hairpinning gotEP4 string bestDerpLatency time.Duration ) @@ -183,6 +217,7 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { if gotEP4 == "" { gotEP4 = ipPort ret.GlobalV4 = ipPort + startHairCheck(ipPort) } else { if gotEP4 != ipPort { ret.MappingVariesByDestIP.Set(true) @@ -198,11 +233,6 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { ret.PreferredDERP = c.DERP.NodeIDOfSTUNServer(server) } } - addHair := func(server, ipPort string, d time.Duration) { - mu.Lock() - defer mu.Unlock() - gotEPHair[server] = ipPort - } var pc4, pc6 STUNConn @@ -218,14 +248,6 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { go closeOnCtx(u4) } - // And a second UDP4 socket to check hairpinning. - pc4Hair, err := net.ListenPacket("udp4", ":0") - if err != nil { - c.logf("udp4: %v", err) - return nil, err - } - go closeOnCtx(pc4Hair) - if v6iface { if f := c.GetSTUNConn6; f != nil { pc6 = f() @@ -240,9 +262,9 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { } } - reader := func(s *stunner.Stunner, pc STUNConn, maxReads int) { + reader := func(s *stunner.Stunner, pc STUNConn) { var buf [64 << 10]byte - for i := 0; i < maxReads; i++ { + for { n, addr, err := pc.ReadFrom(buf[:]) if err != nil { if ctx.Err() != nil { @@ -256,6 +278,9 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { c.logf("ReadFrom: unexpected addr %T", addr) continue } + if c.handleHairSTUN(buf[:n], ua) { + continue + } s.Receive(buf[:n], ua) } @@ -263,7 +288,6 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { var grp errgroup.Group - const unlimited = 9999 // effectively, closed on cancel anyway s4 := &stunner.Stunner{ Send: pc4.WriteTo, Endpoint: add, @@ -274,20 +298,10 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { c.s4 = s4 grp.Go(func() error { return s4.Run(ctx) }) if c.GetSTUNConn4 == nil { - go reader(s4, pc4, unlimited) - } - - s4Hair := &stunner.Stunner{ - Send: pc4Hair.WriteTo, - Endpoint: addHair, - Servers: stuns4, - Logf: c.logf, - DNSCache: dnscache.Get(), + go reader(s4, pc4) } - grp.Go(func() error { return s4Hair.Run(ctx) }) - go reader(s4Hair, pc4Hair, 2) - if pc6 != nil { + if pc6 != nil && len(stuns6) > 0 { s6 := &stunner.Stunner{ Endpoint: add, Send: pc6.WriteTo, @@ -299,7 +313,7 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { c.s6 = s6 grp.Go(func() error { return s6.Run(ctx) }) if c.GetSTUNConn6 == nil { - go reader(s6, pc6, unlimited) + go reader(s6, pc6) } } @@ -312,17 +326,12 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { defer mu.Unlock() // Check hairpinning. - if ret.MappingVariesByDestIP == "false" { - hairIPStr, hairPortStr, _ := net.SplitHostPort(gotEPHair["derp1.tailscale.com:3478"]) - hairIP := net.ParseIP(hairIPStr) - hairPort, _ := strconv.Atoi(hairPortStr) - if hairIP != nil && hairPort != 0 { - tx := stun.NewTxID() // random payload - pc4.WriteTo(tx[:], &net.UDPAddr{IP: hairIP, Port: hairPort}) - var got stun.TxID - pc4Hair.SetReadDeadline(time.Now().Add(1 * time.Second)) - _, _, err := pc4Hair.ReadFrom(got[:]) - ret.HairPinning.Set(err == nil && got == tx) + if ret.MappingVariesByDestIP == "false" && gotEP4 != "" { + select { + case <-c.gotHairSTUN: + ret.HairPinning.Set(true) + case <-hairTimeout: + ret.HairPinning.Set(false) } } diff --git a/netcheck/netcheck_test.go b/netcheck/netcheck_test.go new file mode 100644 index 000000000..a28d9a974 --- /dev/null +++ b/netcheck/netcheck_test.go @@ -0,0 +1,31 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netcheck + +import ( + "net" + "testing" + + "tailscale.com/stun" +) + +func TestHairpinSTUN(t *testing.T) { + c := &Client{ + hairTX: stun.NewTxID(), + gotHairSTUN: make(chan *net.UDPAddr, 1), + } + req := stun.Request(c.hairTX) + if !stun.Is(req) { + t.Fatal("expected STUN message") + } + if !c.handleHairSTUN(req, nil) { + t.Fatal("expected true") + } + select { + case <-c.gotHairSTUN: + default: + t.Fatal("expected value") + } +} diff --git a/stunner/stunner.go b/stunner/stunner.go index 94f75557c..3c21aa80e 100644 --- a/stunner/stunner.go +++ b/stunner/stunner.go @@ -6,9 +6,12 @@ package stunner import ( "context" + "errors" "fmt" + "math/rand" "net" "strconv" + "strings" "sync" "time" @@ -37,6 +40,10 @@ type Stunner struct { // took on the wire (not including DNS lookup time. Endpoint func(server, endpoint string, d time.Duration) + // onPacket is the internal version of Endpoint that does de-dup. + // It's set by Run. + onPacket func(server, endpoint string, d time.Duration) + Servers []string // STUN servers to contact // DNSCache optionally specifies a DNSCache to use. @@ -50,10 +57,6 @@ type Stunner struct { // If false, only IPv4 is used. There is currently no mixed mode. OnlyIPv6 bool - // sessions tracks the state of each server. - // It's keyed by the STUN server (from the Servers field). - sessions map[string]*session - mu sync.Mutex inFlight map[stun.TxID]request } @@ -61,8 +64,8 @@ type Stunner struct { func (s *Stunner) addTX(tx stun.TxID, server string) { s.mu.Lock() defer s.mu.Unlock() - if s.inFlight == nil { - s.inFlight = make(map[stun.TxID]request) + if _, dup := s.inFlight[tx]; dup { + panic("unexpected duplicate STUN TransactionID") } s.inFlight[tx] = request{sent: time.Now(), server: server} } @@ -70,8 +73,15 @@ func (s *Stunner) addTX(tx stun.TxID, server string) { func (s *Stunner) removeTX(tx stun.TxID) (request, bool) { s.mu.Lock() defer s.mu.Unlock() + if s.inFlight == nil { + return request{}, false + } r, ok := s.inFlight[tx] - delete(s.inFlight, tx) + if ok { + delete(s.inFlight, tx) + } else { + s.logf("stunner: got STUN packet for unknown TxID %x", tx) + } return r, ok } @@ -80,11 +90,6 @@ type request struct { server string } -type session struct { - ctx context.Context // closed via call to done when reply received - cancel context.CancelFunc -} - func (s *Stunner) logf(format string, args ...interface{}) { if s.Logf != nil { s.Logf(format, args...) @@ -105,95 +110,113 @@ func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) { } r, ok := s.removeTX(tx) if !ok { - s.logf("stunner: got STUN packet for unknown TxID %x", tx) return } d := now.Sub(r.sent) - session := s.sessions[r.server] - if session != nil { - host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port)) - s.Endpoint(r.server, host, d) - session.cancel() - } + host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port)) + s.onPacket(r.server, host, d) } func (s *Stunner) resolver() *net.Resolver { return net.DefaultResolver } +// cleanUpPostRun zeros out some fields, mostly for debugging (so +// things crash or race+fail if there's a sender still running.) +func (s *Stunner) cleanUpPostRun() { + s.mu.Lock() + s.inFlight = nil + s.mu.Unlock() +} + // Run starts a Stunner and blocks until all servers either respond // or are tried multiple times and timeout. -// -// TODO: this always returns success now. It should return errors -// if certain servers are unavailable probably. Or if all are. -// Or some configured threshold are. +// It can not be called concurrently with itself. func (s *Stunner) Run(ctx context.Context) error { - s.sessions = map[string]*session{} for _, server := range s.Servers { if _, _, err := net.SplitHostPort(server); err != nil { return fmt.Errorf("Stunner.Run: invalid server %q (in Server list %q)", server, s.Servers) } - sctx, cancel := context.WithCancel(ctx) - s.sessions[server] = &session{ - ctx: sctx, - cancel: cancel, + } + if len(s.Servers) == 0 { + return errors.New("stunner: no Servers") + } + + s.inFlight = make(map[stun.TxID]request) + defer s.cleanUpPostRun() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + type sender struct { + ctx context.Context + cancel context.CancelFunc + } + var ( + needMu sync.Mutex + need = make(map[string]sender) // keyed by server; deleted when done + allDone = make(chan struct{}) // closed when need is empty + ) + s.onPacket = func(server, endpoint string, d time.Duration) { + needMu.Lock() + defer needMu.Unlock() + sender, ok := need[server] + if !ok { + return + } + sender.cancel() + delete(need, server) + s.Endpoint(server, endpoint, d) + if len(need) == 0 { + close(allDone) } } - // after this point, the s.sessions map is read-only var wg sync.WaitGroup for _, server := range s.Servers { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + need[server] = sender{ctx, cancel} + } + for server, sender := range need { wg.Add(1) - go func(server string) { + server, ctx := server, sender.ctx + go func() { defer wg.Done() - s.runServer(ctx, server) - }(server) + s.sendPackets(ctx, server) + }() + } + var err error + select { + case <-ctx.Done(): + err = ctx.Err() + case <-allDone: + cancel() } wg.Wait() - return nil -} - -func (s *Stunner) runServer(ctx context.Context, server string) { - session := s.sessions[server] - - // If we're using a DNS cache, prime the cache before doing - // any quick timeouts (100ms, etc) so the timeout doesn't - // apply to the first DNS lookup. - if s.DNSCache != nil { - _, _ = s.DNSCache.LookupIP(ctx, server) + var missing []string + needMu.Lock() + for server := range need { + missing = append(missing, server) } + needMu.Unlock() - for i, d := range retryDurations { - ctx, cancel := context.WithTimeout(ctx, d) - err := s.sendSTUN(ctx, server) - if err != nil { - s.logf("stunner: sendSTUN(%q): %v", server, err) - } - - select { - case <-ctx.Done(): - cancel() - case <-session.ctx.Done(): - cancel() - if i > 0 { - s.logf("stunner: slow STUN response from %s: %d retries", server, i) - } - return - } + if len(missing) == 0 || err == nil { + return nil } - s.logf("stunner: no STUN response from %s", server) + return fmt.Errorf("got STUN error: %v; missing replies from: %v", err, strings.Join(missing, ", ")) } -func (s *Stunner) sendSTUN(ctx context.Context, server string) error { - host, port, err := net.SplitHostPort(server) +func (s *Stunner) serverAddr(ctx context.Context, server string) (*net.UDPAddr, error) { + hostStr, portStr, err := net.SplitHostPort(server) if err != nil { - return err + return nil, err } - addrPort, err := strconv.Atoi(port) + addrPort, err := strconv.Atoi(portStr) if err != nil { - return fmt.Errorf("port: %v", err) + return nil, fmt.Errorf("port: %v", err) } if addrPort == 0 { addrPort = 3478 @@ -202,17 +225,18 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error { var ipAddrs []net.IPAddr if s.DNSCache != nil { - ip, err := s.DNSCache.LookupIP(ctx, host) + ip, err := s.DNSCache.LookupIP(ctx, hostStr) if err != nil { - return fmt.Errorf("lookup ip addr from cache (%q): %v", host, err) + return nil, err } ipAddrs = []net.IPAddr{{IP: ip}} } else { - ipAddrs, err = s.resolver().LookupIPAddr(ctx, host) + ipAddrs, err = s.resolver().LookupIPAddr(ctx, hostStr) if err != nil { - return fmt.Errorf("lookup ip addr (%q): %v", host, err) + return nil, fmt.Errorf("lookup ip addr (%q): %v", hostStr, err) } } + for _, ipAddr := range ipAddrs { ip4 := ipAddr.IP.To4() if ip4 != nil { @@ -228,29 +252,36 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error { } if addr.IP == nil { if s.OnlyIPv6 { - return fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs) + return nil, fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs) } - return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs) + return nil, fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs) } + return addr, nil +} - txID := stun.NewTxID() - req := stun.Request(txID) - s.addTX(txID, server) - _, err = s.Send(req, addr) +func (s *Stunner) sendPackets(ctx context.Context, server string) error { + addr, err := s.serverAddr(ctx, server) if err != nil { - return fmt.Errorf("send: %v", err) + return err } - return nil -} -var retryDurations = []time.Duration{ - 100 * time.Millisecond, - 100 * time.Millisecond, - 100 * time.Millisecond, - 200 * time.Millisecond, - 200 * time.Millisecond, - 400 * time.Millisecond, - 800 * time.Millisecond, - 1600 * time.Millisecond, - 3200 * time.Millisecond, + const maxSend = 2 + for i := 0; i < maxSend; i++ { + txID := stun.NewTxID() + req := stun.Request(txID) + s.addTX(txID, server) + _, err = s.Send(req, addr) + if err != nil { + return fmt.Errorf("send: %v", err) + } + + select { + case <-ctx.Done(): + // Ignore error. The caller deals with handling contexts. + // We only use it to dermine when to stop spraying STUN packets. + return nil + case <-time.After(time.Millisecond * time.Duration(50+rand.Intn(200))): + } + } + return nil } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index fd0b091d1..ce3903905 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -138,9 +138,8 @@ type Options struct { // Zero means to pick one automatically. Port uint16 - // STUN, if non-empty, specifies alternate STUN servers for testing. - // If empty, the production DERP servers are used. - STUN []string + // DERPs, if non-nil, is used instead of derpmap.Prod. + DERPs *derpmap.World // EndpointsFunc optionally provides a func to be called when // endpoints change. The called func does not own the slice. @@ -202,11 +201,11 @@ func Listen(opts Options) (*Conn, error) { derpRecvCh: make(chan derpReadResult), udpRecvCh: make(chan udpReadResult), derpTLSConfig: opts.derpTLSConfig, - derps: derpmap.Prod(), + derps: opts.DERPs, } c.linkState, _ = getLinkState() - if len(opts.STUN) > 0 { - c.derps = derpmap.NewTestWorld(opts.STUN...) + if c.derps == nil { + c.derps = derpmap.Prod() } c.netChecker = &netcheck.Client{ DERP: c.derps, diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 1e8efac27..188b79002 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -32,7 +32,6 @@ import ( ) func TestListen(t *testing.T) { - // TODO(crawshaw): when offline this test spends a while trying to connect to real derp servers. epCh := make(chan string, 16) epFunc := func(endpoints []string) { @@ -47,7 +46,7 @@ func TestListen(t *testing.T) { port := pickPort(t) conn, err := Listen(Options{ Port: port, - STUN: []string{stunAddr}, + DERPs: derpmap.NewTestWorld(stunAddr), EndpointsFunc: epFunc, Logf: t.Logf, }) @@ -157,7 +156,7 @@ func serveSTUN(t *testing.T) (addr string, cleanupFn func()) { } stunAddr := pc.LocalAddr().String() - stunAddr = strings.Replace(stunAddr, "0.0.0.0:", "localhost:", 1) + stunAddr = strings.Replace(stunAddr, "0.0.0.0:", "127.0.0.1:", 1) doneCh := make(chan struct{}) go runSTUN(t, pc, &stats, doneCh) @@ -343,8 +342,8 @@ func TestTwoDevicePing(t *testing.T) { epCh1 := make(chan []string, 16) conn1, err := Listen(Options{ - Logf: logger.WithPrefix(t.Logf, "conn1: "), - STUN: []string{stunAddr}, + Logf: logger.WithPrefix(t.Logf, "conn1: "), + DERPs: derps, EndpointsFunc: func(eps []string) { epCh1 <- eps }, @@ -353,13 +352,12 @@ func TestTwoDevicePing(t *testing.T) { if err != nil { t.Fatal(err) } - conn1.derps = derps defer conn1.Close() epCh2 := make(chan []string, 16) conn2, err := Listen(Options{ - Logf: logger.WithPrefix(t.Logf, "conn2: "), - STUN: []string{stunAddr}, + Logf: logger.WithPrefix(t.Logf, "conn2: "), + DERPs: derps, EndpointsFunc: func(eps []string) { epCh2 <- eps }, @@ -368,7 +366,6 @@ func TestTwoDevicePing(t *testing.T) { if err != nil { t.Fatal(err) } - conn2.derps = derps defer conn2.Close() ports := []uint16{conn1.LocalPort(), conn2.LocalPort()}