feat(api): implement new api handler

feat/extended-api
nils måsén 11 months ago
parent 72e437f173
commit 47091761a5

@ -1,22 +1,19 @@
package cmd package cmd
import ( import (
"errors"
"math"
"net/http"
"os" "os"
"os/signal" "os/signal"
"strconv"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
"github.com/containrrr/watchtower/internal/actions" "github.com/containrrr/watchtower/internal/actions"
"github.com/containrrr/watchtower/internal/flags" "github.com/containrrr/watchtower/internal/flags"
"github.com/containrrr/watchtower/internal/meta" "github.com/containrrr/watchtower/internal/meta"
"github.com/containrrr/watchtower/internal/util"
"github.com/containrrr/watchtower/pkg/api" "github.com/containrrr/watchtower/pkg/api"
apiMetrics "github.com/containrrr/watchtower/pkg/api/metrics" "github.com/containrrr/watchtower/pkg/api/updates"
"github.com/containrrr/watchtower/pkg/api/update"
"github.com/containrrr/watchtower/pkg/container" "github.com/containrrr/watchtower/pkg/container"
"github.com/containrrr/watchtower/pkg/filters" "github.com/containrrr/watchtower/pkg/filters"
"github.com/containrrr/watchtower/pkg/metrics" "github.com/containrrr/watchtower/pkg/metrics"
@ -31,20 +28,16 @@ import (
var ( var (
client container.Client client container.Client
scheduleSpec string scheduleSpec string
cleanup bool
noRestart bool
monitorOnly bool
enableLabel bool enableLabel bool
disableContainers []string disableContainers []string
notifier t.Notifier notifier t.Notifier
timeout time.Duration
lifecycleHooks bool
rollingRestart bool
scope string scope string
labelPrecedence bool
up = t.UpdateParams{}
) )
var rootCmd = NewRootCommand() var rootCmd = NewRootCommand()
var localLog = notifications.LocalLog
// NewRootCommand creates the root command for watchtower // NewRootCommand creates the root command for watchtower
func NewRootCommand() *cobra.Command { func NewRootCommand() *cobra.Command {
@ -87,18 +80,18 @@ func PreRun(cmd *cobra.Command, _ []string) {
scheduleSpec, _ = f.GetString("schedule") scheduleSpec, _ = f.GetString("schedule")
flags.GetSecretsFromFiles(cmd) flags.GetSecretsFromFiles(cmd)
cleanup, noRestart, monitorOnly, timeout = flags.ReadFlags(cmd) up.Cleanup, up.NoRestart, up.MonitorOnly, up.Timeout = flags.ReadFlags(cmd)
if timeout < 0 { if up.Timeout < 0 {
log.Fatal("Please specify a positive value for timeout value.") log.Fatal("Please specify a positive value for timeout value.")
} }
enableLabel, _ = f.GetBool("label-enable") enableLabel, _ = f.GetBool("label-enable")
disableContainers, _ = f.GetStringSlice("disable-containers") disableContainers, _ = f.GetStringSlice("disable-containers")
lifecycleHooks, _ = f.GetBool("enable-lifecycle-hooks") up.LifecycleHooks, _ = f.GetBool("enable-lifecycle-hooks")
rollingRestart, _ = f.GetBool("rolling-restart") up.RollingRestart, _ = f.GetBool("rolling-restart")
scope, _ = f.GetString("scope") scope, _ = f.GetString("scope")
labelPrecedence, _ = f.GetBool("label-take-precedence") up.LabelPrecedence, _ = f.GetBool("label-take-precedence")
if scope != "" { if scope != "" {
log.Debugf(`Using scope %q`, scope) log.Debugf(`Using scope %q`, scope)
@ -110,25 +103,22 @@ func PreRun(cmd *cobra.Command, _ []string) {
log.Fatal(err) log.Fatal(err)
} }
var clientOpts = container.ClientOptions{}
noPull, _ := f.GetBool("no-pull") noPull, _ := f.GetBool("no-pull")
includeStopped, _ := f.GetBool("include-stopped") clientOpts.PullImages = !noPull
includeRestarting, _ := f.GetBool("include-restarting") clientOpts.IncludeStopped, _ = f.GetBool("include-stopped")
reviveStopped, _ := f.GetBool("revive-stopped") clientOpts.IncludeRestarting, _ = f.GetBool("include-restarting")
removeVolumes, _ := f.GetBool("remove-volumes") clientOpts.ReviveStopped, _ = f.GetBool("revive-stopped")
clientOpts.RemoveVolumes, _ = f.GetBool("remove-volumes")
warnOnHeadPullFailed, _ := f.GetString("warn-on-head-failure") warnOnHeadPullFailed, _ := f.GetString("warn-on-head-failure")
clientOpts.WarnOnHeadFailed = container.WarningStrategy(warnOnHeadPullFailed)
if monitorOnly && noPull { if up.MonitorOnly && noPull {
log.Warn("Using `WATCHTOWER_NO_PULL` and `WATCHTOWER_MONITOR_ONLY` simultaneously might lead to no action being taken at all. If this is intentional, you may safely ignore this message.") log.Warn("Using `WATCHTOWER_NO_PULL` and `WATCHTOWER_MONITOR_ONLY` simultaneously might lead to no action being taken at all. If this is intentional, you may safely ignore this message.")
} }
client = container.NewClient(container.ClientOptions{ client = container.NewClient(clientOpts)
PullImages: !noPull,
IncludeStopped: includeStopped,
ReviveStopped: reviveStopped,
RemoveVolumes: removeVolumes,
IncludeRestarting: includeRestarting,
WarnOnHeadFailed: container.WarningStrategy(warnOnHeadPullFailed),
})
notifier = notifications.NewNotifier(cmd) notifier = notifications.NewNotifier(cmd)
notifier.AddLogHook() notifier.AddLogHook()
@ -137,13 +127,16 @@ func PreRun(cmd *cobra.Command, _ []string) {
// Run is the main execution flow of the command // Run is the main execution flow of the command
func Run(c *cobra.Command, names []string) { func Run(c *cobra.Command, names []string) {
filter, filterDesc := filters.BuildFilter(names, disableContainers, enableLabel, scope) filter, filterDesc := filters.BuildFilter(names, disableContainers, enableLabel, scope)
up.Filter = filter
runOnce, _ := c.PersistentFlags().GetBool("run-once") runOnce, _ := c.PersistentFlags().GetBool("run-once")
enableUpdateAPI, _ := c.PersistentFlags().GetBool("http-api-update") enableUpdateAPI, _ := c.PersistentFlags().GetBool("http-api-updates")
enableMetricsAPI, _ := c.PersistentFlags().GetBool("http-api-metrics") enableMetricsAPI, _ := c.PersistentFlags().GetBool("http-api-metrics")
unblockHTTPAPI, _ := c.PersistentFlags().GetBool("http-api-periodic-polls") unblockHTTPAPI, _ := c.PersistentFlags().GetBool("http-api-periodic-polls")
apiToken, _ := c.PersistentFlags().GetString("http-api-token") apiToken, _ := c.PersistentFlags().GetString("http-api-token")
healthCheck, _ := c.PersistentFlags().GetBool("health-check") healthCheck, _ := c.PersistentFlags().GetBool("health-check")
enableScheduler := !enableUpdateAPI || unblockHTTPAPI
if healthCheck { if healthCheck {
// health check should not have pid 1 // health check should not have pid 1
if os.Getpid() == 1 { if os.Getpid() == 1 {
@ -153,116 +146,113 @@ func Run(c *cobra.Command, names []string) {
os.Exit(0) os.Exit(0)
} }
if rollingRestart && monitorOnly { if up.RollingRestart && up.MonitorOnly {
log.Fatal("Rolling restarts is not compatible with the global monitor only flag") log.Fatal("Rolling restarts is not compatible with the global monitor only flag")
} }
awaitDockerClient() awaitDockerClient()
if err := actions.CheckForSanity(client, filter, rollingRestart); err != nil { if err := actions.CheckForSanity(client, up.Filter, up.RollingRestart); err != nil {
logNotifyExit(err) logNotifyExit(err)
} }
if runOnce { if runOnce {
writeStartupMessage(c, time.Time{}, filterDesc) writeStartupMessage(c, time.Time{}, filterDesc)
runUpdatesWithNotifications(filter) runUpdatesWithNotifications(up)
notifier.Close() notifier.Close()
os.Exit(0) os.Exit(0)
return return
} }
if err := actions.CheckForMultipleWatchtowerInstances(client, cleanup, scope); err != nil { if err := actions.CheckForMultipleWatchtowerInstances(client, up.Cleanup, scope); err != nil {
logNotifyExit(err) logNotifyExit(err)
} }
// The lock is shared between the scheduler and the HTTP API. It only allows one update to run at a time. // The lock is shared between the scheduler and the HTTP API. It only allows one updates to run at a time.
updateLock := make(chan bool, 1) updateLock := sync.Mutex{}
updateLock <- true
httpAPI := api.New(apiToken) httpAPI := api.New(apiToken)
if enableUpdateAPI { if enableUpdateAPI {
updateHandler := update.New(func(images []string) { httpAPI.EnableUpdates(func(paramsFunc updates.ModifyParamsFunc) t.Report {
metric := runUpdatesWithNotifications(filters.FilterByImage(images, filter)) apiUpdateParams := up
metrics.RegisterScan(metric) paramsFunc(&apiUpdateParams)
}, updateLock) if up.MonitorOnly && !apiUpdateParams.MonitorOnly {
httpAPI.RegisterFunc(updateHandler.Path, updateHandler.Handle) apiUpdateParams.MonitorOnly = true
// If polling isn't enabled the scheduler is never started and localLog.Warn("Ignoring request to disable monitor only through API")
// we need to trigger the startup messages manually.
if !unblockHTTPAPI {
writeStartupMessage(c, time.Time{}, filterDesc)
} }
report := runUpdatesWithNotifications(apiUpdateParams)
metrics.RegisterScan(metrics.NewMetric(report))
return report
}, &updateLock)
} }
if enableMetricsAPI { if enableMetricsAPI {
metricsHandler := apiMetrics.New() httpAPI.EnableMetrics()
httpAPI.RegisterHandler(metricsHandler.Path, metricsHandler.Handle)
} }
if err := httpAPI.Start(enableUpdateAPI && !unblockHTTPAPI); err != nil && !errors.Is(err, http.ErrServerClosed) { if err := httpAPI.Start(); err != nil {
log.Error("failed to start API", err) log.Error("failed to start API", err)
} }
if err := runUpgradesOnSchedule(c, filter, filterDesc, updateLock); err != nil { var firstScan time.Time
log.Error(err) var scheduler *cron.Cron
if enableScheduler {
var err error
scheduler, err = runUpgradesOnSchedule(up, &updateLock)
if err != nil {
log.Errorf("Failed to start scheduler: %v", err)
} else {
firstScan = scheduler.Entries()[0].Schedule.Next(time.Now())
} }
os.Exit(1)
} }
func logNotifyExit(err error) { writeStartupMessage(c, firstScan, filterDesc)
log.Error(err)
notifier.Close()
os.Exit(1)
}
func awaitDockerClient() { // Graceful shut-down on SIGINT/SIGTERM
log.Debug("Sleeping for a second to ensure the docker api client has been properly initialized.") interrupt := make(chan os.Signal, 1)
time.Sleep(1 * time.Second) signal.Notify(interrupt, os.Interrupt)
} signal.Notify(interrupt, syscall.SIGTERM)
func formatDuration(d time.Duration) string { recievedSignal := <-interrupt
sb := strings.Builder{} localLog.WithField("signal", recievedSignal).Infof("Got shutdown signal. Gracefully shutting down...")
if scheduler != nil {
scheduler.Stop()
}
hours := int64(d.Hours()) updateLock.Lock()
minutes := int64(math.Mod(d.Minutes(), 60)) go func() {
seconds := int64(math.Mod(d.Seconds(), 60)) time.Sleep(time.Second * 3)
updateLock.Unlock()
}()
if hours == 1 { waitFor(httpAPI.Stop(), "Waiting for HTTP API requests to complete...")
sb.WriteString("1 hour") waitFor(&updateLock, "Waiting for running updates to be finished...")
} else if hours != 0 {
sb.WriteString(strconv.FormatInt(hours, 10))
sb.WriteString(" hours")
}
if hours != 0 && (seconds != 0 || minutes != 0) { localLog.Info("Shutdown completed")
sb.WriteString(", ")
} }
if minutes == 1 { func waitFor(waitLock *sync.Mutex, delayMessage string) {
sb.WriteString("1 minute") if !waitLock.TryLock() {
} else if minutes != 0 { log.Info(delayMessage)
sb.WriteString(strconv.FormatInt(minutes, 10)) waitLock.Lock()
sb.WriteString(" minutes")
} }
if minutes != 0 && (seconds != 0) {
sb.WriteString(", ")
} }
if seconds == 1 { func logNotifyExit(err error) {
sb.WriteString("1 second") log.Error(err)
} else if seconds != 0 || (hours == 0 && minutes == 0) { notifier.Close()
sb.WriteString(strconv.FormatInt(seconds, 10)) os.Exit(1)
sb.WriteString(" seconds")
} }
return sb.String() func awaitDockerClient() {
log.Debug("Sleeping for a second to ensure the docker api client has been properly initialized.")
time.Sleep(1 * time.Second)
} }
func writeStartupMessage(c *cobra.Command, sched time.Time, filtering string) { func writeStartupMessage(c *cobra.Command, sched time.Time, filtering string) {
noStartupMessage, _ := c.PersistentFlags().GetBool("no-startup-message") noStartupMessage, _ := c.PersistentFlags().GetBool("no-startup-message")
enableUpdateAPI, _ := c.PersistentFlags().GetBool("http-api-update") enableUpdateAPI, _ := c.PersistentFlags().GetBool("http-api-updates")
var startupLog *log.Entry var startupLog *log.Entry
if noStartupMessage { if noStartupMessage {
@ -285,11 +275,11 @@ func writeStartupMessage(c *cobra.Command, sched time.Time, filtering string) {
startupLog.Info(filtering) startupLog.Info(filtering)
if !sched.IsZero() { if !sched.IsZero() {
until := formatDuration(time.Until(sched)) until := util.FormatDuration(time.Until(sched))
startupLog.Info("Scheduling first run: " + sched.Format("2006-01-02 15:04:05 -0700 MST")) startupLog.Info("Scheduling first run: " + sched.Format("2006-01-02 15:04:05 -0700 MST"))
startupLog.Info("Note that the first check will be performed in " + until) startupLog.Info("Note that the first check will be performed in " + until)
} else if runOnce, _ := c.PersistentFlags().GetBool("run-once"); runOnce { } else if runOnce, _ := c.PersistentFlags().GetBool("run-once"); runOnce {
startupLog.Info("Running a one time update.") startupLog.Info("Running a one time updates.")
} else { } else {
startupLog.Info("Periodic runs are not enabled.") startupLog.Info("Periodic runs are not enabled.")
} }
@ -309,25 +299,19 @@ func writeStartupMessage(c *cobra.Command, sched time.Time, filtering string) {
} }
} }
func runUpgradesOnSchedule(c *cobra.Command, filter t.Filter, filtering string, lock chan bool) error { func runUpgradesOnSchedule(updateParams t.UpdateParams, updateLock *sync.Mutex) (*cron.Cron, error) {
if lock == nil {
lock = make(chan bool, 1)
lock <- true
}
scheduler := cron.New() scheduler := cron.New()
err := scheduler.AddFunc( err := scheduler.AddFunc(
scheduleSpec, scheduleSpec,
func() { func() {
select { if updateLock.TryLock() {
case v := <-lock: defer updateLock.Unlock()
defer func() { lock <- v }() result := runUpdatesWithNotifications(updateParams)
metric := runUpdatesWithNotifications(filter) metrics.RegisterScan(metrics.NewMetric(result))
metrics.RegisterScan(metric) } else {
default:
// Update was skipped // Update was skipped
metrics.RegisterScan(nil) metrics.RegisterScan(nil)
log.Debug("Skipped another update already running.") log.Debug("Skipped another updates already running.")
} }
nextRuns := scheduler.Entries() nextRuns := scheduler.Entries()
@ -337,47 +321,28 @@ func runUpgradesOnSchedule(c *cobra.Command, filter t.Filter, filtering string,
}) })
if err != nil { if err != nil {
return err return nil, err
} }
writeStartupMessage(c, scheduler.Entries()[0].Schedule.Next(time.Now()), filtering)
scheduler.Start() scheduler.Start()
// Graceful shut-down on SIGINT/SIGTERM return scheduler, nil
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
signal.Notify(interrupt, syscall.SIGTERM)
<-interrupt
scheduler.Stop()
log.Info("Waiting for running update to be finished...")
<-lock
return nil
} }
func runUpdatesWithNotifications(filter t.Filter) *metrics.Metric { func runUpdatesWithNotifications(updateParams t.UpdateParams) t.Report {
notifier.StartNotification() notifier.StartNotification()
updateParams := t.UpdateParams{
Filter: filter,
Cleanup: cleanup,
NoRestart: noRestart,
Timeout: timeout,
MonitorOnly: monitorOnly,
LifecycleHooks: lifecycleHooks,
RollingRestart: rollingRestart,
LabelPrecedence: labelPrecedence,
}
result, err := actions.Update(client, updateParams) result, err := actions.Update(client, updateParams)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }
notifier.SendNotification(result) notifier.SendNotification(result)
metricResults := metrics.NewMetric(result)
notifications.LocalLog.WithFields(log.Fields{ localLog.WithFields(log.Fields{
"Scanned": metricResults.Scanned, "Scanned": len(result.Scanned()),
"Updated": metricResults.Updated, "Updated": len(result.Updated()),
"Failed": metricResults.Failed, "Failed": len(result.Failed()),
}).Info("Session done") }).Info("Session done")
return metricResults
return result
} }

@ -0,0 +1,47 @@
package util
import (
"math"
"strconv"
"strings"
"time"
)
func FormatDuration(d time.Duration) string {
sb := strings.Builder{}
hours := int64(d.Hours())
minutes := int64(math.Mod(d.Minutes(), 60))
seconds := int64(math.Mod(d.Seconds(), 60))
if hours == 1 {
sb.WriteString("1 hour")
} else if hours != 0 {
sb.WriteString(strconv.FormatInt(hours, 10))
sb.WriteString(" hours")
}
if hours != 0 && (seconds != 0 || minutes != 0) {
sb.WriteString(", ")
}
if minutes == 1 {
sb.WriteString("1 minute")
} else if minutes != 0 {
sb.WriteString(strconv.FormatInt(minutes, 10))
sb.WriteString(" minutes")
}
if minutes != 0 && (seconds != 0) {
sb.WriteString(", ")
}
if seconds == 1 {
sb.WriteString("1 second")
} else if seconds != 0 || (hours == 0 && minutes == 0) {
sb.WriteString(strconv.FormatInt(seconds, 10))
sb.WriteString(" seconds")
}
return sb.String()
}

@ -1,8 +1,14 @@
package api package api
import ( import (
"fmt" "context"
"errors"
"github.com/containrrr/watchtower/pkg/api/metrics"
"github.com/containrrr/watchtower/pkg/api/middleware"
"github.com/containrrr/watchtower/pkg/api/prelude"
"github.com/containrrr/watchtower/pkg/api/updates"
"net/http" "net/http"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -13,6 +19,12 @@ const tokenMissingMsg = "api token is empty or has not been set. exiting"
type API struct { type API struct {
Token string Token string
hasHandlers bool hasHandlers bool
mux *http.ServeMux
server *http.Server
running *sync.Mutex
router router
authMiddleware prelude.Middleware
registered bool
} }
// New is a factory function creating a new API instance // New is a factory function creating a new API instance
@ -20,37 +32,37 @@ func New(token string) *API {
return &API{ return &API{
Token: token, Token: token,
hasHandlers: false, hasHandlers: false,
mux: http.NewServeMux(),
running: &sync.Mutex{},
router: router{},
authMiddleware: middleware.RequireToken(token),
registered: false,
} }
} }
// RequireToken is wrapper around http.HandleFunc that checks token validity func (api *API) route(route string) methodHandlers {
func (api *API) RequireToken(fn http.HandlerFunc) http.HandlerFunc { return api.router.route(route)
return func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
want := fmt.Sprintf("Bearer %s", api.Token)
if auth != want {
w.WriteHeader(http.StatusUnauthorized)
return
} }
log.Debug("Valid token found.")
fn(w, r) func (api *API) registerHandlers() {
if api.registered {
return
} }
for path, route := range api.router {
if len(route) < 1 {
continue
} }
// RegisterFunc is a wrapper around http.HandleFunc that also sets the flag used to determine whether to launch the API
func (api *API) RegisterFunc(path string, fn http.HandlerFunc) {
api.hasHandlers = true api.hasHandlers = true
http.HandleFunc(path, api.RequireToken(fn)) api.mux.Handle(path, api.authMiddleware(route.Handler))
} }
api.registered = true
// RegisterHandler is a wrapper around http.Handler that also sets the flag used to determine whether to launch the API return
func (api *API) RegisterHandler(path string, handler http.Handler) {
api.hasHandlers = true
http.Handle(path, api.RequireToken(handler.ServeHTTP))
} }
// Start the API and serve over HTTP. Requires an API Token to be set. // Start the API and serve over HTTP. Requires an API Token to be set.
func (api *API) Start(block bool) error { func (api *API) Start() error {
api.registerHandlers()
if !api.hasHandlers { if !api.hasHandlers {
log.Debug("Watchtower HTTP API skipped.") log.Debug("Watchtower HTTP API skipped.")
@ -61,16 +73,49 @@ func (api *API) Start(block bool) error {
log.Fatal(tokenMissingMsg) log.Fatal(tokenMissingMsg)
} }
if block { api.running.Lock()
runHTTPServer()
} else {
go func() { go func() {
runHTTPServer() defer api.running.Unlock()
}() api.server = &http.Server{
Addr: ":8080",
Handler: api.mux,
}
if err := api.server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Errorf("HTTP Server error: %v", err)
} }
}()
return nil return nil
} }
func runHTTPServer() { // Stop tells the api server to shut down (if its running) and returns a sync.Mutex that is locked
log.Fatal(http.ListenAndServe(":8080", nil)) // until the server has handled all remaining requests and shut down
func (api *API) Stop() *sync.Mutex {
if api.server != nil {
go func() {
if err := api.server.Shutdown(context.Background()); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Errorf("Error stopping HTTP Server: %v", err)
}
}()
}
return api.running
}
// Handler is used to get a http.Handler for testing
func (api *API) Handler() http.Handler {
api.registerHandlers()
return api.mux
}
// EnableUpdates registers the `updates` endpoints
func (api *API) EnableUpdates(f updates.InvokedFunc, updateLock *sync.Mutex) {
api.route("/v1/updates").post(updates.PostV1(f, updateLock))
}
// EnableMetrics registers the `metrics` endpoints
func (api *API) EnableMetrics() {
api.route("/v1/metrics").get(metrics.GetV1())
} }

@ -1,27 +1,14 @@
package metrics package metrics
import ( import (
. "github.com/containrrr/watchtower/pkg/api/prelude"
"github.com/containrrr/watchtower/pkg/metrics" "github.com/containrrr/watchtower/pkg/metrics"
"net/http"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
) )
// Handler is an HTTP handle for serving metric data // GetV1 creates a new metrics http handler
type Handler struct { func GetV1() HandlerFunc {
Path string // Initialize watchtower metrics
Handle http.HandlerFunc metrics.Init()
Metrics *metrics.Metrics return WrapHandler(promhttp.Handler().ServeHTTP)
}
// New is a factory function creating a new Metrics instance
func New() *Handler {
m := metrics.Default()
handler := promhttp.Handler()
return &Handler{
Path: "/v1/metrics",
Handle: handler.ServeHTTP,
Metrics: m,
}
} }

@ -12,7 +12,6 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/containrrr/watchtower/pkg/api" "github.com/containrrr/watchtower/pkg/api"
metricsAPI "github.com/containrrr/watchtower/pkg/api/metrics"
"github.com/containrrr/watchtower/pkg/metrics" "github.com/containrrr/watchtower/pkg/metrics"
) )
@ -51,10 +50,9 @@ func getWithToken(handler http.Handler) map[string]string {
var _ = Describe("the metrics API", func() { var _ = Describe("the metrics API", func() {
httpAPI := api.New(token) httpAPI := api.New(token)
m := metricsAPI.New() httpAPI.EnableMetrics()
handleReq := httpAPI.RequireToken(m.Handle) tryGetMetrics := func() map[string]string { return getWithToken(httpAPI.Handler()) }
tryGetMetrics := func() map[string]string { return getWithToken(handleReq) }
It("should serve metrics", func() { It("should serve metrics", func() {

@ -1,7 +1,7 @@
package api package middleware
import ( import (
"io" "github.com/containrrr/watchtower/pkg/api/prelude"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -16,50 +16,53 @@ const (
func TestAPI(t *testing.T) { func TestAPI(t *testing.T) {
RegisterFailHandler(Fail) RegisterFailHandler(Fail)
RunSpecs(t, "API Suite") RunSpecs(t, "Middleware Suite")
} }
var _ = Describe("API", func() { var _ = Describe("API", func() {
api := New(token) requireToken := RequireToken(token)
Describe("RequireToken middleware", func() { Describe("RequireToken middleware", func() {
It("should return 401 Unauthorized when token is not provided", func() { It("should return 401 Unauthorized when token is not provided", func() {
handlerFunc := api.RequireToken(testHandler)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/hello", nil) req := httptest.NewRequest("GET", "/hello", nil)
handlerFunc(rec, req) requireToken(testHandler).ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(http.StatusUnauthorized)) Expect(rec.Code).To(Equal(http.StatusUnauthorized))
Expect(rec.Body).To(MatchJSON(`{
"code": "MISSING_TOKEN",
"error": "No authentication token was supplied"
}`))
}) })
It("should return 401 Unauthorized when token is invalid", func() { It("should return 401 Unauthorized when token is invalid", func() {
handlerFunc := api.RequireToken(testHandler)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/hello", nil) req := httptest.NewRequest("GET", "/hello", nil)
req.Header.Set("Authorization", "Bearer 123") req.Header.Set("Authorization", "Bearer 123")
handlerFunc(rec, req) requireToken(testHandler).ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(http.StatusUnauthorized)) Expect(rec.Code).To(Equal(http.StatusUnauthorized))
Expect(rec.Body).To(MatchJSON(`{
"code": "INVALID_TOKEN",
"error": "The supplied token does not match the configured auth token"
}`))
}) })
It("should return 200 OK when token is valid", func() { It("should return 200 OK when token is valid", func() {
handlerFunc := api.RequireToken(testHandler)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/hello", nil) req := httptest.NewRequest("GET", "/hello", nil)
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
handlerFunc(rec, req) requireToken(testHandler).ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(http.StatusOK)) Expect(rec.Code).To(Equal(http.StatusOK))
}) })
}) })
}) })
func testHandler(w http.ResponseWriter, req *http.Request) { func testHandler(_ *prelude.Context) prelude.Response {
_, _ = io.WriteString(w, "Hello!") return prelude.OK("Hello!")
} }

@ -0,0 +1,24 @@
package middleware
import (
"fmt"
. "github.com/containrrr/watchtower/pkg/api/prelude"
)
// RequireToken returns a prelude.Middleware that checks token validity
func RequireToken(token string) Middleware {
return func(next HandlerFunc) HandlerFunc {
want := fmt.Sprintf("Bearer %s", token)
return func(c *Context) Response {
auth := c.Request.Header.Get("Authorization")
if auth == "" {
return Error(ErrMissingToken)
}
if auth != want {
return Error(ErrInvalidToken)
}
return next(c)
}
}
}

@ -0,0 +1,62 @@
package prelude
import (
"bytes"
"fmt"
"github.com/sirupsen/logrus"
"net/http"
)
type Context struct {
Request *http.Request
Log *logrus.Entry
writer http.ResponseWriter
}
func newContext(w http.ResponseWriter, req *http.Request) *Context {
reqLog := localLog.WithField("endpoint", fmt.Sprintf("%v %v", req.Method, req.URL.Path))
return &Context{
Log: reqLog,
Request: req,
writer: w,
}
}
func (c *Context) Headers() http.Header {
return c.writer.Header()
}
type contextWrapper struct {
context *Context
body bytes.Buffer
statusCode int
}
func (cw *contextWrapper) Header() http.Header {
return cw.context.writer.Header()
}
func (cw *contextWrapper) Write(bytes []byte) (int, error) {
return cw.body.Write(bytes)
}
func (cw *contextWrapper) WriteHeader(statusCode int) {
cw.statusCode = statusCode
}
func WrapHandler(next http.HandlerFunc) HandlerFunc {
return func(c *Context) Response {
wrapper := contextWrapper{
context: c,
body: bytes.Buffer{},
}
next(&wrapper, c.Request)
return Response{
Status: wrapper.statusCode,
Body: wrapper.body.Bytes(),
Raw: true,
}
}
}

@ -0,0 +1,36 @@
package prelude
import "net/http"
type errorResponse struct {
Error string `json:"error"`
Code ErrorCode `json:"code"`
Status int `json:"-"`
}
const internalErrorPayload string = `{ "error": "API internal error, check logs", "code": "API_INTERNAL_ERROR" }`
type ErrorCode string
var (
ErrUpdateRunning = errorResponse{
Code: "UPDATE_RUNNING",
Error: "Update already running",
Status: http.StatusConflict,
}
ErrNotFound = errorResponse{
Code: "NOT_FOUND",
Error: "Endpoint is not registered to a handler",
Status: http.StatusNotFound,
}
ErrInvalidToken = errorResponse{
Code: "INVALID_TOKEN",
Error: "The supplied token does not match the configured auth token",
Status: http.StatusUnauthorized,
}
ErrMissingToken = errorResponse{
Code: "MISSING_TOKEN",
Error: "No authentication token was supplied",
Status: http.StatusUnauthorized,
}
)

@ -0,0 +1,39 @@
package prelude
import (
log "github.com/sirupsen/logrus"
"net/http"
)
type HandlerFunc func(c *Context) Response
func (hf HandlerFunc) ServeHTTP(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", DefaultContentType)
context := newContext(w, req)
reqLog := context.Log.WithFields(log.Fields{
"query": req.URL.RawQuery,
})
reqLog.Trace("Received API Request")
res := hf(context)
status := res.Status
bytes, err := res.Bytes()
if err != nil {
context.Log.WithError(err).Errorf("Failed to create JSON payload for response")
bytes = []byte(internalErrorPayload)
status = http.StatusInternalServerError
// Reset the content-type in case the handler changed it
w.Header().Set("Content-Type", DefaultContentType)
}
reqLog.WithField("status", status).Trace("Handled API Request")
w.WriteHeader(status)
if _, err = w.Write(bytes); err != nil {
localLog.Errorf("Failed to write HTTP response: %v", err)
}
}

@ -0,0 +1,41 @@
package prelude
import (
"encoding/json"
log "github.com/sirupsen/logrus"
"net/http"
)
type Response struct {
Body any
Status int
Raw bool
}
func (r *Response) Bytes() ([]byte, error) {
if bytes, raw := r.Body.([]byte); raw {
return bytes, nil
}
if str, raw := r.Body.(string); raw {
return []byte(str), nil
}
return json.MarshalIndent(r.Body, "", " ")
}
var localLog = log.WithField("notify", "no")
func OK(body any) Response {
return Response{
Status: http.StatusOK,
Body: body,
}
}
func Error(err errorResponse) Response {
return Response{
Status: err.Status,
Body: err,
}
}

@ -0,0 +1,5 @@
package prelude
type Middleware func(next HandlerFunc) HandlerFunc
const DefaultContentType = "application/json"

@ -0,0 +1,34 @@
package api
import (
. "github.com/containrrr/watchtower/pkg/api/prelude"
"net/http"
)
type router map[string]methodHandlers
type methodHandlers map[string]HandlerFunc
func (mh methodHandlers) Handler(c *Context) Response {
handler, found := mh[c.Request.Method]
if !found {
return Error(ErrNotFound)
}
return handler(c)
}
func (mh methodHandlers) post(handlerFunc HandlerFunc) {
mh[http.MethodPost] = handlerFunc
}
func (mh methodHandlers) get(handlerFunc HandlerFunc) {
mh[http.MethodGet] = handlerFunc
}
func (r router) route(route string) methodHandlers {
routeMethods, found := r[route]
if !found {
routeMethods = methodHandlers{}
r[route] = routeMethods
}
return routeMethods
}

@ -1,72 +0,0 @@
package update
import (
"io"
"net/http"
"os"
"strings"
log "github.com/sirupsen/logrus"
)
var (
lock chan bool
)
// New is a factory function creating a new Handler instance
func New(updateFn func(images []string), updateLock chan bool) *Handler {
if updateLock != nil {
lock = updateLock
} else {
lock = make(chan bool, 1)
lock <- true
}
return &Handler{
fn: updateFn,
Path: "/v1/update",
}
}
// Handler is an API handler used for triggering container update scans
type Handler struct {
fn func(images []string)
Path string
}
// Handle is the actual http.Handle function doing all the heavy lifting
func (handle *Handler) Handle(w http.ResponseWriter, r *http.Request) {
log.Info("Updates triggered by HTTP API request.")
_, err := io.Copy(os.Stdout, r.Body)
if err != nil {
log.Println(err)
return
}
var images []string
imageQueries, found := r.URL.Query()["image"]
if found {
for _, image := range imageQueries {
images = append(images, strings.Split(image, ",")...)
}
} else {
images = nil
}
if len(images) > 0 {
chanValue := <-lock
defer func() { lock <- chanValue }()
handle.fn(images)
} else {
select {
case chanValue := <-lock:
defer func() { lock <- chanValue }()
handle.fn(images)
default:
log.Debug("Skipped. Another update already running.")
}
}
}

@ -0,0 +1,22 @@
package updates
import (
"github.com/containrrr/watchtower/pkg/types"
"net/url"
"strings"
)
type ModifyParamsFunc func(up *types.UpdateParams)
type InvokedFunc func(ModifyParamsFunc) types.Report
func parseImages(u *url.URL) []string {
var images []string
imageQueries, found := u.Query()["image"]
if found {
for _, image := range imageQueries {
images = append(images, strings.Split(image, ",")...)
}
}
return images
}

@ -0,0 +1,37 @@
package updates
import (
. "github.com/containrrr/watchtower/pkg/api/prelude"
"github.com/containrrr/watchtower/pkg/filters"
"github.com/containrrr/watchtower/pkg/types"
"sync"
log "github.com/sirupsen/logrus"
)
// PostV1 creates an API http.HandlerFunc for V1 of updates
func PostV1(updateFn InvokedFunc, updateLock *sync.Mutex) HandlerFunc {
return func(c *Context) Response {
log.Info("Updates triggered by HTTP API request.")
images := parseImages(c.Request.URL)
if !updateLock.TryLock() {
if len(images) > 0 {
// If images have been passed, wait until the current updates are done
updateLock.Lock()
} else {
// If a full update is running (no explicit image filter), skip this update
log.Debug("Skipped. Another updates already running.")
return OK(nil) // For backwards compatibility
}
}
defer updateLock.Unlock()
_ = updateFn(func(up *types.UpdateParams) {
up.Filter = filters.FilterByImage(images, up.Filter)
})
return OK(nil)
}
}

@ -45,12 +45,11 @@ func (metrics *Metrics) Register(metric *Metric) {
metrics.channel <- metric metrics.channel <- metric
} }
// Default creates a new metrics handler if none exists, otherwise returns the existing one // Init creates a new metrics handler if none exists
func Default() *Metrics { func Init() {
if metrics != nil { if metrics != nil {
return metrics return
} }
metrics = &Metrics{ metrics = &Metrics{
scanned: promauto.NewGauge(prometheus.GaugeOpts{ scanned: promauto.NewGauge(prometheus.GaugeOpts{
Name: "watchtower_containers_scanned", Name: "watchtower_containers_scanned",
@ -76,6 +75,11 @@ func Default() *Metrics {
} }
go metrics.HandleUpdate(metrics.channel) go metrics.HandleUpdate(metrics.channel)
}
// Default creates a new metrics handler if none exists, otherwise returns the existing one
func Default() *Metrics {
Init()
return metrics return metrics
} }

Loading…
Cancel
Save