diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index 4a232d37e..32adcdb9a 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -401,3 +401,49 @@ func TestBreakWatcherConn(t *testing.T) { timer.Reset(5 * time.Second) } } + +func noopAdd(key.NodePublic, netip.AddrPort) {} +func noopRemove(key.NodePublic) {} + +func TestRunWatchConnectionLoopServeConnect(t *testing.T) { + defer func() { testHookWatchLookConnectResult = nil }() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + priv := key.NewNode() + serverURL, s := newTestServer(t, priv) + defer s.Close() + + pub := priv.Public() + + watcher := newWatcherClient(t, priv, serverURL) + defer watcher.Close() + + // Test connecting to ourselves, and that we get hung up on. + testHookWatchLookConnectResult = func(err error, wasSelfConnect bool) bool { + t.Helper() + if err != nil { + t.Fatalf("error connecting to server: %v", err) + } + if !wasSelfConnect { + t.Error("wanted self-connect; wasn't") + } + return false + } + watcher.RunWatchConnectionLoop(ctx, pub, t.Logf, noopAdd, noopRemove) + + // Test connecting to the server with a zero value for ignoreServerKey, + // so we should always connect. + testHookWatchLookConnectResult = func(err error, wasSelfConnect bool) bool { + t.Helper() + if err != nil { + t.Fatalf("error connecting to server: %v", err) + } + if wasSelfConnect { + t.Error("wanted normal connect; got self connect") + } + return false + } + watcher.RunWatchConnectionLoop(ctx, key.NodePublic{}, t.Logf, noopAdd, noopRemove) +} diff --git a/derp/derphttp/mesh_client.go b/derp/derphttp/mesh_client.go index 9e9e518e1..07aaa3c89 100644 --- a/derp/derphttp/mesh_client.go +++ b/derp/derphttp/mesh_client.go @@ -16,6 +16,10 @@ import ( var retryInterval = 5 * time.Second +// testHookWatchLookConnectResult, if non-nil for tests, is called by RunWatchConnectionLoop +// with the connect result. If it returns false, the loop ends. +var testHookWatchLookConnectResult func(connectError error, wasSelfConnect bool) (keepRunning bool) + // RunWatchConnectionLoop loops until ctx is done, sending // WatchConnectionChanges and subscribing to connection changes. // @@ -112,7 +116,21 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key } for ctx.Err() == nil { - if c.ServerPublicKey() == ignoreServerKey { + // Make sure we're connected before calling s.ServerPublicKey. + _, _, err := c.connect(ctx, "RunWatchConnectionLoop") + if err != nil { + if f := testHookWatchLookConnectResult; f != nil && !f(err, false) { + return + } + logf("mesh connect: %v", err) + sleep(retryInterval) + continue + } + selfConnect := c.ServerPublicKey() == ignoreServerKey + if f := testHookWatchLookConnectResult; f != nil && !f(err, selfConnect) { + return + } + if selfConnect { logf("detected self-connect; ignoring host") return }