diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index e5984a5b5..d7a340996 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -69,6 +69,7 @@ import ( "reflect" "strconv" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -181,10 +182,16 @@ func main() { } } - client, daemonPid, err := startTailscaled(bootCtx, cfg) + client, daemonProcess, err := startTailscaled(bootCtx, cfg) if err != nil { log.Fatalf("failed to bring up tailscale: %v", err) } + killTailscaled := func() { + if err := daemonProcess.Signal(unix.SIGTERM); err != nil { + log.Fatalf("error shutting tailscaled down: %v", err) + } + } + defer killTailscaled() w, err := client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialPrefs|ipn.NotifyInitialState) if err != nil { @@ -252,7 +259,7 @@ authLoop: w.Close() - ctx, cancel := context.WithCancel(context.Background()) // no deadline now that we're in steady state + ctx, cancel := contextWithExitSignalWatch() defer cancel() if cfg.AuthOnce { @@ -306,84 +313,111 @@ authLoop: log.Fatalf("error creating new netfilter runner: %v", err) } } + notifyChan := make(chan ipn.Notify) + errChan := make(chan error) + go func() { + for { + n, err := w.Next() + if err != nil { + errChan <- err + break + } else { + notifyChan <- n + } + } + }() + var wg sync.WaitGroup +runLoop: for { - n, err := w.Next() - if err != nil { + select { + case <-ctx.Done(): + // Although killTailscaled() is deferred earlier, if we + // have started the reaper defined below, we need to + // kill tailscaled and let reaper clean up child + // processes. + killTailscaled() + break runLoop + case err := <-errChan: log.Fatalf("failed to read from tailscaled: %v", err) - } - - if n.State != nil && *n.State != ipn.Running { - // Something's gone wrong and we've left the authenticated state. - // Our container image never recovered gracefully from this, and the - // control flow required to make it work now is hard. So, just crash - // the container and rely on the container runtime to restart us, - // whereupon we'll go through initial auth again. - log.Fatalf("tailscaled left running state (now in state %q), exiting", *n.State) - } - if n.NetMap != nil { - addrs := n.NetMap.SelfNode.Addresses().AsSlice() - newCurrentIPs := deephash.Hash(&addrs) - ipsHaveChanged := newCurrentIPs != currentIPs - if cfg.ProxyTo != "" && len(addrs) > 0 && ipsHaveChanged { - log.Printf("Installing proxy rules") - if err := installIngressForwardingRule(ctx, cfg.ProxyTo, addrs, nfr); err != nil { - log.Fatalf("installing ingress proxy rules: %v", err) - } + case n := <-notifyChan: + if n.State != nil && *n.State != ipn.Running { + // Something's gone wrong and we've left the authenticated state. + // Our container image never recovered gracefully from this, and the + // control flow required to make it work now is hard. So, just crash + // the container and rely on the container runtime to restart us, + // whereupon we'll go through initial auth again. + log.Fatalf("tailscaled left running state (now in state %q), exiting", *n.State) } - if cfg.ServeConfigPath != "" && len(n.NetMap.DNS.CertDomains) > 0 { - cd := n.NetMap.DNS.CertDomains[0] - prev := certDomain.Swap(ptr.To(cd)) - if prev == nil || *prev != cd { - select { - case certDomainChanged <- true: - default: + if n.NetMap != nil { + addrs := n.NetMap.SelfNode.Addresses().AsSlice() + newCurrentIPs := deephash.Hash(&addrs) + ipsHaveChanged := newCurrentIPs != currentIPs + if cfg.ProxyTo != "" && len(addrs) > 0 && ipsHaveChanged { + log.Printf("Installing proxy rules") + if err := installIngressForwardingRule(ctx, cfg.ProxyTo, addrs, nfr); err != nil { + log.Fatalf("installing ingress proxy rules: %v", err) } } - } - if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) > 0 { - if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil { - log.Fatalf("installing egress proxy rules: %v", err) + if cfg.ServeConfigPath != "" && len(n.NetMap.DNS.CertDomains) > 0 { + cd := n.NetMap.DNS.CertDomains[0] + prev := certDomain.Swap(ptr.To(cd)) + if prev == nil || *prev != cd { + select { + case certDomainChanged <- true: + default: + } + } } - } - currentIPs = newCurrentIPs + if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) > 0 { + if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil { + log.Fatalf("installing egress proxy rules: %v", err) + } + } + currentIPs = newCurrentIPs - deviceInfo := []any{n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name()} - if cfg.InKubernetes && cfg.KubernetesCanPatch && cfg.KubeSecret != "" && deephash.Update(¤tDeviceInfo, &deviceInfo) { - if err := storeDeviceInfo(ctx, cfg.KubeSecret, n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { - log.Fatalf("storing device ID in kube secret: %v", err) + deviceInfo := []any{n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name()} + if cfg.InKubernetes && cfg.KubernetesCanPatch && cfg.KubeSecret != "" && deephash.Update(¤tDeviceInfo, &deviceInfo) { + if err := storeDeviceInfo(ctx, cfg.KubeSecret, n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { + log.Fatalf("storing device ID in kube secret: %v", err) + } } } - } - if !startupTasksDone { - if (!wantProxy || currentIPs != deephash.Sum{}) && (!wantDeviceInfo || currentDeviceInfo != deephash.Sum{}) { - // This log message is used in tests to detect when all - // post-auth configuration is done. - log.Println("Startup complete, waiting for shutdown signal") - startupTasksDone = true - - // Reap all processes, since we are PID1 and need to collect zombies. We can - // only start doing this once we've stopped shelling out to things - // `tailscale up`, otherwise this goroutine can reap the CLI subprocesses - // and wedge bringup. - go func() { - for { - var status unix.WaitStatus - pid, err := unix.Wait4(-1, &status, 0, nil) - if errors.Is(err, unix.EINTR) { - continue - } - if err != nil { - log.Fatalf("Waiting for exited processes: %v", err) - } - if pid == daemonPid { - log.Printf("Tailscaled exited") - os.Exit(0) + if !startupTasksDone { + if (!wantProxy || currentIPs != deephash.Sum{}) && (!wantDeviceInfo || currentDeviceInfo != deephash.Sum{}) { + // This log message is used in tests to detect when all + // post-auth configuration is done. + log.Println("Startup complete, waiting for shutdown signal") + startupTasksDone = true + + // // Reap all processes, since we are PID1 and need to collect zombies. We can + // // only start doing this once we've stopped shelling out to things + // // `tailscale up`, otherwise this goroutine can reap the CLI subprocesses + // // and wedge bringup. + reaper := func() { + defer wg.Done() + for { + var status unix.WaitStatus + pid, err := unix.Wait4(-1, &status, 0, nil) + if errors.Is(err, unix.EINTR) { + continue + } + if err != nil { + log.Fatalf("Waiting for exited processes: %v", err) + } + if pid == daemonProcess.Pid { + log.Printf("Tailscaled exited") + os.Exit(0) + } } + } - }() + wg.Add(1) + go reaper() + } } } } + wg.Wait() } // watchServeConfigChanges watches path for changes, and when it sees one, reads @@ -460,10 +494,8 @@ func readServeConfig(path, certDomain string) (*ipn.ServeConfig, error) { return &sc, nil } -func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, int, error) { +func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, *os.Process, error) { args := tailscaledArgs(cfg) - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, unix.SIGTERM, unix.SIGINT) // tailscaled runs without context, since it needs to persist // beyond the startup timeout in ctx. cmd := exec.Command("tailscaled", args...) @@ -474,13 +506,8 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient } log.Printf("Starting tailscaled") if err := cmd.Start(); err != nil { - return nil, 0, fmt.Errorf("starting tailscaled failed: %v", err) + return nil, nil, fmt.Errorf("starting tailscaled failed: %v", err) } - go func() { - <-sigCh - log.Printf("Received SIGTERM from container runtime, shutting down tailscaled") - cmd.Process.Signal(unix.SIGTERM) - }() // Wait for the socket file to appear, otherwise API ops will racily fail. log.Printf("Waiting for tailscaled socket") @@ -503,7 +530,7 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient UseSocketOnly: true, } - return tsClient, cmd.Process.Pid, nil + return tsClient, cmd.Process, nil } // tailscaledArgs uses cfg to construct the argv for tailscaled. @@ -801,3 +828,25 @@ func defaultBool(name string, defVal bool) bool { } return ret } + +// contextWithExitSignalWatch watches for SIGTERM/SIGINT signals. It returns a +// context that gets cancelled when a signal is received and a cancel function +// that can be called to free the resources when the watch should be stopped. +func contextWithExitSignalWatch() (context.Context, func()) { + closeChan := make(chan string) + ctx, cancel := context.WithCancel(context.Background()) + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + select { + case <-signalChan: + cancel() + case <-closeChan: + return + } + }() + f := func() { + closeChan <- "goodbye" + } + return ctx, f +}