cmd/containerboot: fix unclean shutdown (#10035)

* cmd/containerboot: shut down cleanly on SIGTERM

Make sure that tailscaled watcher returns when
SIGTERM is received and also that it shuts down
before tailscaled exits.

Updates tailscale/tailscale#10090

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
pull/10293/head
Irbe Krumina 1 year ago committed by GitHub
parent 7238586652
commit 664ebb14d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -69,6 +69,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time" "time"
@ -181,10 +182,16 @@ func main() {
} }
} }
client, daemonPid, err := startTailscaled(bootCtx, cfg) client, daemonProcess, err := startTailscaled(bootCtx, cfg)
if err != nil { if err != nil {
log.Fatalf("failed to bring up tailscale: %v", err) log.Fatalf("failed to bring up tailscale: %v", err)
} }
killTailscaled := func() {
if err := daemonProcess.Signal(unix.SIGTERM); err != nil {
log.Fatalf("error shutting tailscaled down: %v", err)
}
}
defer killTailscaled()
w, err := client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialPrefs|ipn.NotifyInitialState) w, err := client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialPrefs|ipn.NotifyInitialState)
if err != nil { if err != nil {
@ -252,7 +259,7 @@ authLoop:
w.Close() w.Close()
ctx, cancel := context.WithCancel(context.Background()) // no deadline now that we're in steady state ctx, cancel := contextWithExitSignalWatch()
defer cancel() defer cancel()
if cfg.AuthOnce { if cfg.AuthOnce {
@ -306,84 +313,111 @@ authLoop:
log.Fatalf("error creating new netfilter runner: %v", err) log.Fatalf("error creating new netfilter runner: %v", err)
} }
} }
notifyChan := make(chan ipn.Notify)
errChan := make(chan error)
go func() {
for {
n, err := w.Next()
if err != nil {
errChan <- err
break
} else {
notifyChan <- n
}
}
}()
var wg sync.WaitGroup
runLoop:
for { for {
n, err := w.Next() select {
if err != nil { case <-ctx.Done():
// Although killTailscaled() is deferred earlier, if we
// have started the reaper defined below, we need to
// kill tailscaled and let reaper clean up child
// processes.
killTailscaled()
break runLoop
case err := <-errChan:
log.Fatalf("failed to read from tailscaled: %v", err) log.Fatalf("failed to read from tailscaled: %v", err)
} case n := <-notifyChan:
if n.State != nil && *n.State != ipn.Running {
if n.State != nil && *n.State != ipn.Running { // Something's gone wrong and we've left the authenticated state.
// Something's gone wrong and we've left the authenticated state. // Our container image never recovered gracefully from this, and the
// Our container image never recovered gracefully from this, and the // control flow required to make it work now is hard. So, just crash
// control flow required to make it work now is hard. So, just crash // the container and rely on the container runtime to restart us,
// the container and rely on the container runtime to restart us, // whereupon we'll go through initial auth again.
// whereupon we'll go through initial auth again. log.Fatalf("tailscaled left running state (now in state %q), exiting", *n.State)
log.Fatalf("tailscaled left running state (now in state %q), exiting", *n.State)
}
if n.NetMap != nil {
addrs := n.NetMap.SelfNode.Addresses().AsSlice()
newCurrentIPs := deephash.Hash(&addrs)
ipsHaveChanged := newCurrentIPs != currentIPs
if cfg.ProxyTo != "" && len(addrs) > 0 && ipsHaveChanged {
log.Printf("Installing proxy rules")
if err := installIngressForwardingRule(ctx, cfg.ProxyTo, addrs, nfr); err != nil {
log.Fatalf("installing ingress proxy rules: %v", err)
}
} }
if cfg.ServeConfigPath != "" && len(n.NetMap.DNS.CertDomains) > 0 { if n.NetMap != nil {
cd := n.NetMap.DNS.CertDomains[0] addrs := n.NetMap.SelfNode.Addresses().AsSlice()
prev := certDomain.Swap(ptr.To(cd)) newCurrentIPs := deephash.Hash(&addrs)
if prev == nil || *prev != cd { ipsHaveChanged := newCurrentIPs != currentIPs
select { if cfg.ProxyTo != "" && len(addrs) > 0 && ipsHaveChanged {
case certDomainChanged <- true: log.Printf("Installing proxy rules")
default: if err := installIngressForwardingRule(ctx, cfg.ProxyTo, addrs, nfr); err != nil {
log.Fatalf("installing ingress proxy rules: %v", err)
} }
} }
} if cfg.ServeConfigPath != "" && len(n.NetMap.DNS.CertDomains) > 0 {
if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) > 0 { cd := n.NetMap.DNS.CertDomains[0]
if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil { prev := certDomain.Swap(ptr.To(cd))
log.Fatalf("installing egress proxy rules: %v", err) if prev == nil || *prev != cd {
select {
case certDomainChanged <- true:
default:
}
}
} }
} if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) > 0 {
currentIPs = newCurrentIPs if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil {
log.Fatalf("installing egress proxy rules: %v", err)
}
}
currentIPs = newCurrentIPs
deviceInfo := []any{n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name()} deviceInfo := []any{n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name()}
if cfg.InKubernetes && cfg.KubernetesCanPatch && cfg.KubeSecret != "" && deephash.Update(&currentDeviceInfo, &deviceInfo) { if cfg.InKubernetes && cfg.KubernetesCanPatch && cfg.KubeSecret != "" && deephash.Update(&currentDeviceInfo, &deviceInfo) {
if err := storeDeviceInfo(ctx, cfg.KubeSecret, n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { if err := storeDeviceInfo(ctx, cfg.KubeSecret, n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil {
log.Fatalf("storing device ID in kube secret: %v", err) log.Fatalf("storing device ID in kube secret: %v", err)
}
} }
} }
} if !startupTasksDone {
if !startupTasksDone { if (!wantProxy || currentIPs != deephash.Sum{}) && (!wantDeviceInfo || currentDeviceInfo != deephash.Sum{}) {
if (!wantProxy || currentIPs != deephash.Sum{}) && (!wantDeviceInfo || currentDeviceInfo != deephash.Sum{}) { // This log message is used in tests to detect when all
// This log message is used in tests to detect when all // post-auth configuration is done.
// post-auth configuration is done. log.Println("Startup complete, waiting for shutdown signal")
log.Println("Startup complete, waiting for shutdown signal") startupTasksDone = true
startupTasksDone = true
// // Reap all processes, since we are PID1 and need to collect zombies. We can
// Reap all processes, since we are PID1 and need to collect zombies. We can // // only start doing this once we've stopped shelling out to things
// only start doing this once we've stopped shelling out to things // // `tailscale up`, otherwise this goroutine can reap the CLI subprocesses
// `tailscale up`, otherwise this goroutine can reap the CLI subprocesses // // and wedge bringup.
// and wedge bringup. reaper := func() {
go func() { defer wg.Done()
for { for {
var status unix.WaitStatus var status unix.WaitStatus
pid, err := unix.Wait4(-1, &status, 0, nil) pid, err := unix.Wait4(-1, &status, 0, nil)
if errors.Is(err, unix.EINTR) { if errors.Is(err, unix.EINTR) {
continue continue
} }
if err != nil { if err != nil {
log.Fatalf("Waiting for exited processes: %v", err) log.Fatalf("Waiting for exited processes: %v", err)
} }
if pid == daemonPid { if pid == daemonProcess.Pid {
log.Printf("Tailscaled exited") log.Printf("Tailscaled exited")
os.Exit(0) os.Exit(0)
}
} }
} }
}() wg.Add(1)
go reaper()
}
} }
} }
} }
wg.Wait()
} }
// watchServeConfigChanges watches path for changes, and when it sees one, reads // watchServeConfigChanges watches path for changes, and when it sees one, reads
@ -460,10 +494,8 @@ func readServeConfig(path, certDomain string) (*ipn.ServeConfig, error) {
return &sc, nil return &sc, nil
} }
func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, int, error) { func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, *os.Process, error) {
args := tailscaledArgs(cfg) args := tailscaledArgs(cfg)
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, unix.SIGTERM, unix.SIGINT)
// tailscaled runs without context, since it needs to persist // tailscaled runs without context, since it needs to persist
// beyond the startup timeout in ctx. // beyond the startup timeout in ctx.
cmd := exec.Command("tailscaled", args...) cmd := exec.Command("tailscaled", args...)
@ -474,13 +506,8 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient
} }
log.Printf("Starting tailscaled") log.Printf("Starting tailscaled")
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
return nil, 0, fmt.Errorf("starting tailscaled failed: %v", err) return nil, nil, fmt.Errorf("starting tailscaled failed: %v", err)
} }
go func() {
<-sigCh
log.Printf("Received SIGTERM from container runtime, shutting down tailscaled")
cmd.Process.Signal(unix.SIGTERM)
}()
// Wait for the socket file to appear, otherwise API ops will racily fail. // Wait for the socket file to appear, otherwise API ops will racily fail.
log.Printf("Waiting for tailscaled socket") log.Printf("Waiting for tailscaled socket")
@ -503,7 +530,7 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient
UseSocketOnly: true, UseSocketOnly: true,
} }
return tsClient, cmd.Process.Pid, nil return tsClient, cmd.Process, nil
} }
// tailscaledArgs uses cfg to construct the argv for tailscaled. // tailscaledArgs uses cfg to construct the argv for tailscaled.
@ -801,3 +828,25 @@ func defaultBool(name string, defVal bool) bool {
} }
return ret return ret
} }
// contextWithExitSignalWatch watches for SIGTERM/SIGINT signals. It returns a
// context that gets cancelled when a signal is received and a cancel function
// that can be called to free the resources when the watch should be stopped.
func contextWithExitSignalWatch() (context.Context, func()) {
closeChan := make(chan string)
ctx, cancel := context.WithCancel(context.Background())
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
select {
case <-signalChan:
cancel()
case <-closeChan:
return
}
}()
f := func() {
closeChan <- "goodbye"
}
return ctx, f
}

Loading…
Cancel
Save