diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 500162932..f2e64368c 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -13,6 +13,7 @@ package derphttp import ( "bufio" "context" + "crypto/rand" "crypto/tls" "crypto/x509" "errors" @@ -72,6 +73,7 @@ type Client struct { client *derp.Client connGen int // incremented once per new connection; valid values are >0 serverPubKey key.NodePublic + pingOut map[derp.PingMessage]chan<- bool // chan to send to on pong } // NewRegionClient returns a new DERP-over-HTTP client. It connects lazily. @@ -698,7 +700,67 @@ func (c *Client) Send(dstKey key.NodePublic, b []byte) error { return err } -// SendPing sends a ping message, without any implicit connect or reconnect. +func (c *Client) registerPing(m derp.PingMessage, ch chan<- bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.pingOut == nil { + c.pingOut = map[derp.PingMessage]chan<- bool{} + } + c.pingOut[m] = ch +} + +func (c *Client) unregisterPing(m derp.PingMessage) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.pingOut, m) +} + +func (c *Client) handledPong(m derp.PongMessage) bool { + c.mu.Lock() + defer c.mu.Unlock() + k := derp.PingMessage(m) + if ch, ok := c.pingOut[k]; ok { + ch <- true + delete(c.pingOut, k) + return true + } + return false +} + +// Ping sends a ping to the peer and waits for it either to be +// acknowledged (in which case Ping returns nil) or waits for ctx to +// be over and returns an error. It will wait at most 5 seconds +// before returning an error. +// +// Another goroutine must be in a loop calling Recv or +// RecvDetail or ping responses won't be handled. +func (c *Client) Ping(ctx context.Context) error { + maxDL := time.Now().Add(5 * time.Second) + if dl, ok := ctx.Deadline(); !ok || dl.After(maxDL) { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, maxDL) + defer cancel() + } + var data derp.PingMessage + rand.Read(data[:]) + gotPing := make(chan bool, 1) + c.registerPing(data, gotPing) + defer c.unregisterPing(data) + if err := c.SendPing(data); err != nil { + return err + } + select { + case <-gotPing: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// SendPing writes a ping message, without any implicit connect or +// reconnect. This is a lower-level interface that writes a frame +// without any implicit handling of the response pong, if any. For a +// higher-level interface, use Ping. func (c *Client) SendPing(data [8]byte) error { c.mu.Lock() closed, client := c.closed, c.client @@ -819,14 +881,22 @@ func (c *Client) RecvDetail() (m derp.ReceivedMessage, connGen int, err error) { if err != nil { return nil, 0, err } - m, err = client.Recv() - if err != nil { - c.closeForReconnect(client) - if c.isClosed() { - err = ErrClientClosed + for { + m, err = client.Recv() + switch m := m.(type) { + case derp.PongMessage: + if c.handledPong(m) { + continue + } + } + if err != nil { + c.closeForReconnect(client) + if c.isClosed() { + err = ErrClientClosed + } } + return m, connGen, err } - return m, connGen, err } func (c *Client) isClosed() bool { diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index 40cddc4da..bc6c008b2 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -154,3 +154,55 @@ func waitConnect(t testing.TB, c *Client) { t.Fatalf("client first Recv was unexpected type %T", v) } } + +func TestPing(t *testing.T) { + serverPrivateKey := key.NewNode() + s := derp.NewServer(serverPrivateKey, t.Logf) + defer s.Close() + + httpsrv := &http.Server{ + TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), + Handler: Handler(s), + } + + ln, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + serverURL := "http://" + ln.Addr().String() + t.Logf("server URL: %s", serverURL) + + go func() { + if err := httpsrv.Serve(ln); err != nil { + if err == http.ErrServerClosed { + return + } + panic(err) + } + }() + + c, err := NewClient(key.NewNode(), serverURL, t.Logf) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + defer c.Close() + if err := c.Connect(context.Background()); err != nil { + t.Fatalf("client Connect: %v", err) + } + + errc := make(chan error, 1) + go func() { + for { + m, err := c.Recv() + if err != nil { + errc <- err + return + } + t.Logf("Recv: %T", m) + } + }() + err = c.Ping(context.Background()) + if err != nil { + t.Fatalf("Ping: %v", err) + } +}