diff --git a/derp/derp_server.go b/derp/derp_server.go index 23155de1b..8123715e0 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -222,6 +222,9 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error // At this point we trust the client so we don't time out. nc.SetDeadline(time.Time{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := &sclient{ s: s, key: clientKey, @@ -229,6 +232,7 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error br: br, bw: bw, logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", remoteAddr, clientKey)), + done: ctx.Done(), remoteAddr: remoteAddr, connectedAt: time.Now(), sendQueue: make(chan pkt, perClientSendQueueDepth), @@ -248,8 +252,6 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error } func (c *sclient) run() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() defer func() { // Atomically close+remove send queue, so racing writers don't // send to closed channel. @@ -259,7 +261,7 @@ func (c *sclient) run() error { c.mu.Unlock() }() - go c.sender(ctx) + go c.sender() for { ft, fl, err := readFrameHeader(c.br) @@ -270,9 +272,9 @@ func (c *sclient) run() error { case frameNotePreferred: err = c.handleFrameNotePreferred(ft, fl) case frameSendPacket: - err = c.handleFrameSendPacket(ctx, ft, fl) + err = c.handleFrameSendPacket(ft, fl) default: - err = c.handleUnknownFrame(ctx, ft, fl) + err = c.handleUnknownFrame(ft, fl) } if err != nil { return err @@ -280,7 +282,7 @@ func (c *sclient) run() error { } } -func (c *sclient) handleUnknownFrame(ctx context.Context, ft frameType, fl uint32) error { +func (c *sclient) handleUnknownFrame(ft frameType, fl uint32) error { _, err := io.CopyN(ioutil.Discard, c.br, int64(fl)) return err } @@ -297,10 +299,10 @@ func (c *sclient) handleFrameNotePreferred(ft frameType, fl uint32) error { return nil } -func (c *sclient) handleFrameSendPacket(ctx context.Context, ft frameType, fl uint32) error { +func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { s := c.s - dstKey, contents, err := s.recvPacket(ctx, c.br, fl) + dstKey, contents, err := s.recvPacket(c.br, fl) if err != nil { return fmt.Errorf("client %x: recvPacket: %v", c.key, err) } @@ -446,7 +448,7 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *cl return clientKey, info, nil } -func (s *Server) recvPacket(ctx context.Context, br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) { +func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) { if frameLen < keyLen { return key.Public{}, nil, errors.New("short send packet frame") } @@ -476,7 +478,8 @@ type sclient struct { key key.Public info clientInfo logf logger.Logf - remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String() + done <-chan struct{} // closed when connection closes + remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String() // Owned by run, not thread-safe. br *bufio.Reader @@ -521,16 +524,16 @@ func (c *sclient) setPreferred(v bool) { } } -func (c *sclient) sender(ctx context.Context) { +func (c *sclient) sender() { // If the sender shuts down unilaterally due to an error, close so // that the receive loop unblocks and cleans up the rest. defer c.nc.Close() - if err := c.sendLoop(ctx); err != nil { + if err := c.sendLoop(); err != nil { c.logf("sender failed: %v", err) } } -func (c *sclient) sendLoop(ctx context.Context) error { +func (c *sclient) sendLoop() error { c.mu.RLock() queue := c.sendQueue c.mu.RUnlock() @@ -566,7 +569,7 @@ func (c *sclient) sendLoop(ctx context.Context) error { for { select { - case <-ctx.Done(): + case <-c.done: return nil case pkt, ok := <-queue: