diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index efe88c2b3..560a58676 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -86,7 +86,8 @@ type Client struct { addrFamSelAtomic syncs.AtomicValue[AddressFamilySelector] mu sync.Mutex - started bool // true upon first connect, never transitions to false + atomicState syncs.AtomicValue[ConnectedState] // hold mu to write + started bool // true upon first connect, never transitions to false preferred bool canAckPings bool closed bool @@ -99,6 +100,14 @@ type Client struct { clock tstime.Clock } +// ConnectedState describes the state of a derphttp Client. +type ConnectedState struct { + Connected bool + Connecting bool + Closed bool + LocalAddr netip.AddrPort // if Connected +} + func (c *Client) String() string { return fmt.Sprintf("", c.ServerPublicKey().ShortString(), c.url) } @@ -307,6 +316,12 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien if c.client != nil { return c.client, c.connGen, nil } + c.atomicState.Store(ConnectedState{Connecting: true}) + defer func() { + if err != nil { + c.atomicState.Store(ConnectedState{Connecting: false}) + } + }() // timeout is the fallback maximum time (if ctx doesn't limit // it further) to do all of: DNS + TCP + TLS + HTTP Upgrade + @@ -524,6 +539,12 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien c.netConn = tcpConn c.tlsState = tlsState c.connGen++ + + localAddr, _ := c.client.LocalAddr() + c.atomicState.Store(ConnectedState{ + Connected: true, + LocalAddr: localAddr, + }) return c.client, c.connGen, nil } @@ -906,16 +927,15 @@ func (c *Client) SendPing(data [8]byte) error { // LocalAddr reports c's local TCP address, without any implicit // connect or reconnect. func (c *Client) LocalAddr() (netip.AddrPort, error) { - c.mu.Lock() - closed, client := c.closed, c.client - c.mu.Unlock() - if closed { + st := c.atomicState.Load() + if st.Closed { return netip.AddrPort{}, ErrClientClosed } - if client == nil { + la := st.LocalAddr + if !st.Connected && !la.IsValid() { return netip.AddrPort{}, errors.New("client not connected") } - return client.LocalAddr() + return la, nil } func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error { @@ -1049,6 +1069,7 @@ func (c *Client) Close() error { if c.netConn != nil { c.netConn.Close() } + c.atomicState.Store(ConnectedState{Closed: true}) return nil } diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index 32adcdb9a..dc5acf49f 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "crypto/tls" + "fmt" "net" "net/http" "net/netip" @@ -447,3 +448,16 @@ func TestRunWatchConnectionLoopServeConnect(t *testing.T) { } watcher.RunWatchConnectionLoop(ctx, key.NodePublic{}, t.Logf, noopAdd, noopRemove) } + +// verify that the LocalAddr method doesn't acquire the mutex. +// See https://github.com/tailscale/tailscale/issues/11519 +func TestLocalAddrNoMutex(t *testing.T) { + var c Client + c.mu.Lock() + defer c.mu.Unlock() // not needed in test but for symmetry + + _, err := c.LocalAddr() + if got, want := fmt.Sprint(err), "client not connected"; got != want { + t.Errorf("got error %q; want %q", got, want) + } +}