derp: add sclient.done channel, simplify some context passing

This is mostly prep for a few future CLs, making sure we always have a
close-on-dead done channel available to select on when doing other
channel operations.
pull/208/head
Brad Fitzpatrick 4 years ago
parent ea90780066
commit 1453aecb44

@ -222,6 +222,9 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error
// At this point we trust the client so we don't time out.
nc.SetDeadline(time.Time{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := &sclient{
s: s,
key: clientKey,
@ -229,6 +232,7 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error
br: br,
bw: bw,
logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", remoteAddr, clientKey)),
done: ctx.Done(),
remoteAddr: remoteAddr,
connectedAt: time.Now(),
sendQueue: make(chan pkt, perClientSendQueueDepth),
@ -248,8 +252,6 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error
}
func (c *sclient) run() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
defer func() {
// Atomically close+remove send queue, so racing writers don't
// send to closed channel.
@ -259,7 +261,7 @@ func (c *sclient) run() error {
c.mu.Unlock()
}()
go c.sender(ctx)
go c.sender()
for {
ft, fl, err := readFrameHeader(c.br)
@ -270,9 +272,9 @@ func (c *sclient) run() error {
case frameNotePreferred:
err = c.handleFrameNotePreferred(ft, fl)
case frameSendPacket:
err = c.handleFrameSendPacket(ctx, ft, fl)
err = c.handleFrameSendPacket(ft, fl)
default:
err = c.handleUnknownFrame(ctx, ft, fl)
err = c.handleUnknownFrame(ft, fl)
}
if err != nil {
return err
@ -280,7 +282,7 @@ func (c *sclient) run() error {
}
}
func (c *sclient) handleUnknownFrame(ctx context.Context, ft frameType, fl uint32) error {
func (c *sclient) handleUnknownFrame(ft frameType, fl uint32) error {
_, err := io.CopyN(ioutil.Discard, c.br, int64(fl))
return err
}
@ -297,10 +299,10 @@ func (c *sclient) handleFrameNotePreferred(ft frameType, fl uint32) error {
return nil
}
func (c *sclient) handleFrameSendPacket(ctx context.Context, ft frameType, fl uint32) error {
func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
s := c.s
dstKey, contents, err := s.recvPacket(ctx, c.br, fl)
dstKey, contents, err := s.recvPacket(c.br, fl)
if err != nil {
return fmt.Errorf("client %x: recvPacket: %v", c.key, err)
}
@ -446,7 +448,7 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *cl
return clientKey, info, nil
}
func (s *Server) recvPacket(ctx context.Context, br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) {
func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) {
if frameLen < keyLen {
return key.Public{}, nil, errors.New("short send packet frame")
}
@ -476,7 +478,8 @@ type sclient struct {
key key.Public
info clientInfo
logf logger.Logf
remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String()
done <-chan struct{} // closed when connection closes
remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String()
// Owned by run, not thread-safe.
br *bufio.Reader
@ -521,16 +524,16 @@ func (c *sclient) setPreferred(v bool) {
}
}
func (c *sclient) sender(ctx context.Context) {
func (c *sclient) sender() {
// If the sender shuts down unilaterally due to an error, close so
// that the receive loop unblocks and cleans up the rest.
defer c.nc.Close()
if err := c.sendLoop(ctx); err != nil {
if err := c.sendLoop(); err != nil {
c.logf("sender failed: %v", err)
}
}
func (c *sclient) sendLoop(ctx context.Context) error {
func (c *sclient) sendLoop() error {
c.mu.RLock()
queue := c.sendQueue
c.mu.RUnlock()
@ -566,7 +569,7 @@ func (c *sclient) sendLoop(ctx context.Context) error {
for {
select {
case <-ctx.Done():
case <-c.done:
return nil
case pkt, ok := <-queue:

Loading…
Cancel
Save