diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 4855da2c9..e8a2039c7 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -45,15 +45,13 @@ import ( // A Conn routes UDP packets and actively manages a list of its endpoints. // It implements wireguard/device.Bind. type Conn struct { - pconn *RebindingUDPConn - pconnPort uint16 - startEpUpdate chan string // send (with reason string) to trigger endpoint update - epFunc func(endpoints []string) - logf logger.Logf - sendLogLimit *rate.Limiter - derps *derpmap.World - netChecker *netcheck.Client - goroutines sync.WaitGroup + pconn *RebindingUDPConn + pconnPort uint16 + epFunc func(endpoints []string) + logf logger.Logf + sendLogLimit *rate.Limiter + derps *derpmap.World + netChecker *netcheck.Client // bufferedIPv4From and bufferedIPv4Packet are owned by // ReceiveIPv4, and used when both a DERP and IPv4 packet arrive @@ -64,6 +62,21 @@ type Conn struct { connCtx context.Context // closed on Conn.Close connCtxCancel func() // closes connCtx + // stunReceiveFunc holds the current STUN packet processing func. + // Its Loaded value is always non-nil. + stunReceiveFunc atomic.Value // of func(p []byte, fromAddr *net.UDPAddr) + + udpRecvCh chan udpReadResult + derpRecvCh chan derpReadResult + + mu sync.Mutex // guards all following fields + + closed bool + + endpointsUpdateActive bool + wantEndpointsUpdate string // non-empty for why reason + lastEndpoints []string + // addrsByUDP is a map of every remote ip:port to a priority // list of endpoint addresses for a peer. // The priority list is provided by wgengine configuration. @@ -75,22 +88,15 @@ type Conn struct { // 10.0.0.1:1 -> [10.0.0.1:1, 10.0.0.2:2] // 10.0.0.2:2 -> [10.0.0.1:1, 10.0.0.2:2] // 10.0.0.3:3 -> [10.0.0.3:3] - addrsMu sync.Mutex - addrsByUDP map[udpAddr]*AddrSet // TODO: clean up this map sometime? - addrsByKey map[key.Public]*AddrSet // TODO: clean up this map sometime? + addrsByUDP map[udpAddr]*AddrSet // TODO: clean up this map sometime? - // stunReceiveFunc holds the current STUN packet processing func. - // Its Loaded value is always non-nil. - stunReceiveFunc atomic.Value // of func(p []byte, fromAddr *net.UDPAddr) + // addsByKey maps from public keys (as seen by incoming DERP + // packets) to its AddrSet (the same values as in addrsByUDP). + addrsByKey map[key.Public]*AddrSet // TODO: clean up this map sometime? - netInfoMu sync.Mutex netInfoFunc func(*tailcfg.NetInfo) // nil until set netInfoLast *tailcfg.NetInfo - udpRecvCh chan udpReadResult - derpRecvCh chan derpReadResult - - derpMu sync.Mutex wantDerp bool privateKey key.Private myDerp int // nearest DERP server; 0 means none/unknown @@ -189,7 +195,6 @@ func Listen(opts Options) (*Conn, error) { pconn: new(RebindingUDPConn), pconnPort: opts.Port, sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), - startEpUpdate: make(chan string, 1), connCtx: connCtx, connCtxCancel: connCtxCancel, epFunc: opts.endpointsFunc(), @@ -216,11 +221,13 @@ func Listen(opts Options) (*Conn, error) { c.pconn.Reset(packetConn.(*net.UDPConn)) c.ReSTUN("initial") - c.goroutines.Add(1) - go func() { - defer c.goroutines.Done() - c.epUpdate(connCtx) - }() + // We assume that LinkChange notifications are plumbed through well + // on our mobile clients, so don't do the timer thing to save radio/battery/CPU/etc. + if !version.IsMobile() { + go c.periodicReSTUN() + } + go c.periodicDerpCleanup() + return c, nil } @@ -231,75 +238,49 @@ func (c *Conn) ignoreSTUNPackets() { c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) } -// epUpdate runs in its own goroutine until ctx is shut down. +// runs in its own goroutine until ctx is shut down. // Whenever c.startEpUpdate receives a value, it starts an // STUN endpoint lookup. -func (c *Conn) epUpdate(ctx context.Context) { - var lastEndpoints []string - var lastCancel func() - var lastDone chan struct{} - - var regularUpdate <-chan time.Time - if !version.IsMobile() { - // We assume that LinkChange notifications are plumbed through well - // on our mobile clients, so don't do the timer thing to save radio/battery/CPU/etc. - ticker := time.NewTicker(28 * time.Second) // just under 30s, a likely UDP NAT timeout - defer ticker.Stop() - regularUpdate = ticker.C - } - - for { - var why string - select { - case <-ctx.Done(): - if lastCancel != nil { - lastCancel() - <-lastDone - } - return - case why = <-c.startEpUpdate: - case <-regularUpdate: - why = "timer" +// +// c.mu must NOT be held. +func (c *Conn) updateEndpoints(why string) { + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + why := c.wantEndpointsUpdate + c.wantEndpointsUpdate = "" + if why != "" && !c.closed { + go c.updateEndpoints(why) + } else { + c.endpointsUpdateActive = false } - if lastCancel != nil { - select { - case <-lastDone: - default: - c.logf("magicsock.Conn.epUpdate: starting new endpoint update (for %s) while previous running; cancelling previous...", why) - lastCancel() - <-lastDone - } - } - c.logf("magicsock.Conn.epUpdate: starting endpoint update (%s)", why) - var epCtx context.Context - epCtx, lastCancel = context.WithCancel(ctx) - lastDone = make(chan struct{}) + }() + c.logf("magicsock.Conn: starting endpoint update (%s)", why) - go func() { - defer close(lastDone) + endpoints, err := c.determineEndpoints(c.connCtx) + if err != nil { + c.logf("magicsock.Conn: endpoint update (%s) failed: %v", why, err) + // TODO(crawshaw): are there any conditions under which + // we should trigger a retry based on the error here? + return + } - c.cleanStaleDerp() + if c.setEndpoints(endpoints) { + c.epFunc(endpoints) + } +} - netReport, err := c.updateNetInfo(epCtx) - if err != nil { - c.logf("magicsock.Conn: updateNetInfo failed: %v", err) - return - } - endpoints, err := c.determineEndpoints(epCtx, netReport) - if err != nil { - c.logf("magicsock.Conn: endpoint update failed: %v", err) - // TODO(crawshaw): are there any conditions under which - // we should trigger a retry based on the error here? - return - } - if stringsEqual(endpoints, lastEndpoints) { - return - } - lastEndpoints = endpoints - c.epFunc(endpoints) - }() +// setEndpoints records the new endpoints, reporting whether they're changed. +// It takes ownership of the slice. +func (c *Conn) setEndpoints(endpoints []string) (changed bool) { + c.mu.Lock() + defer c.mu.Unlock() + if stringsEqual(endpoints, c.lastEndpoints) { + return false } + c.lastEndpoints = endpoints + return true } func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { @@ -351,9 +332,11 @@ var processStartUnixNano = time.Now().UnixNano() // connect to. This is only used if netcheck couldn't find the // nearest one (for instance, if UDP is blocked and thus STUN latency // checks aren't working). +// +// c.mu must NOT be held. func (c *Conn) pickDERPFallback() int { - c.derpMu.Lock() - defer c.derpMu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() if c.myDerp != 0 { // If we already had one in the past, stay on it. @@ -376,9 +359,11 @@ func (c *Conn) pickDERPFallback() int { // since the last state. // // callNetInfoCallback takes ownership of ni. +// +// c.mu must NOT be held. func (c *Conn) callNetInfoCallback(ni *tailcfg.NetInfo) { - c.netInfoMu.Lock() - defer c.netInfoMu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() if ni.BasicallyEqual(c.netInfoLast) { return } @@ -393,19 +378,20 @@ func (c *Conn) SetNetInfoCallback(fn func(*tailcfg.NetInfo)) { if fn == nil { panic("nil NetInfoCallback") } - c.netInfoMu.Lock() + c.mu.Lock() last := c.netInfoLast c.netInfoFunc = fn - c.netInfoMu.Unlock() + c.mu.Unlock() if last != nil { fn(last) } } +// c.mu must NOT be held. func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) { - c.derpMu.Lock() - defer c.derpMu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() if !c.wantDerp { c.myDerp = 0 return false @@ -427,8 +413,16 @@ func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) { } // determineEndpoints returns the machine's endpoint addresses. It -// does a STUN lookup to determine its public address. -func (c *Conn) determineEndpoints(ctx context.Context, nr *netcheck.Report) (ipPorts []string, err error) { +// does a STUN lookup (via netcheck) to determine its public address. +// +// c.mu must NOT be held. +func (c *Conn) determineEndpoints(ctx context.Context) (ipPorts []string, err error) { + nr, err := c.updateNetInfo(ctx) + if err != nil { + c.logf("magicsock.Conn.determineEndpoints: updateNetInfo: %v", err) + return + } + already := make(map[string]bool) // endpoint -> true var eps []string // unique endpoints @@ -696,9 +690,9 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest { if !addr.IP.Equal(derpMagicIP) { return nil } - c.derpMu.Lock() - defer c.derpMu.Unlock() - if !c.wantDerp { + c.mu.Lock() + defer c.mu.Unlock() + if !c.wantDerp || c.closed { return nil } if c.privateKey.IsZero() { @@ -717,7 +711,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest { } // Note that derphttp.NewClient does not dial the server - // so it is safe to do under the derpMu lock. + // so it is safe to do under the mu lock. dc, err := derphttp.NewClient(c.privateKey, "https://"+derpSrv.HostHTTPS+"/derp", c.logf) if err != nil { c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, derpSrv.HostHTTPS, err) @@ -881,8 +875,8 @@ func (c *Conn) findAddrSet(addr *net.UDPAddr) *AddrSet { copy(epAddr.ip.Addr[:], addr.IP.To16()) epAddr.port = uint16(addr.Port) - c.addrsMu.Lock() - defer c.addrsMu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() return c.addrsByUDP[epAddr] } @@ -971,9 +965,9 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr return 0, nil, nil, err } - c.addrsMu.Lock() + c.mu.Lock() addrSet = c.addrsByKey[dm.src] - c.addrsMu.Unlock() + c.mu.Unlock() if addrSet == nil { key := wgcfg.Key(dm.src) @@ -1021,8 +1015,8 @@ func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error // If the private key changes, any DERP connections are torn down & // recreated when needed. func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { - c.derpMu.Lock() - defer c.derpMu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() oldKey, newKey := c.privateKey, key.Private(privateKey) if newKey == oldKey { @@ -1043,8 +1037,8 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { // SetDERPEnabled controls whether DERP is used. // New connections have it enabled by default. func (c *Conn) SetDERPEnabled(wantDerp bool) { - c.derpMu.Lock() - defer c.derpMu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() c.wantDerp = wantDerp if !wantDerp { @@ -1052,14 +1046,14 @@ func (c *Conn) SetDERPEnabled(wantDerp bool) { } } -// c.derpMu must be held. +// c.mu must be held. func (c *Conn) closeAllDerpLocked() { for i := range c.activeDerp { c.closeDerpLocked(i) } } -// c.derpMu must be held. +// c.mu must be held. func (c *Conn) closeDerpLocked(node int) { if ad, ok := c.activeDerp[node]; ok { c.logf("closing connection to derp%v", node) @@ -1070,8 +1064,8 @@ func (c *Conn) closeDerpLocked(node int) { } func (c *Conn) cleanStaleDerp() { - c.derpMu.Lock() - defer c.derpMu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() const inactivityTime = 60 * time.Second tooOld := time.Now().Add(-inactivityTime) for i, ad := range c.activeDerp { @@ -1088,30 +1082,62 @@ func (c *Conn) cleanStaleDerp() { func (c *Conn) SetMark(value uint32) error { return nil } func (c *Conn) LastMark() uint32 { return 0 } +// Close closes the connection. +// +// Only the first close does anything. Any later closes return nil. func (c *Conn) Close() error { - // TODO: make this safe for concurrent Close? it's safe now only if Close calls are serialized. - select { - case <-c.donec(): + c.mu.Lock() + if c.closed { + c.mu.Unlock() return nil - default: } - c.connCtxCancel() + defer c.mu.Unlock() - c.derpMu.Lock() + c.closed = true + c.connCtxCancel() c.closeAllDerpLocked() - c.derpMu.Unlock() + return c.pconn.Close() +} - err := c.pconn.Close() - c.goroutines.Wait() - return err +func (c *Conn) periodicReSTUN() { + ticker := time.NewTicker(28 * time.Second) // just under 30s, a likely UDP NAT timeout + defer ticker.Stop() + for { + select { + case <-c.donec(): + return + case <-ticker.C: + c.ReSTUN("periodic") + } + } +} + +func (c *Conn) periodicDerpCleanup() { + ticker := time.NewTicker(15 * time.Second) // arbitrary + defer ticker.Stop() + for { + select { + case <-c.donec(): + return + case <-ticker.C: + c.cleanStaleDerp() + } + } } // ReSTUN triggers an address discovery. // The provided why string is for debug logging only. func (c *Conn) ReSTUN(why string) { - select { - case c.startEpUpdate <- why: - case <-c.donec(): + c.mu.Lock() + defer c.mu.Unlock() + if c.endpointsUpdateActive { + if c.wantEndpointsUpdate != why { + c.logf("magicsock.Conn.ReSTUN: endpoint update active, need another later (%q)", why) + c.wantEndpointsUpdate = why + } + } else { + c.endpointsUpdateActive = true + go c.updateEndpoints(why) } } @@ -1388,7 +1414,7 @@ func (c *Conn) CreateEndpoint(key [32]byte, addrs string) (conn.Endpoint, error) } } - c.addrsMu.Lock() + c.mu.Lock() for _, addr := range a.addrs { if addr.IP.Equal(derpMagicIP) { continue @@ -1400,7 +1426,7 @@ func (c *Conn) CreateEndpoint(key [32]byte, addrs string) (conn.Endpoint, error) c.addrsByUDP[epAddr] = a } c.addrsByKey[key] = a - c.addrsMu.Unlock() + c.mu.Unlock() return a, nil }