diff --git a/derp/derp_client.go b/derp/derp_client.go index 50377da85..2a97ed299 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -261,10 +261,18 @@ func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err e func (c *Client) writeTimeoutFired() { c.nc.Close() } +func (c *Client) SendPing(data [8]byte) error { + return c.sendPingOrPong(framePing, data) +} + func (c *Client) SendPong(data [8]byte) error { + return c.sendPingOrPong(framePong, data) +} + +func (c *Client) sendPingOrPong(typ frameType, data [8]byte) error { c.wmu.Lock() defer c.wmu.Unlock() - if err := writeFrameHeader(c.bw, framePong, 8); err != nil { + if err := writeFrameHeader(c.bw, typ, 8); err != nil { return err } if _, err := c.bw.Write(data[:]); err != nil { @@ -375,6 +383,12 @@ type PingMessage [8]byte func (PingMessage) msg() {} +// PongMessage is a reply to a PingMessage from a client or server +// with the payload sent previously in a PingMessage. +type PongMessage [8]byte + +func (PongMessage) msg() {} + // KeepAliveMessage is a one-way empty message from server to client, just to // keep the connection alive. It's like a PingMessage, but doesn't solicit // a reply from the client. @@ -536,6 +550,15 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro copy(pm[:], b[:]) return pm, nil + case framePong: + var pm PongMessage + if n < 8 { + c.logf("[unexpected] dropping short ping frame") + continue + } + copy(pm[:], b[:]) + return pm, nil + case frameHealth: return HealthMessage{Problem: string(b[:])}, nil diff --git a/derp/derp_server.go b/derp/derp_server.go index e759c5ea4..57d0315eb 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -662,6 +662,7 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string, connN connectedAt: time.Now(), sendQueue: make(chan pkt, perClientSendQueueDepth), discoSendQueue: make(chan pkt, perClientSendQueueDepth), + sendPongCh: make(chan [8]byte, 1), peerGone: make(chan key.NodePublic), canMesh: clientInfo.MeshKey != "" && clientInfo.MeshKey == s.meshKey, } @@ -729,6 +730,8 @@ func (c *sclient) run(ctx context.Context) error { err = c.handleFrameWatchConns(ft, fl) case frameClosePeer: err = c.handleFrameClosePeer(ft, fl) + case framePing: + err = c.handleFramePing(ft, fl) default: err = c.handleUnknownFrame(ft, fl) } @@ -766,6 +769,32 @@ func (c *sclient) handleFrameWatchConns(ft frameType, fl uint32) error { return nil } +func (c *sclient) handleFramePing(ft frameType, fl uint32) error { + var m PingMessage + if fl < uint32(len(m)) { + return fmt.Errorf("short ping: %v", fl) + } + if fl > 1000 { + // unreasonably extra large. We leave some extra + // space for future extensibility, but not too much. + return fmt.Errorf("ping body too large: %v", fl) + } + _, err := io.ReadFull(c.br, m[:]) + if err != nil { + return err + } + if extra := int64(fl) - int64(len(m)); extra > 0 { + _, err = io.CopyN(ioutil.Discard, c.br, extra) + } + select { + case c.sendPongCh <- [8]byte(m): + default: + // They're pinging too fast. Ignore. + // TODO(bradfitz): add a rate limiter too. + } + return err +} + func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error { if fl != keyLen { return fmt.Errorf("handleFrameClosePeer wrong size") @@ -1202,6 +1231,7 @@ type sclient struct { remoteIPPort netaddr.IPPort // zero if remoteAddr is not ip:port. sendQueue chan pkt // packets queued to this client; never closed discoSendQueue chan pkt // important packets queued to this client; never closed + sendPongCh chan [8]byte // pong replies to send to the client; never closed peerGone chan key.NodePublic // write request that a previous sender has disconnected (not used by mesh peers) meshUpdate chan struct{} // write request to write peerStateChange canMesh bool // clientInfo had correct mesh token for inter-region routing @@ -1342,6 +1372,9 @@ func (c *sclient) sendLoop(ctx context.Context) error { werr = c.sendPacket(msg.src, msg.bs) c.recordQueueTime(msg.enqueuedAt) continue + case msg := <-c.sendPongCh: + werr = c.sendPong(msg) + continue case <-keepAliveTick.C: werr = c.sendKeepAlive() continue @@ -1368,6 +1401,9 @@ func (c *sclient) sendLoop(ctx context.Context) error { case msg := <-c.discoSendQueue: werr = c.sendPacket(msg.src, msg.bs) c.recordQueueTime(msg.enqueuedAt) + case msg := <-c.sendPongCh: + werr = c.sendPong(msg) + continue case <-keepAliveTick.C: werr = c.sendKeepAlive() } @@ -1384,6 +1420,16 @@ func (c *sclient) sendKeepAlive() error { return writeFrameHeader(c.bw.bw(), frameKeepAlive, 0) } +// sendPong sends a pong reply, without flushing. +func (c *sclient) sendPong(data [8]byte) error { + c.setWriteDeadline() + if err := writeFrameHeader(c.bw.bw(), framePong, uint32(len(data))); err != nil { + return err + } + _, err := c.bw.Write(data[:]) + return err +} + // sendPeerGone sends a peerGone frame, without flushing. func (c *sclient) sendPeerGone(peer key.NodePublic) error { c.s.peerGoneFrames.Add(1) diff --git a/derp/derp_test.go b/derp/derp_test.go index fa580ae14..b16add71a 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -812,6 +812,14 @@ func TestClientRecv(t *testing.T) { }, want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8}, }, + { + name: "pong", + input: []byte{ + byte(framePong), 0, 0, 0, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + }, + want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8}, + }, { name: "health_bad", input: []byte{ @@ -858,6 +866,23 @@ func TestClientRecv(t *testing.T) { } } +func TestClientSendPing(t *testing.T) { + var buf bytes.Buffer + c := &Client{ + bw: bufio.NewWriter(&buf), + } + if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { + t.Fatal(err) + } + want := []byte{ + byte(framePing), 0, 0, 0, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + } + if !bytes.Equal(buf.Bytes(), want) { + t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) + } +} + func TestClientSendPong(t *testing.T) { var buf bytes.Buffer c := &Client{ @@ -873,7 +898,6 @@ func TestClientSendPong(t *testing.T) { if !bytes.Equal(buf.Bytes(), want) { t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) } - } func TestServerDupClients(t *testing.T) { @@ -1316,3 +1340,30 @@ func TestClientSendRateLimiting(t *testing.T) { t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K) } } + +func TestServerRepliesToPing(t *testing.T) { + ts := newTestServer(t) + defer ts.close(t) + + tc := newRegularClient(t, ts, "alice") + + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 42} + + if err := tc.c.SendPing(data); err != nil { + t.Fatal(err) + } + + for { + m, err := tc.c.recvTimeout(time.Second) + if err != nil { + t.Fatal(err) + } + switch m := m.(type) { + case PongMessage: + if ([8]byte(m)) != data { + t.Fatalf("got pong %2x; want %2x", [8]byte(m), data) + } + return + } + } +} diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index f94117a12..500162932 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -698,6 +698,20 @@ func (c *Client) Send(dstKey key.NodePublic, b []byte) error { return err } +// SendPing sends a ping message, without any implicit connect or reconnect. +func (c *Client) SendPing(data [8]byte) error { + c.mu.Lock() + closed, client := c.closed, c.client + c.mu.Unlock() + if closed { + return ErrClientClosed + } + if client == nil { + return errors.New("client not connected") + } + return client.SendPing(data) +} + func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error { client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket") if err != nil {