fix(lifecycle): cleanup lifecycle

- removes unwieldy SkipUpdate return value in favor of errors.Is
- generalizes the code for all four phases
- allows timeout to be defined for all phases
- enables explicit unit in timeout label values (in addition to implicit minutes)
refactor-update
nils måsén 5 months ago
parent 76f9cea516
commit 023c1a7d93

@ -5,6 +5,7 @@ import (
"fmt"
"time"
c "github.com/containrrr/watchtower/pkg/container"
t "github.com/containrrr/watchtower/pkg/types"
)
@ -72,16 +73,16 @@ func (client MockClient) GetContainer(_ t.ContainerID) (t.Container, error) {
}
// ExecuteCommand is a mock method
func (client MockClient) ExecuteCommand(_ t.ContainerID, command string, _ int) (SkipUpdate bool, err error) {
func (client MockClient) ExecuteCommand(_ t.ContainerID, command string, _ time.Duration) error {
switch command {
case "/PreUpdateReturn0.sh":
return false, nil
return nil
case "/PreUpdateReturn1.sh":
return false, fmt.Errorf("command exited with code 1")
return fmt.Errorf("command exited with code 1")
case "/PreUpdateReturn75.sh":
return true, nil
return c.ErrorLifecycleSkip
default:
return false, nil
return nil
}
}

@ -0,0 +1,15 @@
package util
import (
"strconv"
"time"
)
// ParseDuration parses the input string as a duration, treating a plain number as implicitly using the specified unit
func ParseDuration(input string, unitlessUnit time.Duration) (time.Duration, error) {
if unitless, err := strconv.Atoi(input); err == nil {
return unitlessUnit * time.Duration(unitless), nil
}
return time.ParseDuration(input)
}

@ -31,7 +31,7 @@ type Client interface {
StartContainer(t.Container) (t.ContainerID, error)
RenameContainer(t.Container, string) error
IsContainerStale(t.Container, t.UpdateParams) (stale bool, latestImage t.ImageID, err error)
ExecuteCommand(containerID t.ContainerID, command string, timeout int) (SkipUpdate bool, err error)
ExecuteCommand(containerID t.ContainerID, command string, timeout time.Duration) error
RemoveImageByID(t.ImageID) error
WarnOnHeadPullFailed(container t.Container) bool
}
@ -439,7 +439,7 @@ func (client dockerClient) RemoveImageByID(id t.ImageID) error {
return err
}
func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command string, timeout int) (SkipUpdate bool, err error) {
func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command string, timeout time.Duration) error {
bg := context.Background()
clog := log.WithField("containerID", containerID)
@ -452,7 +452,7 @@ func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command str
exec, err := client.api.ContainerExecCreate(bg, string(containerID), execConfig)
if err != nil {
return false, err
return err
}
response, attachErr := client.api.ContainerExecAttach(bg, exec.ID, types.ExecStartCheck{
@ -467,7 +467,7 @@ func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command str
execStartCheck := types.ExecStartCheck{Detach: false, Tty: true}
err = client.api.ContainerExecStart(bg, exec.ID, execStartCheck)
if err != nil {
return false, err
return err
}
var output string
@ -484,24 +484,16 @@ func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command str
// Inspect the exec to get the exit code and print a message if the
// exit code is not success.
skipUpdate, err := client.waitForExecOrTimeout(bg, exec.ID, output, timeout)
if err != nil {
return true, err
}
return skipUpdate, nil
return client.waitForExecOrTimeout(bg, exec.ID, output, timeout)
}
func (client dockerClient) waitForExecOrTimeout(bg context.Context, ID string, execOutput string, timeout int) (SkipUpdate bool, err error) {
func (client dockerClient) waitForExecOrTimeout(ctx context.Context, ID string, execOutput string, timeout time.Duration) error {
const ExTempFail = 75
var ctx context.Context
var cancel context.CancelFunc
if timeout > 0 {
ctx, cancel = context.WithTimeout(bg, time.Duration(timeout)*time.Minute)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
} else {
ctx = bg
}
for {
@ -516,7 +508,7 @@ func (client dockerClient) waitForExecOrTimeout(bg context.Context, ID string, e
}).Debug("Awaiting timeout or completion")
if err != nil {
return false, err
return err
}
if execInspect.Running {
time.Sleep(1 * time.Second)
@ -527,15 +519,15 @@ func (client dockerClient) waitForExecOrTimeout(bg context.Context, ID string, e
}
if execInspect.ExitCode == ExTempFail {
return true, nil
return ErrorLifecycleSkip
}
if execInspect.ExitCode > 0 {
return false, fmt.Errorf("command exited with code %v %s", execInspect.ExitCode, execOutput)
return fmt.Errorf("command exited with code %v", execInspect.ExitCode)
}
break
}
return false, nil
return nil
}
func (client dockerClient) waitForStopOrTimeout(c t.Container, waitTime time.Duration) error {

@ -308,7 +308,7 @@ var _ = Describe("the client", func() {
),
)
_, err := client.ExecuteCommand(containerID, cmd, 1)
err := client.ExecuteCommand(containerID, cmd, 1)
Expect(err).NotTo(HaveOccurred())
// Note: Since Execute requires opening up a raw TCP stream to the daemon for the output, this will fail
// when using the mock API server. Regardless of the outcome, the log should include the container ID

@ -219,44 +219,6 @@ func (c Container) IsWatchtower() bool {
return ContainsWatchtowerLabel(c.containerInfo.Config.Labels)
}
// PreUpdateTimeout checks whether a container has a specific timeout set
// for how long the pre-update command is allowed to run. This value is expressed
// either as an integer, in minutes, or as 0 which will allow the command/script
// to run indefinitely. Users should be cautious with the 0 option, as that
// could result in watchtower waiting forever.
func (c Container) PreUpdateTimeout() int {
var minutes int
var err error
val := c.getLabelValueOrEmpty(preUpdateTimeoutLabel)
minutes, err = strconv.Atoi(val)
if err != nil || val == "" {
return 1
}
return minutes
}
// PostUpdateTimeout checks whether a container has a specific timeout set
// for how long the post-update command is allowed to run. This value is expressed
// either as an integer, in minutes, or as 0 which will allow the command/script
// to run indefinitely. Users should be cautious with the 0 option, as that
// could result in watchtower waiting forever.
func (c Container) PostUpdateTimeout() int {
var minutes int
var err error
val := c.getLabelValueOrEmpty(postUpdateTimeoutLabel)
minutes, err = strconv.Atoi(val)
if err != nil || val == "" {
return 1
}
return minutes
}
// StopSignal returns the custom stop signal (if any) that is encoded in the
// container's metadata. If the container has not specified a custom stop
// signal, the empty string "" is returned.

@ -6,3 +6,6 @@ var errorNoImageInfo = errors.New("no available image info")
var errorNoContainerInfo = errors.New("no available container info")
var errorInvalidConfig = errors.New("container configuration missing or invalid")
var errorLabelNotFound = errors.New("label was not found in container")
// ErrorLifecycleSkip is returned by a lifecycle hook when the exit code of the command indicated that it ought to be skipped
var ErrorLifecycleSkip = errors.New("skipping container as the pre-update command returned exit code 75 (EX_TEMPFAIL)")

@ -1,67 +1,92 @@
package container
import "strconv"
import (
"errors"
"fmt"
"strconv"
"time"
"github.com/containrrr/watchtower/internal/util"
wt "github.com/containrrr/watchtower/pkg/types"
"github.com/sirupsen/logrus"
)
const (
watchtowerLabel = "com.centurylinklabs.watchtower"
signalLabel = "com.centurylinklabs.watchtower.stop-signal"
enableLabel = "com.centurylinklabs.watchtower.enable"
monitorOnlyLabel = "com.centurylinklabs.watchtower.monitor-only"
noPullLabel = "com.centurylinklabs.watchtower.no-pull"
dependsOnLabel = "com.centurylinklabs.watchtower.depends-on"
zodiacLabel = "com.centurylinklabs.zodiac.original-image"
scope = "com.centurylinklabs.watchtower.scope"
preCheckLabel = "com.centurylinklabs.watchtower.lifecycle.pre-check"
postCheckLabel = "com.centurylinklabs.watchtower.lifecycle.post-check"
preUpdateLabel = "com.centurylinklabs.watchtower.lifecycle.pre-update"
postUpdateLabel = "com.centurylinklabs.watchtower.lifecycle.post-update"
preUpdateTimeoutLabel = "com.centurylinklabs.watchtower.lifecycle.pre-update-timeout"
postUpdateTimeoutLabel = "com.centurylinklabs.watchtower.lifecycle.post-update-timeout"
namespace = "com.centurylinklabs.watchtower"
watchtowerLabel = namespace
signalLabel = namespace + ".stop-signal"
enableLabel = namespace + ".enable"
monitorOnlyLabel = namespace + ".monitor-only"
noPullLabel = namespace + ".no-pull"
dependsOnLabel = namespace + ".depends-on"
zodiacLabel = "com.centurylinklabs.zodiac.original-image"
scope = namespace + ".scope"
)
// GetLifecyclePreCheckCommand returns the pre-check command set in the container metadata or an empty string
func (c Container) GetLifecyclePreCheckCommand() string {
return c.getLabelValueOrEmpty(preCheckLabel)
// ContainsWatchtowerLabel takes a map of labels and values and tells
// the consumer whether it contains a valid watchtower instance label
func ContainsWatchtowerLabel(labels map[string]string) bool {
val, ok := labels[watchtowerLabel]
return ok && val == "true"
}
// GetLifecyclePostCheckCommand returns the post-check command set in the container metadata or an empty string
func (c Container) GetLifecyclePostCheckCommand() string {
return c.getLabelValueOrEmpty(postCheckLabel)
}
// GetLifecycleCommand returns the lifecycle command set in the container metadata or an empty string
func (c *Container) GetLifecycleCommand(phase wt.LifecyclePhase) string {
label := fmt.Sprintf("%v.lifecycle.%v", namespace, phase)
value, found := c.getLabelValue(label)
// GetLifecyclePreUpdateCommand returns the pre-update command set in the container metadata or an empty string
func (c Container) GetLifecyclePreUpdateCommand() string {
return c.getLabelValueOrEmpty(preUpdateLabel)
}
if !found {
return ""
}
// GetLifecyclePostUpdateCommand returns the post-update command set in the container metadata or an empty string
func (c Container) GetLifecyclePostUpdateCommand() string {
return c.getLabelValueOrEmpty(postUpdateLabel)
return value
}
// ContainsWatchtowerLabel takes a map of labels and values and tells
// the consumer whether it contains a valid watchtower instance label
func ContainsWatchtowerLabel(labels map[string]string) bool {
val, ok := labels[watchtowerLabel]
return ok && val == "true"
// GetLifecycleTimeout checks whether a container has a specific timeout set
// for how long the lifecycle command is allowed to run. This value is expressed
// either as a duration, an integer (minutes implied), or as 0 which will allow the command/script
// to run indefinitely. Users should be cautious with the 0 option, as that
// could result in watchtower waiting forever.
func (c *Container) GetLifecycleTimeout(phase wt.LifecyclePhase) time.Duration {
label := fmt.Sprintf("%v.lifecycle.%v-timeout", namespace, phase)
timeout, err := c.getDurationLabelValue(label, time.Minute)
if err != nil {
timeout = time.Minute
if !errors.Is(err, errorLabelNotFound) {
logrus.WithError(err).Errorf("could not parse timeout label value for %v lifecycle", phase)
}
}
return timeout
}
func (c Container) getLabelValueOrEmpty(label string) string {
func (c *Container) getLabelValueOrEmpty(label string) string {
if val, ok := c.containerInfo.Config.Labels[label]; ok {
return val
}
return ""
}
func (c Container) getLabelValue(label string) (string, bool) {
func (c *Container) getLabelValue(label string) (string, bool) {
val, ok := c.containerInfo.Config.Labels[label]
return val, ok
}
func (c Container) getBoolLabelValue(label string) (bool, error) {
func (c *Container) getBoolLabelValue(label string) (bool, error) {
if strVal, ok := c.containerInfo.Config.Labels[label]; ok {
value, err := strconv.ParseBool(strVal)
return value, err
}
return false, errorLabelNotFound
}
func (c *Container) getDurationLabelValue(label string, unitlessUnit time.Duration) (time.Duration, error) {
value, found := c.getLabelValue(label)
if !found || len(value) < 1 {
return 0, errorLabelNotFound
}
return util.ParseDuration(value, unitlessUnit)
}

@ -6,101 +6,60 @@ import (
log "github.com/sirupsen/logrus"
)
// ExecutePreChecks tries to run the pre-check lifecycle hook for all containers included by the current filter.
func ExecutePreChecks(client container.Client, params types.UpdateParams) {
containers, err := client.ListContainers(params.Filter)
if err != nil {
return
}
for _, currentContainer := range containers {
ExecutePreCheckCommand(client, currentContainer)
}
}
// ExecutePostChecks tries to run the post-check lifecycle hook for all containers included by the current filter.
func ExecutePostChecks(client container.Client, params types.UpdateParams) {
containers, err := client.ListContainers(params.Filter)
if err != nil {
return
}
for _, currentContainer := range containers {
ExecutePostCheckCommand(client, currentContainer)
}
}
type ExecCommandFunc func(client container.Client, container types.Container)
// ExecutePreCheckCommand tries to run the pre-check lifecycle hook for a single container.
func ExecutePreCheckCommand(client container.Client, container types.Container) {
clog := log.WithField("container", container.Name())
command := container.GetLifecyclePreCheckCommand()
if len(command) == 0 {
clog.Debug("No pre-check command supplied. Skipping")
return
}
clog.Debug("Executing pre-check command.")
_, err := client.ExecuteCommand(container.ID(), command, 1)
err := ExecuteLifeCyclePhaseCommand(types.PreCheck, client, container)
if err != nil {
clog.Error(err)
log.WithField("container", container.Name()).Error(err)
}
}
// ExecutePostCheckCommand tries to run the post-check lifecycle hook for a single container.
func ExecutePostCheckCommand(client container.Client, container types.Container) {
clog := log.WithField("container", container.Name())
command := container.GetLifecyclePostCheckCommand()
if len(command) == 0 {
clog.Debug("No post-check command supplied. Skipping")
return
}
clog.Debug("Executing post-check command.")
_, err := client.ExecuteCommand(container.ID(), command, 1)
err := ExecuteLifeCyclePhaseCommand(types.PostCheck, client, container)
if err != nil {
clog.Error(err)
log.WithField("container", container.Name()).Error(err)
}
}
// ExecutePreUpdateCommand tries to run the pre-update lifecycle hook for a single container.
func ExecutePreUpdateCommand(client container.Client, container types.Container) (SkipUpdate bool, err error) {
timeout := container.PreUpdateTimeout()
command := container.GetLifecyclePreUpdateCommand()
clog := log.WithField("container", container.Name())
if len(command) == 0 {
clog.Debug("No pre-update command supplied. Skipping")
return false, nil
}
if !container.IsRunning() || container.IsRestarting() {
clog.Debug("Container is not running. Skipping pre-update command.")
return false, nil
}
clog.Debug("Executing pre-update command.")
return client.ExecuteCommand(container.ID(), command, timeout)
func ExecutePreUpdateCommand(client container.Client, container types.Container) error {
return ExecuteLifeCyclePhaseCommand(types.PreUpdate, client, container)
}
// ExecutePostUpdateCommand tries to run the post-update lifecycle hook for a single container.
func ExecutePostUpdateCommand(client container.Client, newContainerID types.ContainerID) {
newContainer, err := client.GetContainer(newContainerID)
timeout := newContainer.PostUpdateTimeout()
if err != nil {
log.WithField("containerID", newContainerID.ShortID()).Error(err)
return
}
clog := log.WithField("container", newContainer.Name())
command := newContainer.GetLifecyclePostUpdateCommand()
if len(command) == 0 {
clog.Debug("No post-update command supplied. Skipping")
return
err = ExecuteLifeCyclePhaseCommand(types.PostUpdate, client, newContainer)
if err != nil {
log.WithField("container", newContainer.Name()).Error(err)
}
}
clog.Debug("Executing post-update command.")
_, err = client.ExecuteCommand(newContainerID, command, timeout)
// ExecuteLifeCyclePhaseCommand tries to run the corresponding lifecycle hook for a single container.
func ExecuteLifeCyclePhaseCommand(phase types.LifecyclePhase, client container.Client, container types.Container) error {
if err != nil {
clog.Error(err)
timeout := container.GetLifecycleTimeout(phase)
command := container.GetLifecycleCommand(phase)
clog := log.WithField("container", container.Name())
if len(command) == 0 {
clog.Debugf("No %v command supplied. Skipping", phase)
return nil
}
if !container.IsRunning() || container.IsRestarting() {
clog.Debugf("Container is not running. Skipping %v command.", phase)
return nil
}
clog.Debugf("Executing %v command.", phase)
return client.ExecuteCommand(container.ID(), command, timeout)
}

@ -2,6 +2,7 @@ package types
import (
"strings"
"time"
"github.com/docker/docker/api/types"
dc "github.com/docker/docker/api/types/container"
@ -60,18 +61,14 @@ type Container interface {
StopSignal() string
HasImageInfo() bool
ImageInfo() *types.ImageInspect
GetLifecyclePreCheckCommand() string
GetLifecyclePostCheckCommand() string
GetLifecyclePreUpdateCommand() string
GetLifecyclePostUpdateCommand() string
GetLifecycleCommand(LifecyclePhase) string
GetLifecycleTimeout(LifecyclePhase) time.Duration
VerifyConfiguration() error
SetStale(bool)
IsStale() bool
IsNoPull(UpdateParams) bool
SetLinkedToRestarting(bool)
IsLinkedToRestarting() bool
PreUpdateTimeout() int
PostUpdateTimeout() int
IsRestarting() bool
GetCreateConfig() *dc.Config
GetCreateHostConfig() *dc.HostConfig

@ -0,0 +1,27 @@
package types
import "fmt"
type LifecyclePhase int
const (
PreCheck LifecyclePhase = iota
PreUpdate
PostUpdate
PostCheck
)
func (p LifecyclePhase) String() string {
switch p {
case PreCheck:
return "pre-check"
case PreUpdate:
return "pre-update"
case PostUpdate:
return "post-update"
case PostCheck:
return "post-check"
default:
return fmt.Sprintf("invalid(%d)", p)
}
}
Loading…
Cancel
Save