From 47091761a55905c2132552fdc540dfbbb990f5ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?nils=20m=C3=A5s=C3=A9n?= Date: Sat, 21 Oct 2023 19:35:41 +0200 Subject: [PATCH] feat(api): implement new api handler --- cmd/root.go | 245 ++++++++---------- internal/util/duration.go | 47 ++++ pkg/api/api.go | 115 +++++--- pkg/api/metrics/metrics.go | 25 +- pkg/api/metrics/metrics_test.go | 6 +- .../middleware_test.go} | 35 +-- pkg/api/middleware/require_token.go | 24 ++ pkg/api/prelude/context.go | 62 +++++ pkg/api/prelude/errors.go | 36 +++ pkg/api/prelude/handler_func.go | 39 +++ pkg/api/prelude/response.go | 41 +++ pkg/api/prelude/types.go | 5 + pkg/api/router.go | 34 +++ pkg/api/update/update.go | 72 ----- pkg/api/updates/updates.go | 22 ++ pkg/api/updates/updates_v1.go | 37 +++ pkg/metrics/metrics.go | 12 +- 17 files changed, 567 insertions(+), 290 deletions(-) create mode 100644 internal/util/duration.go rename pkg/api/{api_test.go => middleware/middleware_test.go} (57%) create mode 100644 pkg/api/middleware/require_token.go create mode 100644 pkg/api/prelude/context.go create mode 100644 pkg/api/prelude/errors.go create mode 100644 pkg/api/prelude/handler_func.go create mode 100644 pkg/api/prelude/response.go create mode 100644 pkg/api/prelude/types.go create mode 100644 pkg/api/router.go delete mode 100644 pkg/api/update/update.go create mode 100644 pkg/api/updates/updates.go create mode 100644 pkg/api/updates/updates_v1.go diff --git a/cmd/root.go b/cmd/root.go index 48961d2..e9ebd40 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,22 +1,19 @@ package cmd import ( - "errors" - "math" - "net/http" "os" "os/signal" - "strconv" "strings" + "sync" "syscall" "time" "github.com/containrrr/watchtower/internal/actions" "github.com/containrrr/watchtower/internal/flags" "github.com/containrrr/watchtower/internal/meta" + "github.com/containrrr/watchtower/internal/util" "github.com/containrrr/watchtower/pkg/api" - apiMetrics "github.com/containrrr/watchtower/pkg/api/metrics" - "github.com/containrrr/watchtower/pkg/api/update" + "github.com/containrrr/watchtower/pkg/api/updates" "github.com/containrrr/watchtower/pkg/container" "github.com/containrrr/watchtower/pkg/filters" "github.com/containrrr/watchtower/pkg/metrics" @@ -31,20 +28,16 @@ import ( var ( client container.Client scheduleSpec string - cleanup bool - noRestart bool - monitorOnly bool enableLabel bool disableContainers []string notifier t.Notifier - timeout time.Duration - lifecycleHooks bool - rollingRestart bool scope string - labelPrecedence bool + + up = t.UpdateParams{} ) var rootCmd = NewRootCommand() +var localLog = notifications.LocalLog // NewRootCommand creates the root command for watchtower func NewRootCommand() *cobra.Command { @@ -87,18 +80,18 @@ func PreRun(cmd *cobra.Command, _ []string) { scheduleSpec, _ = f.GetString("schedule") flags.GetSecretsFromFiles(cmd) - cleanup, noRestart, monitorOnly, timeout = flags.ReadFlags(cmd) + up.Cleanup, up.NoRestart, up.MonitorOnly, up.Timeout = flags.ReadFlags(cmd) - if timeout < 0 { + if up.Timeout < 0 { log.Fatal("Please specify a positive value for timeout value.") } enableLabel, _ = f.GetBool("label-enable") disableContainers, _ = f.GetStringSlice("disable-containers") - lifecycleHooks, _ = f.GetBool("enable-lifecycle-hooks") - rollingRestart, _ = f.GetBool("rolling-restart") + up.LifecycleHooks, _ = f.GetBool("enable-lifecycle-hooks") + up.RollingRestart, _ = f.GetBool("rolling-restart") scope, _ = f.GetString("scope") - labelPrecedence, _ = f.GetBool("label-take-precedence") + up.LabelPrecedence, _ = f.GetBool("label-take-precedence") if scope != "" { log.Debugf(`Using scope %q`, scope) @@ -110,25 +103,22 @@ func PreRun(cmd *cobra.Command, _ []string) { log.Fatal(err) } + var clientOpts = container.ClientOptions{} + noPull, _ := f.GetBool("no-pull") - includeStopped, _ := f.GetBool("include-stopped") - includeRestarting, _ := f.GetBool("include-restarting") - reviveStopped, _ := f.GetBool("revive-stopped") - removeVolumes, _ := f.GetBool("remove-volumes") + clientOpts.PullImages = !noPull + clientOpts.IncludeStopped, _ = f.GetBool("include-stopped") + clientOpts.IncludeRestarting, _ = f.GetBool("include-restarting") + clientOpts.ReviveStopped, _ = f.GetBool("revive-stopped") + clientOpts.RemoveVolumes, _ = f.GetBool("remove-volumes") warnOnHeadPullFailed, _ := f.GetString("warn-on-head-failure") + clientOpts.WarnOnHeadFailed = container.WarningStrategy(warnOnHeadPullFailed) - if monitorOnly && noPull { + if up.MonitorOnly && noPull { log.Warn("Using `WATCHTOWER_NO_PULL` and `WATCHTOWER_MONITOR_ONLY` simultaneously might lead to no action being taken at all. If this is intentional, you may safely ignore this message.") } - client = container.NewClient(container.ClientOptions{ - PullImages: !noPull, - IncludeStopped: includeStopped, - ReviveStopped: reviveStopped, - RemoveVolumes: removeVolumes, - IncludeRestarting: includeRestarting, - WarnOnHeadFailed: container.WarningStrategy(warnOnHeadPullFailed), - }) + client = container.NewClient(clientOpts) notifier = notifications.NewNotifier(cmd) notifier.AddLogHook() @@ -137,13 +127,16 @@ func PreRun(cmd *cobra.Command, _ []string) { // Run is the main execution flow of the command func Run(c *cobra.Command, names []string) { filter, filterDesc := filters.BuildFilter(names, disableContainers, enableLabel, scope) + up.Filter = filter runOnce, _ := c.PersistentFlags().GetBool("run-once") - enableUpdateAPI, _ := c.PersistentFlags().GetBool("http-api-update") + enableUpdateAPI, _ := c.PersistentFlags().GetBool("http-api-updates") enableMetricsAPI, _ := c.PersistentFlags().GetBool("http-api-metrics") unblockHTTPAPI, _ := c.PersistentFlags().GetBool("http-api-periodic-polls") apiToken, _ := c.PersistentFlags().GetString("http-api-token") healthCheck, _ := c.PersistentFlags().GetBool("health-check") + enableScheduler := !enableUpdateAPI || unblockHTTPAPI + if healthCheck { // health check should not have pid 1 if os.Getpid() == 1 { @@ -153,61 +146,97 @@ func Run(c *cobra.Command, names []string) { os.Exit(0) } - if rollingRestart && monitorOnly { + if up.RollingRestart && up.MonitorOnly { log.Fatal("Rolling restarts is not compatible with the global monitor only flag") } awaitDockerClient() - if err := actions.CheckForSanity(client, filter, rollingRestart); err != nil { + if err := actions.CheckForSanity(client, up.Filter, up.RollingRestart); err != nil { logNotifyExit(err) } if runOnce { writeStartupMessage(c, time.Time{}, filterDesc) - runUpdatesWithNotifications(filter) + runUpdatesWithNotifications(up) notifier.Close() os.Exit(0) return } - if err := actions.CheckForMultipleWatchtowerInstances(client, cleanup, scope); err != nil { + if err := actions.CheckForMultipleWatchtowerInstances(client, up.Cleanup, scope); err != nil { logNotifyExit(err) } - // The lock is shared between the scheduler and the HTTP API. It only allows one update to run at a time. - updateLock := make(chan bool, 1) - updateLock <- true + // The lock is shared between the scheduler and the HTTP API. It only allows one updates to run at a time. + updateLock := sync.Mutex{} httpAPI := api.New(apiToken) if enableUpdateAPI { - updateHandler := update.New(func(images []string) { - metric := runUpdatesWithNotifications(filters.FilterByImage(images, filter)) - metrics.RegisterScan(metric) - }, updateLock) - httpAPI.RegisterFunc(updateHandler.Path, updateHandler.Handle) - // If polling isn't enabled the scheduler is never started and - // we need to trigger the startup messages manually. - if !unblockHTTPAPI { - writeStartupMessage(c, time.Time{}, filterDesc) - } + httpAPI.EnableUpdates(func(paramsFunc updates.ModifyParamsFunc) t.Report { + apiUpdateParams := up + paramsFunc(&apiUpdateParams) + if up.MonitorOnly && !apiUpdateParams.MonitorOnly { + apiUpdateParams.MonitorOnly = true + localLog.Warn("Ignoring request to disable monitor only through API") + } + report := runUpdatesWithNotifications(apiUpdateParams) + metrics.RegisterScan(metrics.NewMetric(report)) + return report + }, &updateLock) } if enableMetricsAPI { - metricsHandler := apiMetrics.New() - httpAPI.RegisterHandler(metricsHandler.Path, metricsHandler.Handle) + httpAPI.EnableMetrics() } - if err := httpAPI.Start(enableUpdateAPI && !unblockHTTPAPI); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := httpAPI.Start(); err != nil { log.Error("failed to start API", err) } - if err := runUpgradesOnSchedule(c, filter, filterDesc, updateLock); err != nil { - log.Error(err) + var firstScan time.Time + var scheduler *cron.Cron + if enableScheduler { + var err error + scheduler, err = runUpgradesOnSchedule(up, &updateLock) + if err != nil { + log.Errorf("Failed to start scheduler: %v", err) + } else { + firstScan = scheduler.Entries()[0].Schedule.Next(time.Now()) + } } - os.Exit(1) + writeStartupMessage(c, firstScan, filterDesc) + + // Graceful shut-down on SIGINT/SIGTERM + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + signal.Notify(interrupt, syscall.SIGTERM) + + recievedSignal := <-interrupt + localLog.WithField("signal", recievedSignal).Infof("Got shutdown signal. Gracefully shutting down...") + if scheduler != nil { + scheduler.Stop() + } + + updateLock.Lock() + go func() { + time.Sleep(time.Second * 3) + updateLock.Unlock() + }() + + waitFor(httpAPI.Stop(), "Waiting for HTTP API requests to complete...") + waitFor(&updateLock, "Waiting for running updates to be finished...") + + localLog.Info("Shutdown completed") +} + +func waitFor(waitLock *sync.Mutex, delayMessage string) { + if !waitLock.TryLock() { + log.Info(delayMessage) + waitLock.Lock() + } } func logNotifyExit(err error) { @@ -221,48 +250,9 @@ func awaitDockerClient() { time.Sleep(1 * time.Second) } -func formatDuration(d time.Duration) string { - sb := strings.Builder{} - - hours := int64(d.Hours()) - minutes := int64(math.Mod(d.Minutes(), 60)) - seconds := int64(math.Mod(d.Seconds(), 60)) - - if hours == 1 { - sb.WriteString("1 hour") - } else if hours != 0 { - sb.WriteString(strconv.FormatInt(hours, 10)) - sb.WriteString(" hours") - } - - if hours != 0 && (seconds != 0 || minutes != 0) { - sb.WriteString(", ") - } - - if minutes == 1 { - sb.WriteString("1 minute") - } else if minutes != 0 { - sb.WriteString(strconv.FormatInt(minutes, 10)) - sb.WriteString(" minutes") - } - - if minutes != 0 && (seconds != 0) { - sb.WriteString(", ") - } - - if seconds == 1 { - sb.WriteString("1 second") - } else if seconds != 0 || (hours == 0 && minutes == 0) { - sb.WriteString(strconv.FormatInt(seconds, 10)) - sb.WriteString(" seconds") - } - - return sb.String() -} - func writeStartupMessage(c *cobra.Command, sched time.Time, filtering string) { noStartupMessage, _ := c.PersistentFlags().GetBool("no-startup-message") - enableUpdateAPI, _ := c.PersistentFlags().GetBool("http-api-update") + enableUpdateAPI, _ := c.PersistentFlags().GetBool("http-api-updates") var startupLog *log.Entry if noStartupMessage { @@ -285,11 +275,11 @@ func writeStartupMessage(c *cobra.Command, sched time.Time, filtering string) { startupLog.Info(filtering) if !sched.IsZero() { - until := formatDuration(time.Until(sched)) + until := util.FormatDuration(time.Until(sched)) startupLog.Info("Scheduling first run: " + sched.Format("2006-01-02 15:04:05 -0700 MST")) startupLog.Info("Note that the first check will be performed in " + until) } else if runOnce, _ := c.PersistentFlags().GetBool("run-once"); runOnce { - startupLog.Info("Running a one time update.") + startupLog.Info("Running a one time updates.") } else { startupLog.Info("Periodic runs are not enabled.") } @@ -309,25 +299,19 @@ func writeStartupMessage(c *cobra.Command, sched time.Time, filtering string) { } } -func runUpgradesOnSchedule(c *cobra.Command, filter t.Filter, filtering string, lock chan bool) error { - if lock == nil { - lock = make(chan bool, 1) - lock <- true - } - +func runUpgradesOnSchedule(updateParams t.UpdateParams, updateLock *sync.Mutex) (*cron.Cron, error) { scheduler := cron.New() err := scheduler.AddFunc( scheduleSpec, func() { - select { - case v := <-lock: - defer func() { lock <- v }() - metric := runUpdatesWithNotifications(filter) - metrics.RegisterScan(metric) - default: + if updateLock.TryLock() { + defer updateLock.Unlock() + result := runUpdatesWithNotifications(updateParams) + metrics.RegisterScan(metrics.NewMetric(result)) + } else { // Update was skipped metrics.RegisterScan(nil) - log.Debug("Skipped another update already running.") + log.Debug("Skipped another updates already running.") } nextRuns := scheduler.Entries() @@ -337,47 +321,28 @@ func runUpgradesOnSchedule(c *cobra.Command, filter t.Filter, filtering string, }) if err != nil { - return err + return nil, err } - writeStartupMessage(c, scheduler.Entries()[0].Schedule.Next(time.Now()), filtering) - scheduler.Start() - // Graceful shut-down on SIGINT/SIGTERM - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt) - signal.Notify(interrupt, syscall.SIGTERM) - - <-interrupt - scheduler.Stop() - log.Info("Waiting for running update to be finished...") - <-lock - return nil + return scheduler, nil } -func runUpdatesWithNotifications(filter t.Filter) *metrics.Metric { +func runUpdatesWithNotifications(updateParams t.UpdateParams) t.Report { notifier.StartNotification() - updateParams := t.UpdateParams{ - Filter: filter, - Cleanup: cleanup, - NoRestart: noRestart, - Timeout: timeout, - MonitorOnly: monitorOnly, - LifecycleHooks: lifecycleHooks, - RollingRestart: rollingRestart, - LabelPrecedence: labelPrecedence, - } + result, err := actions.Update(client, updateParams) if err != nil { log.Error(err) } notifier.SendNotification(result) - metricResults := metrics.NewMetric(result) - notifications.LocalLog.WithFields(log.Fields{ - "Scanned": metricResults.Scanned, - "Updated": metricResults.Updated, - "Failed": metricResults.Failed, + + localLog.WithFields(log.Fields{ + "Scanned": len(result.Scanned()), + "Updated": len(result.Updated()), + "Failed": len(result.Failed()), }).Info("Session done") - return metricResults + + return result } diff --git a/internal/util/duration.go b/internal/util/duration.go new file mode 100644 index 0000000..e3ebe9f --- /dev/null +++ b/internal/util/duration.go @@ -0,0 +1,47 @@ +package util + +import ( + "math" + "strconv" + "strings" + "time" +) + +func FormatDuration(d time.Duration) string { + sb := strings.Builder{} + + hours := int64(d.Hours()) + minutes := int64(math.Mod(d.Minutes(), 60)) + seconds := int64(math.Mod(d.Seconds(), 60)) + + if hours == 1 { + sb.WriteString("1 hour") + } else if hours != 0 { + sb.WriteString(strconv.FormatInt(hours, 10)) + sb.WriteString(" hours") + } + + if hours != 0 && (seconds != 0 || minutes != 0) { + sb.WriteString(", ") + } + + if minutes == 1 { + sb.WriteString("1 minute") + } else if minutes != 0 { + sb.WriteString(strconv.FormatInt(minutes, 10)) + sb.WriteString(" minutes") + } + + if minutes != 0 && (seconds != 0) { + sb.WriteString(", ") + } + + if seconds == 1 { + sb.WriteString("1 second") + } else if seconds != 0 || (hours == 0 && minutes == 0) { + sb.WriteString(strconv.FormatInt(seconds, 10)) + sb.WriteString(" seconds") + } + + return sb.String() +} diff --git a/pkg/api/api.go b/pkg/api/api.go index 2ceaea8..36d9077 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -1,8 +1,14 @@ package api import ( - "fmt" + "context" + "errors" + "github.com/containrrr/watchtower/pkg/api/metrics" + "github.com/containrrr/watchtower/pkg/api/middleware" + "github.com/containrrr/watchtower/pkg/api/prelude" + "github.com/containrrr/watchtower/pkg/api/updates" "net/http" + "sync" log "github.com/sirupsen/logrus" ) @@ -11,46 +17,52 @@ const tokenMissingMsg = "api token is empty or has not been set. exiting" // API is the http server responsible for serving the HTTP API endpoints type API struct { - Token string - hasHandlers bool + Token string + hasHandlers bool + mux *http.ServeMux + server *http.Server + running *sync.Mutex + router router + authMiddleware prelude.Middleware + registered bool } // New is a factory function creating a new API instance func New(token string) *API { return &API{ - Token: token, - hasHandlers: false, + Token: token, + hasHandlers: false, + mux: http.NewServeMux(), + running: &sync.Mutex{}, + router: router{}, + authMiddleware: middleware.RequireToken(token), + registered: false, } } -// RequireToken is wrapper around http.HandleFunc that checks token validity -func (api *API) RequireToken(fn http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Authorization") - want := fmt.Sprintf("Bearer %s", api.Token) - if auth != want { - w.WriteHeader(http.StatusUnauthorized) - return - } - log.Debug("Valid token found.") - fn(w, r) - } -} - -// RegisterFunc is a wrapper around http.HandleFunc that also sets the flag used to determine whether to launch the API -func (api *API) RegisterFunc(path string, fn http.HandlerFunc) { - api.hasHandlers = true - http.HandleFunc(path, api.RequireToken(fn)) +func (api *API) route(route string) methodHandlers { + return api.router.route(route) } -// RegisterHandler is a wrapper around http.Handler that also sets the flag used to determine whether to launch the API -func (api *API) RegisterHandler(path string, handler http.Handler) { - api.hasHandlers = true - http.Handle(path, api.RequireToken(handler.ServeHTTP)) +func (api *API) registerHandlers() { + if api.registered { + return + } + for path, route := range api.router { + if len(route) < 1 { + continue + } + api.hasHandlers = true + api.mux.Handle(path, api.authMiddleware(route.Handler)) + } + api.registered = true + return } // Start the API and serve over HTTP. Requires an API Token to be set. -func (api *API) Start(block bool) error { +func (api *API) Start() error { + + api.registerHandlers() if !api.hasHandlers { log.Debug("Watchtower HTTP API skipped.") @@ -61,16 +73,49 @@ func (api *API) Start(block bool) error { log.Fatal(tokenMissingMsg) } - if block { - runHTTPServer() - } else { + api.running.Lock() + go func() { + defer api.running.Unlock() + api.server = &http.Server{ + Addr: ":8080", + Handler: api.mux, + } + + if err := api.server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + log.Errorf("HTTP Server error: %v", err) + } + }() + + return nil +} + +// Stop tells the api server to shut down (if its running) and returns a sync.Mutex that is locked +// until the server has handled all remaining requests and shut down +func (api *API) Stop() *sync.Mutex { + + if api.server != nil { go func() { - runHTTPServer() + if err := api.server.Shutdown(context.Background()); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Errorf("Error stopping HTTP Server: %v", err) + } }() } - return nil + + return api.running +} + +// Handler is used to get a http.Handler for testing +func (api *API) Handler() http.Handler { + api.registerHandlers() + return api.mux +} + +// EnableUpdates registers the `updates` endpoints +func (api *API) EnableUpdates(f updates.InvokedFunc, updateLock *sync.Mutex) { + api.route("/v1/updates").post(updates.PostV1(f, updateLock)) } -func runHTTPServer() { - log.Fatal(http.ListenAndServe(":8080", nil)) +// EnableMetrics registers the `metrics` endpoints +func (api *API) EnableMetrics() { + api.route("/v1/metrics").get(metrics.GetV1()) } diff --git a/pkg/api/metrics/metrics.go b/pkg/api/metrics/metrics.go index 4faad4a..a36ecde 100644 --- a/pkg/api/metrics/metrics.go +++ b/pkg/api/metrics/metrics.go @@ -1,27 +1,14 @@ package metrics import ( + . "github.com/containrrr/watchtower/pkg/api/prelude" "github.com/containrrr/watchtower/pkg/metrics" - "net/http" - "github.com/prometheus/client_golang/prometheus/promhttp" ) -// Handler is an HTTP handle for serving metric data -type Handler struct { - Path string - Handle http.HandlerFunc - Metrics *metrics.Metrics -} - -// New is a factory function creating a new Metrics instance -func New() *Handler { - m := metrics.Default() - handler := promhttp.Handler() - - return &Handler{ - Path: "/v1/metrics", - Handle: handler.ServeHTTP, - Metrics: m, - } +// GetV1 creates a new metrics http handler +func GetV1() HandlerFunc { + // Initialize watchtower metrics + metrics.Init() + return WrapHandler(promhttp.Handler().ServeHTTP) } diff --git a/pkg/api/metrics/metrics_test.go b/pkg/api/metrics/metrics_test.go index 48b6dd7..1372d2b 100644 --- a/pkg/api/metrics/metrics_test.go +++ b/pkg/api/metrics/metrics_test.go @@ -12,7 +12,6 @@ import ( . "github.com/onsi/gomega" "github.com/containrrr/watchtower/pkg/api" - metricsAPI "github.com/containrrr/watchtower/pkg/api/metrics" "github.com/containrrr/watchtower/pkg/metrics" ) @@ -51,10 +50,9 @@ func getWithToken(handler http.Handler) map[string]string { var _ = Describe("the metrics API", func() { httpAPI := api.New(token) - m := metricsAPI.New() + httpAPI.EnableMetrics() - handleReq := httpAPI.RequireToken(m.Handle) - tryGetMetrics := func() map[string]string { return getWithToken(handleReq) } + tryGetMetrics := func() map[string]string { return getWithToken(httpAPI.Handler()) } It("should serve metrics", func() { diff --git a/pkg/api/api_test.go b/pkg/api/middleware/middleware_test.go similarity index 57% rename from pkg/api/api_test.go rename to pkg/api/middleware/middleware_test.go index 4e9110b..f863a1c 100644 --- a/pkg/api/api_test.go +++ b/pkg/api/middleware/middleware_test.go @@ -1,7 +1,7 @@ -package api +package middleware import ( - "io" + "github.com/containrrr/watchtower/pkg/api/prelude" "net/http" "net/http/httptest" "testing" @@ -11,55 +11,58 @@ import ( ) const ( - token = "123123123" + token = "123123123" ) func TestAPI(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "API Suite") + RunSpecs(t, "Middleware Suite") } var _ = Describe("API", func() { - api := New(token) + requireToken := RequireToken(token) Describe("RequireToken middleware", func() { It("should return 401 Unauthorized when token is not provided", func() { - handlerFunc := api.RequireToken(testHandler) - rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "/hello", nil) - handlerFunc(rec, req) + requireToken(testHandler).ServeHTTP(rec, req) Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + Expect(rec.Body).To(MatchJSON(`{ + "code": "MISSING_TOKEN", + "error": "No authentication token was supplied" + }`)) }) It("should return 401 Unauthorized when token is invalid", func() { - handlerFunc := api.RequireToken(testHandler) - rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "/hello", nil) req.Header.Set("Authorization", "Bearer 123") - handlerFunc(rec, req) + requireToken(testHandler).ServeHTTP(rec, req) Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + Expect(rec.Body).To(MatchJSON(`{ + "code": "INVALID_TOKEN", + "error": "The supplied token does not match the configured auth token" + }`)) }) It("should return 200 OK when token is valid", func() { - handlerFunc := api.RequireToken(testHandler) rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "/hello", nil) - req.Header.Set("Authorization", "Bearer " + token) + req.Header.Set("Authorization", "Bearer "+token) - handlerFunc(rec, req) + requireToken(testHandler).ServeHTTP(rec, req) Expect(rec.Code).To(Equal(http.StatusOK)) }) }) }) -func testHandler(w http.ResponseWriter, req *http.Request) { - _, _ = io.WriteString(w, "Hello!") +func testHandler(_ *prelude.Context) prelude.Response { + return prelude.OK("Hello!") } diff --git a/pkg/api/middleware/require_token.go b/pkg/api/middleware/require_token.go new file mode 100644 index 0000000..bd24e66 --- /dev/null +++ b/pkg/api/middleware/require_token.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "fmt" + . "github.com/containrrr/watchtower/pkg/api/prelude" +) + +// RequireToken returns a prelude.Middleware that checks token validity +func RequireToken(token string) Middleware { + return func(next HandlerFunc) HandlerFunc { + want := fmt.Sprintf("Bearer %s", token) + return func(c *Context) Response { + auth := c.Request.Header.Get("Authorization") + if auth == "" { + return Error(ErrMissingToken) + } + + if auth != want { + return Error(ErrInvalidToken) + } + return next(c) + } + } +} diff --git a/pkg/api/prelude/context.go b/pkg/api/prelude/context.go new file mode 100644 index 0000000..077530d --- /dev/null +++ b/pkg/api/prelude/context.go @@ -0,0 +1,62 @@ +package prelude + +import ( + "bytes" + "fmt" + "github.com/sirupsen/logrus" + "net/http" +) + +type Context struct { + Request *http.Request + Log *logrus.Entry + writer http.ResponseWriter +} + +func newContext(w http.ResponseWriter, req *http.Request) *Context { + reqLog := localLog.WithField("endpoint", fmt.Sprintf("%v %v", req.Method, req.URL.Path)) + return &Context{ + Log: reqLog, + Request: req, + writer: w, + } +} + +func (c *Context) Headers() http.Header { + return c.writer.Header() +} + +type contextWrapper struct { + context *Context + body bytes.Buffer + statusCode int +} + +func (cw *contextWrapper) Header() http.Header { + return cw.context.writer.Header() +} + +func (cw *contextWrapper) Write(bytes []byte) (int, error) { + return cw.body.Write(bytes) +} + +func (cw *contextWrapper) WriteHeader(statusCode int) { + cw.statusCode = statusCode +} + +func WrapHandler(next http.HandlerFunc) HandlerFunc { + return func(c *Context) Response { + wrapper := contextWrapper{ + context: c, + body: bytes.Buffer{}, + } + + next(&wrapper, c.Request) + + return Response{ + Status: wrapper.statusCode, + Body: wrapper.body.Bytes(), + Raw: true, + } + } +} diff --git a/pkg/api/prelude/errors.go b/pkg/api/prelude/errors.go new file mode 100644 index 0000000..a30ddce --- /dev/null +++ b/pkg/api/prelude/errors.go @@ -0,0 +1,36 @@ +package prelude + +import "net/http" + +type errorResponse struct { + Error string `json:"error"` + Code ErrorCode `json:"code"` + Status int `json:"-"` +} + +const internalErrorPayload string = `{ "error": "API internal error, check logs", "code": "API_INTERNAL_ERROR" }` + +type ErrorCode string + +var ( + ErrUpdateRunning = errorResponse{ + Code: "UPDATE_RUNNING", + Error: "Update already running", + Status: http.StatusConflict, + } + ErrNotFound = errorResponse{ + Code: "NOT_FOUND", + Error: "Endpoint is not registered to a handler", + Status: http.StatusNotFound, + } + ErrInvalidToken = errorResponse{ + Code: "INVALID_TOKEN", + Error: "The supplied token does not match the configured auth token", + Status: http.StatusUnauthorized, + } + ErrMissingToken = errorResponse{ + Code: "MISSING_TOKEN", + Error: "No authentication token was supplied", + Status: http.StatusUnauthorized, + } +) diff --git a/pkg/api/prelude/handler_func.go b/pkg/api/prelude/handler_func.go new file mode 100644 index 0000000..272a800 --- /dev/null +++ b/pkg/api/prelude/handler_func.go @@ -0,0 +1,39 @@ +package prelude + +import ( + log "github.com/sirupsen/logrus" + "net/http" +) + +type HandlerFunc func(c *Context) Response + +func (hf HandlerFunc) ServeHTTP(w http.ResponseWriter, req *http.Request) { + + w.Header().Set("Content-Type", DefaultContentType) + context := newContext(w, req) + + reqLog := context.Log.WithFields(log.Fields{ + "query": req.URL.RawQuery, + }) + reqLog.Trace("Received API Request") + + res := hf(context) + + status := res.Status + + bytes, err := res.Bytes() + if err != nil { + context.Log.WithError(err).Errorf("Failed to create JSON payload for response") + bytes = []byte(internalErrorPayload) + status = http.StatusInternalServerError + // Reset the content-type in case the handler changed it + w.Header().Set("Content-Type", DefaultContentType) + } + + reqLog.WithField("status", status).Trace("Handled API Request") + + w.WriteHeader(status) + if _, err = w.Write(bytes); err != nil { + localLog.Errorf("Failed to write HTTP response: %v", err) + } +} diff --git a/pkg/api/prelude/response.go b/pkg/api/prelude/response.go new file mode 100644 index 0000000..ebd239f --- /dev/null +++ b/pkg/api/prelude/response.go @@ -0,0 +1,41 @@ +package prelude + +import ( + "encoding/json" + log "github.com/sirupsen/logrus" + "net/http" +) + +type Response struct { + Body any + Status int + Raw bool +} + +func (r *Response) Bytes() ([]byte, error) { + if bytes, raw := r.Body.([]byte); raw { + return bytes, nil + } + + if str, raw := r.Body.(string); raw { + return []byte(str), nil + } + + return json.MarshalIndent(r.Body, "", " ") +} + +var localLog = log.WithField("notify", "no") + +func OK(body any) Response { + return Response{ + Status: http.StatusOK, + Body: body, + } +} + +func Error(err errorResponse) Response { + return Response{ + Status: err.Status, + Body: err, + } +} diff --git a/pkg/api/prelude/types.go b/pkg/api/prelude/types.go new file mode 100644 index 0000000..471af18 --- /dev/null +++ b/pkg/api/prelude/types.go @@ -0,0 +1,5 @@ +package prelude + +type Middleware func(next HandlerFunc) HandlerFunc + +const DefaultContentType = "application/json" diff --git a/pkg/api/router.go b/pkg/api/router.go new file mode 100644 index 0000000..fb99087 --- /dev/null +++ b/pkg/api/router.go @@ -0,0 +1,34 @@ +package api + +import ( + . "github.com/containrrr/watchtower/pkg/api/prelude" + "net/http" +) + +type router map[string]methodHandlers + +type methodHandlers map[string]HandlerFunc + +func (mh methodHandlers) Handler(c *Context) Response { + handler, found := mh[c.Request.Method] + if !found { + return Error(ErrNotFound) + } + return handler(c) +} + +func (mh methodHandlers) post(handlerFunc HandlerFunc) { + mh[http.MethodPost] = handlerFunc +} +func (mh methodHandlers) get(handlerFunc HandlerFunc) { + mh[http.MethodGet] = handlerFunc +} + +func (r router) route(route string) methodHandlers { + routeMethods, found := r[route] + if !found { + routeMethods = methodHandlers{} + r[route] = routeMethods + } + return routeMethods +} diff --git a/pkg/api/update/update.go b/pkg/api/update/update.go deleted file mode 100644 index ba044ab..0000000 --- a/pkg/api/update/update.go +++ /dev/null @@ -1,72 +0,0 @@ -package update - -import ( - "io" - "net/http" - "os" - "strings" - - log "github.com/sirupsen/logrus" -) - -var ( - lock chan bool -) - -// New is a factory function creating a new Handler instance -func New(updateFn func(images []string), updateLock chan bool) *Handler { - if updateLock != nil { - lock = updateLock - } else { - lock = make(chan bool, 1) - lock <- true - } - - return &Handler{ - fn: updateFn, - Path: "/v1/update", - } -} - -// Handler is an API handler used for triggering container update scans -type Handler struct { - fn func(images []string) - Path string -} - -// Handle is the actual http.Handle function doing all the heavy lifting -func (handle *Handler) Handle(w http.ResponseWriter, r *http.Request) { - log.Info("Updates triggered by HTTP API request.") - - _, err := io.Copy(os.Stdout, r.Body) - if err != nil { - log.Println(err) - return - } - - var images []string - imageQueries, found := r.URL.Query()["image"] - if found { - for _, image := range imageQueries { - images = append(images, strings.Split(image, ",")...) - } - - } else { - images = nil - } - - if len(images) > 0 { - chanValue := <-lock - defer func() { lock <- chanValue }() - handle.fn(images) - } else { - select { - case chanValue := <-lock: - defer func() { lock <- chanValue }() - handle.fn(images) - default: - log.Debug("Skipped. Another update already running.") - } - } - -} diff --git a/pkg/api/updates/updates.go b/pkg/api/updates/updates.go new file mode 100644 index 0000000..b5fb72e --- /dev/null +++ b/pkg/api/updates/updates.go @@ -0,0 +1,22 @@ +package updates + +import ( + "github.com/containrrr/watchtower/pkg/types" + "net/url" + "strings" +) + +type ModifyParamsFunc func(up *types.UpdateParams) +type InvokedFunc func(ModifyParamsFunc) types.Report + +func parseImages(u *url.URL) []string { + var images []string + imageQueries, found := u.Query()["image"] + if found { + for _, image := range imageQueries { + images = append(images, strings.Split(image, ",")...) + } + + } + return images +} diff --git a/pkg/api/updates/updates_v1.go b/pkg/api/updates/updates_v1.go new file mode 100644 index 0000000..32a8e13 --- /dev/null +++ b/pkg/api/updates/updates_v1.go @@ -0,0 +1,37 @@ +package updates + +import ( + . "github.com/containrrr/watchtower/pkg/api/prelude" + "github.com/containrrr/watchtower/pkg/filters" + "github.com/containrrr/watchtower/pkg/types" + "sync" + + log "github.com/sirupsen/logrus" +) + +// PostV1 creates an API http.HandlerFunc for V1 of updates +func PostV1(updateFn InvokedFunc, updateLock *sync.Mutex) HandlerFunc { + return func(c *Context) Response { + log.Info("Updates triggered by HTTP API request.") + + images := parseImages(c.Request.URL) + + if !updateLock.TryLock() { + if len(images) > 0 { + // If images have been passed, wait until the current updates are done + updateLock.Lock() + } else { + // If a full update is running (no explicit image filter), skip this update + log.Debug("Skipped. Another updates already running.") + return OK(nil) // For backwards compatibility + } + } + + defer updateLock.Unlock() + _ = updateFn(func(up *types.UpdateParams) { + up.Filter = filters.FilterByImage(images, up.Filter) + }) + + return OK(nil) + } +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index b681733..a6f24ae 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -45,12 +45,11 @@ func (metrics *Metrics) Register(metric *Metric) { metrics.channel <- metric } -// Default creates a new metrics handler if none exists, otherwise returns the existing one -func Default() *Metrics { +// Init creates a new metrics handler if none exists +func Init() { if metrics != nil { - return metrics + return } - metrics = &Metrics{ scanned: promauto.NewGauge(prometheus.GaugeOpts{ Name: "watchtower_containers_scanned", @@ -76,6 +75,11 @@ func Default() *Metrics { } go metrics.HandleUpdate(metrics.channel) +} + +// Default creates a new metrics handler if none exists, otherwise returns the existing one +func Default() *Metrics { + Init() return metrics }