wgengine/magicsock: drop donec channel, rename epUpdateCtx to serve its purpose

pull/122/head
Brad Fitzpatrick 5 years ago
parent a399ef3dc7
commit af7a01d6f0

@ -41,11 +41,10 @@ type Conn struct {
startEpUpdate chan struct{} // send to trigger endpoint update startEpUpdate chan struct{} // send to trigger endpoint update
epFunc func(endpoints []string) epFunc func(endpoints []string)
logf func(format string, args ...interface{}) logf func(format string, args ...interface{})
donec chan struct{} // closed on Conn.Close
sendLogLimit *rate.Limiter sendLogLimit *rate.Limiter
epUpdateCtx context.Context // endpoint updater context connCtx context.Context // closed on Conn.Close
epUpdateCancel func() // the func to cancel epUpdateCtx connCtxCancel func() // closes connCtx
// addrsByUDP is a map of every remote ip:port to a priority // addrsByUDP is a map of every remote ip:port to a priority
// list of endpoint addresses for a peer. // list of endpoint addresses for a peer.
@ -135,29 +134,30 @@ func Listen(opts Options) (*Conn, error) {
return nil, fmt.Errorf("magicsock.Listen: %v", err) return nil, fmt.Errorf("magicsock.Listen: %v", err)
} }
epUpdateCtx, epUpdateCancel := context.WithCancel(context.Background()) connCtx, connCtxCancel := context.WithCancel(context.Background())
c := &Conn{ c := &Conn{
pconn: new(RebindingUDPConn), pconn: new(RebindingUDPConn),
pconnPort: opts.Port, pconnPort: opts.Port,
donec: make(chan struct{}), sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1),
sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), stunServers: append([]string{}, opts.STUN...),
stunServers: append([]string{}, opts.STUN...), startEpUpdate: make(chan struct{}, 1),
startEpUpdate: make(chan struct{}, 1), connCtx: connCtx,
epUpdateCtx: epUpdateCtx, connCtxCancel: connCtxCancel,
epUpdateCancel: epUpdateCancel, epFunc: opts.endpointsFunc(),
epFunc: opts.endpointsFunc(), logf: log.Printf,
logf: log.Printf, addrsByUDP: make(map[udpAddr]*AddrSet),
addrsByUDP: make(map[udpAddr]*AddrSet), derpRecvCh: make(chan derpReadResult),
derpRecvCh: make(chan derpReadResult), udpRecvCh: make(chan udpReadResult),
udpRecvCh: make(chan udpReadResult),
} }
c.ignoreSTUNPackets() c.ignoreSTUNPackets()
c.pconn.Reset(packetConn.(*net.UDPConn)) c.pconn.Reset(packetConn.(*net.UDPConn))
c.reSTUN() c.reSTUN()
go c.epUpdate(epUpdateCtx) go c.epUpdate(connCtx)
return c, nil return c, nil
} }
func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() }
// ignoreSTUNPackets sets a STUN packet processing func that does nothing. // ignoreSTUNPackets sets a STUN packet processing func that does nothing.
func (c *Conn) ignoreSTUNPackets() { func (c *Conn) ignoreSTUNPackets() {
c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {})
@ -497,11 +497,11 @@ func (c *Conn) sendAddr(addr *net.UDPAddr, pubKey key.Public, b []byte) error {
if ch := c.derpWriteChanOfAddr(addr); ch != nil { if ch := c.derpWriteChanOfAddr(addr); ch != nil {
errc := make(chan error, 1) errc := make(chan error, 1)
select { select {
case <-c.donec: case <-c.donec():
return errConnClosed return errConnClosed
case ch <- derpWriteRequest{addr, pubKey, b, errc}: case ch <- derpWriteRequest{addr, pubKey, b, errc}:
select { select {
case <-c.donec: case <-c.donec():
return errConnClosed return errConnClosed
case err := <-errc: case err := <-errc:
return err // usually nil return err // usually nil
@ -595,7 +595,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
} }
if err != nil { if err != nil {
select { select {
case <-c.donec: case <-c.donec():
return return
case <-ctx.Done(): case <-ctx.Done():
return return
@ -617,7 +617,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
log.Printf("got derp %v packet: %q", derpFakeAddr, buf[:bufValid]) log.Printf("got derp %v packet: %q", derpFakeAddr, buf[:bufValid])
} }
select { select {
case <-c.donec: case <-c.donec():
return return
case c.derpRecvCh <- derpReadResult{derpFakeAddr, bufValid, copyFn}: case c.derpRecvCh <- derpReadResult{derpFakeAddr, bufValid, copyFn}:
<-didCopy <-didCopy
@ -639,7 +639,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-c.donec: case <-c.donec():
return return
case wr := <-ch: case wr := <-ch:
err := dc.Send(wr.pubKey, wr.b) err := dc.Send(wr.pubKey, wr.b)
@ -648,7 +648,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
} }
select { select {
case wr.errc <- err: case wr.errc <- err:
case <-c.donec: case <-c.donec():
return return
} }
} }
@ -685,7 +685,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
if err != nil { if err != nil {
select { select {
case c.udpRecvCh <- udpReadResult{err: err}: case c.udpRecvCh <- udpReadResult{err: err}:
case <-c.donec: case <-c.donec():
} }
return return
} }
@ -698,7 +698,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
addr.IP = addr.IP.To4() addr.IP = addr.IP.To4()
select { select {
case c.udpRecvCh <- udpReadResult{n: n, addr: addr}: case c.udpRecvCh <- udpReadResult{n: n, addr: addr}:
case <-c.donec: case <-c.donec():
} }
return return
} }
@ -719,7 +719,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
// The main point of this receive, though, is to make sure that the goroutine // The main point of this receive, though, is to make sure that the goroutine
// is done with our b []byte buf. // is done with our b []byte buf.
c.pconn.SetReadDeadline(time.Time{}) c.pconn.SetReadDeadline(time.Time{})
case <-c.donec: case <-c.donec():
return 0, nil, nil, errors.New("Conn closed") return 0, nil, nil, errors.New("Conn closed")
} }
n, addr = dm.n, dm.derpAddr n, addr = dm.n, dm.derpAddr
@ -753,6 +753,13 @@ func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error
return 0, nil, nil, syscall.EAFNOSUPPORT return 0, nil, nil, syscall.EAFNOSUPPORT
} }
// SetPrivateKey sets the connection's private key.
//
// This is only used to be able prove our identity when connecting to
// DERP servers.
//
// If the private key changes, any DERP connections are torn down &
// recreated when needed.
func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error {
c.derpMu.Lock() c.derpMu.Lock()
defer c.derpMu.Unlock() defer c.derpMu.Unlock()
@ -768,6 +775,13 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error {
} }
// Key changed. Close any DERP connections. // Key changed. Close any DERP connections.
c.closeAllDerpLocked()
return nil
}
// c.derpMu must be held.
func (c *Conn) closeAllDerpLocked() {
for _, c := range c.derpConn { for _, c := range c.derpConn {
go c.Close() go c.Close()
} }
@ -777,30 +791,31 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error {
c.derpConn = nil c.derpConn = nil
c.derpCancel = nil c.derpCancel = nil
c.derpWriteCh = nil c.derpWriteCh = nil
return nil
} }
func (c *Conn) SetMark(value uint32) error { return nil } func (c *Conn) SetMark(value uint32) error { return nil }
func (c *Conn) LastMark() uint32 { return 0 } func (c *Conn) LastMark() uint32 { return 0 }
func (c *Conn) Close() error { func (c *Conn) Close() error {
// TODO: make this safe for concurrent Close? it's safe now only if Close calls are serialized.
select { select {
case <-c.donec: case <-c.donec():
return nil return nil
default: default:
} }
close(c.donec) c.connCtxCancel()
c.epUpdateCancel()
for _, dc := range c.derpConn { c.derpMu.Lock()
dc.Close() c.closeAllDerpLocked()
} c.derpMu.Unlock()
return c.pconn.Close() return c.pconn.Close()
} }
func (c *Conn) reSTUN() { func (c *Conn) reSTUN() {
select { select {
case c.startEpUpdate <- struct{}{}: case c.startEpUpdate <- struct{}{}:
case <-c.epUpdateCtx.Done(): case <-c.donec():
} }
} }

Loading…
Cancel
Save