From af7a01d6f0b5b9fbb7c9022940bc4f515148425b Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 2 Mar 2020 09:31:25 -0800 Subject: [PATCH] wgengine/magicsock: drop donec channel, rename epUpdateCtx to serve its purpose --- wgengine/magicsock/magicsock.go | 85 +++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 35 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 064746b95..4dbc26976 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -41,11 +41,10 @@ type Conn struct { startEpUpdate chan struct{} // send to trigger endpoint update epFunc func(endpoints []string) logf func(format string, args ...interface{}) - donec chan struct{} // closed on Conn.Close sendLogLimit *rate.Limiter - epUpdateCtx context.Context // endpoint updater context - epUpdateCancel func() // the func to cancel epUpdateCtx + connCtx context.Context // closed on Conn.Close + connCtxCancel func() // closes connCtx // addrsByUDP is a map of every remote ip:port to a priority // list of endpoint addresses for a peer. @@ -135,29 +134,30 @@ func Listen(opts Options) (*Conn, error) { return nil, fmt.Errorf("magicsock.Listen: %v", err) } - epUpdateCtx, epUpdateCancel := context.WithCancel(context.Background()) + connCtx, connCtxCancel := context.WithCancel(context.Background()) c := &Conn{ - pconn: new(RebindingUDPConn), - pconnPort: opts.Port, - donec: make(chan struct{}), - sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), - stunServers: append([]string{}, opts.STUN...), - startEpUpdate: make(chan struct{}, 1), - epUpdateCtx: epUpdateCtx, - epUpdateCancel: epUpdateCancel, - epFunc: opts.endpointsFunc(), - logf: log.Printf, - addrsByUDP: make(map[udpAddr]*AddrSet), - derpRecvCh: make(chan derpReadResult), - udpRecvCh: make(chan udpReadResult), + pconn: new(RebindingUDPConn), + pconnPort: opts.Port, + sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), + stunServers: append([]string{}, opts.STUN...), + startEpUpdate: make(chan struct{}, 1), + connCtx: connCtx, + connCtxCancel: connCtxCancel, + epFunc: opts.endpointsFunc(), + logf: log.Printf, + addrsByUDP: make(map[udpAddr]*AddrSet), + derpRecvCh: make(chan derpReadResult), + udpRecvCh: make(chan udpReadResult), } c.ignoreSTUNPackets() c.pconn.Reset(packetConn.(*net.UDPConn)) c.reSTUN() - go c.epUpdate(epUpdateCtx) + go c.epUpdate(connCtx) return c, nil } +func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() } + // ignoreSTUNPackets sets a STUN packet processing func that does nothing. func (c *Conn) ignoreSTUNPackets() { c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) @@ -497,11 +497,11 @@ func (c *Conn) sendAddr(addr *net.UDPAddr, pubKey key.Public, b []byte) error { if ch := c.derpWriteChanOfAddr(addr); ch != nil { errc := make(chan error, 1) select { - case <-c.donec: + case <-c.donec(): return errConnClosed case ch <- derpWriteRequest{addr, pubKey, b, errc}: select { - case <-c.donec: + case <-c.donec(): return errConnClosed case err := <-errc: return err // usually nil @@ -595,7 +595,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc } if err != nil { select { - case <-c.donec: + case <-c.donec(): return case <-ctx.Done(): return @@ -617,7 +617,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc log.Printf("got derp %v packet: %q", derpFakeAddr, buf[:bufValid]) } select { - case <-c.donec: + case <-c.donec(): return case c.derpRecvCh <- derpReadResult{derpFakeAddr, bufValid, copyFn}: <-didCopy @@ -639,7 +639,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc select { case <-ctx.Done(): return - case <-c.donec: + case <-c.donec(): return case wr := <-ch: err := dc.Send(wr.pubKey, wr.b) @@ -648,7 +648,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc } select { case wr.errc <- err: - case <-c.donec: + case <-c.donec(): return } } @@ -685,7 +685,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr if err != nil { select { case c.udpRecvCh <- udpReadResult{err: err}: - case <-c.donec: + case <-c.donec(): } return } @@ -698,7 +698,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr addr.IP = addr.IP.To4() select { case c.udpRecvCh <- udpReadResult{n: n, addr: addr}: - case <-c.donec: + case <-c.donec(): } return } @@ -719,7 +719,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr // The main point of this receive, though, is to make sure that the goroutine // is done with our b []byte buf. c.pconn.SetReadDeadline(time.Time{}) - case <-c.donec: + case <-c.donec(): return 0, nil, nil, errors.New("Conn closed") } n, addr = dm.n, dm.derpAddr @@ -753,6 +753,13 @@ func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error return 0, nil, nil, syscall.EAFNOSUPPORT } +// SetPrivateKey sets the connection's private key. +// +// This is only used to be able prove our identity when connecting to +// DERP servers. +// +// If the private key changes, any DERP connections are torn down & +// recreated when needed. func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { c.derpMu.Lock() defer c.derpMu.Unlock() @@ -768,6 +775,13 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { } // Key changed. Close any DERP connections. + c.closeAllDerpLocked() + + return nil +} + +// c.derpMu must be held. +func (c *Conn) closeAllDerpLocked() { for _, c := range c.derpConn { go c.Close() } @@ -777,30 +791,31 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { c.derpConn = nil c.derpCancel = nil c.derpWriteCh = nil - return nil } func (c *Conn) SetMark(value uint32) error { return nil } func (c *Conn) LastMark() uint32 { return 0 } func (c *Conn) Close() error { + // TODO: make this safe for concurrent Close? it's safe now only if Close calls are serialized. select { - case <-c.donec: + case <-c.donec(): return nil default: } - close(c.donec) - c.epUpdateCancel() - for _, dc := range c.derpConn { - dc.Close() - } + c.connCtxCancel() + + c.derpMu.Lock() + c.closeAllDerpLocked() + c.derpMu.Unlock() + return c.pconn.Close() } func (c *Conn) reSTUN() { select { case c.startEpUpdate <- struct{}{}: - case <-c.epUpdateCtx.Done(): + case <-c.donec(): } }