diff --git a/cmd/tailscale/cli/serve_dev.go b/cmd/tailscale/cli/serve_dev.go index 7e98baa3e..b48c8292b 100644 --- a/cmd/tailscale/cli/serve_dev.go +++ b/cmd/tailscale/cli/serve_dev.go @@ -8,6 +8,7 @@ import ( "errors" "flag" "fmt" + "io" "log" "net" "net/url" @@ -289,7 +290,7 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { for { _, err = watcher.Next() if err != nil { - if errors.Is(err, context.Canceled) { + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { return nil } return err diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 97440bc81..d7fad76dd 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -128,6 +128,13 @@ func RegisterNewSSHServer(fn newSSHServerFunc) { newSSHServer = fn } +// watchSession represents a WatchNotifications channel +// and sessionID as required to close targeted buses. +type watchSession struct { + ch chan *ipn.Notify + sessionID string +} + // LocalBackend is the glue between the major pieces of the Tailscale // network software: the cloud control plane (via controlclient), the // network data plane (via wgengine), and the user-facing UIs and CLIs @@ -233,7 +240,7 @@ type LocalBackend struct { loginFlags controlclient.LoginFlags incomingFiles map[*incomingFile]bool fileWaiters set.HandleSet[context.CancelFunc] // of wake-up funcs - notifyWatchers set.HandleSet[chan *ipn.Notify] + notifyWatchers set.HandleSet[*watchSession] lastStatusTime time.Time // status.AsOf value of the last processed status update // directFileRoot, if non-empty, means to write received files // directly to this directory, without staging them in an @@ -2058,7 +2065,7 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa } } - handle := b.notifyWatchers.Add(ch) + handle := b.notifyWatchers.Add(&watchSession{ch, sessionID}) b.mu.Unlock() defer func() { @@ -2103,8 +2110,8 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa select { case <-ctx.Done(): return - case n := <-ch: - if !fn(n) { + case n, ok := <-ch: + if !ok || !fn(n) { return } } @@ -2174,9 +2181,9 @@ func (b *LocalBackend) send(n ipn.Notify) { n.FilesWaiting = &empty.Message{} } - for _, ch := range b.notifyWatchers { + for _, sess := range b.notifyWatchers { select { - case ch <- &n: + case sess.ch <- &n: default: // Drop the notification if the channel is full. } diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 7d9d9c17f..757f2254b 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -752,9 +752,9 @@ func TestWatchNotificationsCallbacks(t *testing.T) { } // Send a notification. Range over notifyWatchers to get the channel // because WatchNotifications doesn't expose the handle for it. - for _, c := range b.notifyWatchers { + for _, sess := range b.notifyWatchers { select { - case c <- n: + case sess.ch <- n: default: t.Fatalf("could not send notification") } diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index c28ca0239..386e0f89e 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -247,16 +247,17 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string // If etag is present, check that it has // not changed from the last config. + prevConfig := b.serveConfig if etag != "" { // Note that we marshal b.serveConfig // and not use b.lastServeConfJSON as that might // be a Go nil value, which produces a different // checksum from a JSON "null" value. - previousCfg, err := json.Marshal(b.serveConfig) + prevBytes, err := json.Marshal(prevConfig) if err != nil { return fmt.Errorf("error encoding previous config: %w", err) } - sum := sha256.Sum256(previousCfg) + sum := sha256.Sum256(prevBytes) previousEtag := hex.EncodeToString(sum[:]) if etag != previousEtag { return ErrETagMismatch @@ -279,6 +280,26 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string } b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(b.pm.CurrentPrefs()) + + // clean up and close all previously open foreground sessions + // if the current ServeConfig has overwritten them. + if prevConfig.Valid() { + has := func(string) bool { return false } + if b.serveConfig.Valid() { + has = b.serveConfig.Foreground().Has + } + prevConfig.Foreground().Range(func(k string, v ipn.ServeConfigView) (cont bool) { + if !has(k) { + for _, sess := range b.notifyWatchers { + if sess.sessionID == k { + close(sess.ch) + } + } + } + return true + }) + } + return nil } diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index ec4e4e8de..242dd9ca2 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -20,6 +20,7 @@ import ( "path/filepath" "strings" "testing" + "time" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" @@ -184,6 +185,105 @@ func getEtag(t *testing.T, b any) string { return hex.EncodeToString(sum[:]) } +// TestServeConfigForeground tests the inter-dependency +// between a ServeConfig and a WatchIPNBus: +// 1. Creating a WatchIPNBus returns a sessionID, that +// 2. ServeConfig sets it as the key of the Foreground field. +// 3. ServeConfig expects the WatchIPNBus to clean up the Foreground +// config when the session is done. +// 4. WatchIPNBus expects the ServeConfig to send a signal (close the channel) +// if an incoming SetServeConfig removes previous foregrounds. +func TestServeConfigForeground(t *testing.T) { + b := newTestBackend(t) + + ch1 := make(chan string, 1) + go func() { + defer close(ch1) + b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) { + if roNotify.SessionID != "" { + ch1 <- roNotify.SessionID + } + return true + }) + }() + + ch2 := make(chan string, 1) + go func() { + b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) { + if roNotify.SessionID != "" { + ch2 <- roNotify.SessionID + return true + } + ch2 <- "again" // let channel know fn was called again + return true + }) + }() + + var session1 string + select { + case session1 = <-ch1: + case <-time.After(time.Second): + t.Fatal("timed out waiting on watch notifications session id") + } + + var session2 string + select { + case session2 = <-ch2: + case <-time.After(time.Second): + t.Fatal("timed out waiting on watch notifications session id") + } + + err := b.SetServeConfig(&ipn.ServeConfig{ + Foreground: map[string]*ipn.ServeConfig{ + session1: {TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {TCPForward: "http://localhost:3000"}}, + }, + session2: {TCP: map[uint16]*ipn.TCPPortHandler{ + 999: {TCPForward: "http://localhost:4000"}}, + }, + }, + }, "") + if err != nil { + t.Fatal(err) + } + + // Setting a new serve config should shut down WatchNotifications + // whose session IDs are no longer found: session1 goes, session2 stays. + err = b.SetServeConfig(&ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 5000: {TCPForward: "http://localhost:5000"}, + }, + Foreground: map[string]*ipn.ServeConfig{ + session2: {TCP: map[uint16]*ipn.TCPPortHandler{ + 999: {TCPForward: "http://localhost:4000"}}, + }, + }, + }, "") + if err != nil { + t.Fatal(err) + } + + select { + case _, ok := <-ch1: + if ok { + t.Fatal("expected channel to be closed") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting on watch notifications closing") + } + + // check that the second session is still running + b.send(ipn.Notify{}) + select { + case _, ok := <-ch2: + if !ok { + t.Fatal("expected second session to remain open") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting on second session") + } +} + func TestServeConfigETag(t *testing.T) { b := newTestBackend(t)