diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index bb3d44377..a98dda10b 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -56,6 +56,11 @@ type Client struct { MeshKey string // optional; for trusted clients IsProber bool // optional; for probers to optional declare themselves as such + // BaseContext, if non-nil, returns the base context to use for dialing a + // new derp server. If nil, context.Background is used. + // In either case, additional timeouts may be added to the base context. + BaseContext func() context.Context + privateKey key.NodePrivate logf logger.Logf netMon *netmon.Monitor // optional; nil means interfaces will be looked up on-demand @@ -144,6 +149,19 @@ func (c *Client) Connect(ctx context.Context) error { return err } +// newContext returns a new context for setting up a new DERP connection. +// It uses either c.BaseContext or returns context.Background. +func (c *Client) newContext() context.Context { + if c.BaseContext != nil { + ctx := c.BaseContext() + if ctx == nil { + panic("BaseContext returned nil") + } + return ctx + } + return context.Background() +} + // TLSConnectionState returns the last TLS connection state, if any. // The client must already be connected. func (c *Client) TLSConnectionState() (_ *tls.ConnectionState, ok bool) { @@ -776,7 +794,7 @@ func (c *Client) dialNodeUsingProxy(ctx context.Context, n *tailcfg.DERPNode, pr } func (c *Client) Send(dstKey key.NodePublic, b []byte) error { - client, _, err := c.connect(context.TODO(), "derphttp.Client.Send") + client, _, err := c.connect(c.newContext(), "derphttp.Client.Send") if err != nil { return err } @@ -876,7 +894,7 @@ func (c *Client) LocalAddr() (netip.AddrPort, error) { } func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error { - client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket") + client, _, err := c.connect(c.newContext(), "derphttp.Client.ForwardPacket") if err != nil { return err } @@ -942,7 +960,7 @@ func (c *Client) NotePreferred(v bool) { // // Only trusted connections (using MeshKey) are allowed to use this. func (c *Client) WatchConnectionChanges() error { - client, _, err := c.connect(context.TODO(), "derphttp.Client.WatchConnectionChanges") + client, _, err := c.connect(c.newContext(), "derphttp.Client.WatchConnectionChanges") if err != nil { return err } @@ -957,7 +975,7 @@ func (c *Client) WatchConnectionChanges() error { // // Only trusted connections (using MeshKey) are allowed to use this. func (c *Client) ClosePeer(target key.NodePublic) error { - client, _, err := c.connect(context.TODO(), "derphttp.Client.ClosePeer") + client, _, err := c.connect(c.newContext(), "derphttp.Client.ClosePeer") if err != nil { return err } @@ -978,7 +996,7 @@ func (c *Client) Recv() (derp.ReceivedMessage, error) { // RecvDetail is like Recv, but additional returns the connection generation on each message. // The connGen value is incremented every time the derphttp.Client reconnects to the server. func (c *Client) RecvDetail() (m derp.ReceivedMessage, connGen int, err error) { - client, connGen, err := c.connect(context.TODO(), "derphttp.Client.Recv") + client, connGen, err := c.connect(c.newContext(), "derphttp.Client.Recv") if err != nil { return nil, 0, err }