derp: support client->server ping (and server->client pong)

In prep for a future change to have client ping derp connections
when their state is questionable, rather than aggressively tearing
them down and doing a heavy reconnect when their state is unknown.

We already support ping/pong in the other direction (servers probing
clients) so we already had the two frame types, but I'd never finished
this direction.

Updates #3619

Change-Id: I024b815d9db1bc57c20f82f80f95fb55fc9e2fcc
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/3625/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent bc537adb1a
commit 434af15a04

@ -261,10 +261,18 @@ func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err e
func (c *Client) writeTimeoutFired() { c.nc.Close() } func (c *Client) writeTimeoutFired() { c.nc.Close() }
func (c *Client) SendPing(data [8]byte) error {
return c.sendPingOrPong(framePing, data)
}
func (c *Client) SendPong(data [8]byte) error { func (c *Client) SendPong(data [8]byte) error {
return c.sendPingOrPong(framePong, data)
}
func (c *Client) sendPingOrPong(typ frameType, data [8]byte) error {
c.wmu.Lock() c.wmu.Lock()
defer c.wmu.Unlock() defer c.wmu.Unlock()
if err := writeFrameHeader(c.bw, framePong, 8); err != nil { if err := writeFrameHeader(c.bw, typ, 8); err != nil {
return err return err
} }
if _, err := c.bw.Write(data[:]); err != nil { if _, err := c.bw.Write(data[:]); err != nil {
@ -375,6 +383,12 @@ type PingMessage [8]byte
func (PingMessage) msg() {} func (PingMessage) msg() {}
// PongMessage is a reply to a PingMessage from a client or server
// with the payload sent previously in a PingMessage.
type PongMessage [8]byte
func (PongMessage) msg() {}
// KeepAliveMessage is a one-way empty message from server to client, just to // KeepAliveMessage is a one-way empty message from server to client, just to
// keep the connection alive. It's like a PingMessage, but doesn't solicit // keep the connection alive. It's like a PingMessage, but doesn't solicit
// a reply from the client. // a reply from the client.
@ -536,6 +550,15 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
copy(pm[:], b[:]) copy(pm[:], b[:])
return pm, nil return pm, nil
case framePong:
var pm PongMessage
if n < 8 {
c.logf("[unexpected] dropping short ping frame")
continue
}
copy(pm[:], b[:])
return pm, nil
case frameHealth: case frameHealth:
return HealthMessage{Problem: string(b[:])}, nil return HealthMessage{Problem: string(b[:])}, nil

@ -662,6 +662,7 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string, connN
connectedAt: time.Now(), connectedAt: time.Now(),
sendQueue: make(chan pkt, perClientSendQueueDepth), sendQueue: make(chan pkt, perClientSendQueueDepth),
discoSendQueue: make(chan pkt, perClientSendQueueDepth), discoSendQueue: make(chan pkt, perClientSendQueueDepth),
sendPongCh: make(chan [8]byte, 1),
peerGone: make(chan key.NodePublic), peerGone: make(chan key.NodePublic),
canMesh: clientInfo.MeshKey != "" && clientInfo.MeshKey == s.meshKey, canMesh: clientInfo.MeshKey != "" && clientInfo.MeshKey == s.meshKey,
} }
@ -729,6 +730,8 @@ func (c *sclient) run(ctx context.Context) error {
err = c.handleFrameWatchConns(ft, fl) err = c.handleFrameWatchConns(ft, fl)
case frameClosePeer: case frameClosePeer:
err = c.handleFrameClosePeer(ft, fl) err = c.handleFrameClosePeer(ft, fl)
case framePing:
err = c.handleFramePing(ft, fl)
default: default:
err = c.handleUnknownFrame(ft, fl) err = c.handleUnknownFrame(ft, fl)
} }
@ -766,6 +769,32 @@ func (c *sclient) handleFrameWatchConns(ft frameType, fl uint32) error {
return nil return nil
} }
func (c *sclient) handleFramePing(ft frameType, fl uint32) error {
var m PingMessage
if fl < uint32(len(m)) {
return fmt.Errorf("short ping: %v", fl)
}
if fl > 1000 {
// unreasonably extra large. We leave some extra
// space for future extensibility, but not too much.
return fmt.Errorf("ping body too large: %v", fl)
}
_, err := io.ReadFull(c.br, m[:])
if err != nil {
return err
}
if extra := int64(fl) - int64(len(m)); extra > 0 {
_, err = io.CopyN(ioutil.Discard, c.br, extra)
}
select {
case c.sendPongCh <- [8]byte(m):
default:
// They're pinging too fast. Ignore.
// TODO(bradfitz): add a rate limiter too.
}
return err
}
func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error { func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error {
if fl != keyLen { if fl != keyLen {
return fmt.Errorf("handleFrameClosePeer wrong size") return fmt.Errorf("handleFrameClosePeer wrong size")
@ -1202,6 +1231,7 @@ type sclient struct {
remoteIPPort netaddr.IPPort // zero if remoteAddr is not ip:port. remoteIPPort netaddr.IPPort // zero if remoteAddr is not ip:port.
sendQueue chan pkt // packets queued to this client; never closed sendQueue chan pkt // packets queued to this client; never closed
discoSendQueue chan pkt // important packets queued to this client; never closed discoSendQueue chan pkt // important packets queued to this client; never closed
sendPongCh chan [8]byte // pong replies to send to the client; never closed
peerGone chan key.NodePublic // write request that a previous sender has disconnected (not used by mesh peers) peerGone chan key.NodePublic // write request that a previous sender has disconnected (not used by mesh peers)
meshUpdate chan struct{} // write request to write peerStateChange meshUpdate chan struct{} // write request to write peerStateChange
canMesh bool // clientInfo had correct mesh token for inter-region routing canMesh bool // clientInfo had correct mesh token for inter-region routing
@ -1342,6 +1372,9 @@ func (c *sclient) sendLoop(ctx context.Context) error {
werr = c.sendPacket(msg.src, msg.bs) werr = c.sendPacket(msg.src, msg.bs)
c.recordQueueTime(msg.enqueuedAt) c.recordQueueTime(msg.enqueuedAt)
continue continue
case msg := <-c.sendPongCh:
werr = c.sendPong(msg)
continue
case <-keepAliveTick.C: case <-keepAliveTick.C:
werr = c.sendKeepAlive() werr = c.sendKeepAlive()
continue continue
@ -1368,6 +1401,9 @@ func (c *sclient) sendLoop(ctx context.Context) error {
case msg := <-c.discoSendQueue: case msg := <-c.discoSendQueue:
werr = c.sendPacket(msg.src, msg.bs) werr = c.sendPacket(msg.src, msg.bs)
c.recordQueueTime(msg.enqueuedAt) c.recordQueueTime(msg.enqueuedAt)
case msg := <-c.sendPongCh:
werr = c.sendPong(msg)
continue
case <-keepAliveTick.C: case <-keepAliveTick.C:
werr = c.sendKeepAlive() werr = c.sendKeepAlive()
} }
@ -1384,6 +1420,16 @@ func (c *sclient) sendKeepAlive() error {
return writeFrameHeader(c.bw.bw(), frameKeepAlive, 0) return writeFrameHeader(c.bw.bw(), frameKeepAlive, 0)
} }
// sendPong sends a pong reply, without flushing.
func (c *sclient) sendPong(data [8]byte) error {
c.setWriteDeadline()
if err := writeFrameHeader(c.bw.bw(), framePong, uint32(len(data))); err != nil {
return err
}
_, err := c.bw.Write(data[:])
return err
}
// sendPeerGone sends a peerGone frame, without flushing. // sendPeerGone sends a peerGone frame, without flushing.
func (c *sclient) sendPeerGone(peer key.NodePublic) error { func (c *sclient) sendPeerGone(peer key.NodePublic) error {
c.s.peerGoneFrames.Add(1) c.s.peerGoneFrames.Add(1)

@ -812,6 +812,14 @@ func TestClientRecv(t *testing.T) {
}, },
want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8}, want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8},
}, },
{
name: "pong",
input: []byte{
byte(framePong), 0, 0, 0, 8,
1, 2, 3, 4, 5, 6, 7, 8,
},
want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8},
},
{ {
name: "health_bad", name: "health_bad",
input: []byte{ input: []byte{
@ -858,6 +866,23 @@ func TestClientRecv(t *testing.T) {
} }
} }
func TestClientSendPing(t *testing.T) {
var buf bytes.Buffer
c := &Client{
bw: bufio.NewWriter(&buf),
}
if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
t.Fatal(err)
}
want := []byte{
byte(framePing), 0, 0, 0, 8,
1, 2, 3, 4, 5, 6, 7, 8,
}
if !bytes.Equal(buf.Bytes(), want) {
t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
}
}
func TestClientSendPong(t *testing.T) { func TestClientSendPong(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
c := &Client{ c := &Client{
@ -873,7 +898,6 @@ func TestClientSendPong(t *testing.T) {
if !bytes.Equal(buf.Bytes(), want) { if !bytes.Equal(buf.Bytes(), want) {
t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
} }
} }
func TestServerDupClients(t *testing.T) { func TestServerDupClients(t *testing.T) {
@ -1316,3 +1340,30 @@ func TestClientSendRateLimiting(t *testing.T) {
t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K) t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K)
} }
} }
func TestServerRepliesToPing(t *testing.T) {
ts := newTestServer(t)
defer ts.close(t)
tc := newRegularClient(t, ts, "alice")
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 42}
if err := tc.c.SendPing(data); err != nil {
t.Fatal(err)
}
for {
m, err := tc.c.recvTimeout(time.Second)
if err != nil {
t.Fatal(err)
}
switch m := m.(type) {
case PongMessage:
if ([8]byte(m)) != data {
t.Fatalf("got pong %2x; want %2x", [8]byte(m), data)
}
return
}
}
}

@ -698,6 +698,20 @@ func (c *Client) Send(dstKey key.NodePublic, b []byte) error {
return err return err
} }
// SendPing sends a ping message, without any implicit connect or reconnect.
func (c *Client) SendPing(data [8]byte) error {
c.mu.Lock()
closed, client := c.closed, c.client
c.mu.Unlock()
if closed {
return ErrClientClosed
}
if client == nil {
return errors.New("client not connected")
}
return client.SendPing(data)
}
func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error { func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error {
client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket") client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket")
if err != nil { if err != nil {

Loading…
Cancel
Save