diff --git a/cmd/root.go b/cmd/root.go index d4839e6..6546933 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -138,6 +138,7 @@ func Run(c *cobra.Command, names []string) { log.Info("Running a one time update.") } runUpdatesWithNotifications(filter) + notifier.Close() os.Exit(0) return } diff --git a/pkg/notifications/email.go b/pkg/notifications/email.go index c4ee56b..bba2b82 100644 --- a/pkg/notifications/email.go +++ b/pkg/notifications/email.go @@ -153,3 +153,5 @@ func (e *emailTypeNotifier) Fire(entry *log.Entry) error { } return nil } + +func (e *emailTypeNotifier) Close() {} \ No newline at end of file diff --git a/pkg/notifications/gotify.go b/pkg/notifications/gotify.go index a065ac0..789f778 100644 --- a/pkg/notifications/gotify.go +++ b/pkg/notifications/gotify.go @@ -59,6 +59,8 @@ func (n *gotifyTypeNotifier) StartNotification() {} func (n *gotifyTypeNotifier) SendNotification() {} +func (n *gotifyTypeNotifier) Close() {} + func (n *gotifyTypeNotifier) Levels() []log.Level { return n.logLevels } diff --git a/pkg/notifications/msteams.go b/pkg/notifications/msteams.go index b356814..ab33966 100644 --- a/pkg/notifications/msteams.go +++ b/pkg/notifications/msteams.go @@ -47,6 +47,8 @@ func (n *msTeamsTypeNotifier) StartNotification() {} func (n *msTeamsTypeNotifier) SendNotification() {} +func (n *msTeamsTypeNotifier) Close() {} + func (n *msTeamsTypeNotifier) Levels() []log.Level { return n.levels } diff --git a/pkg/notifications/notifier.go b/pkg/notifications/notifier.go index 6595b22..dedb21a 100644 --- a/pkg/notifications/notifier.go +++ b/pkg/notifications/notifier.go @@ -66,3 +66,10 @@ func (n *Notifier) SendNotification() { t.SendNotification() } } + +// Close closes all notifiers. +func (n *Notifier) Close() { + for _, t := range n.types { + t.Close() + } +} diff --git a/pkg/notifications/shoutrrr.go b/pkg/notifications/shoutrrr.go index 9a7cd62..cd6359b 100644 --- a/pkg/notifications/shoutrrr.go +++ b/pkg/notifications/shoutrrr.go @@ -3,11 +3,11 @@ package notifications import ( "bytes" "fmt" + "github.com/containrrr/shoutrrr/pkg/types" "text/template" "strings" "github.com/containrrr/shoutrrr" - "github.com/containrrr/shoutrrr/pkg/router" t "github.com/containrrr/watchtower/pkg/types" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -18,13 +18,19 @@ const ( shoutrrrType = "shoutrrr" ) +type router interface { + Send(message string, params *types.Params) []error +} + // Implements Notifier, logrus.Hook type shoutrrrTypeNotifier struct { Urls []string - Router *router.ServiceRouter + Router router entries []*log.Entry logLevels []log.Level template *template.Template + messages chan string + done chan bool } func newShoutrrrNotifier(c *cobra.Command, acceptedLogLevels []log.Level) t.Notifier { @@ -41,13 +47,33 @@ func newShoutrrrNotifier(c *cobra.Command, acceptedLogLevels []log.Level) t.Noti Router: r, logLevels: acceptedLogLevels, template: getShoutrrrTemplate(c), + messages: make(chan string, 1), + done: make(chan bool), } log.AddHook(n) + // Do the sending in a separate goroutine so we don't block the main process. + go sendNotifications(n) + return n } +func sendNotifications(n *shoutrrrTypeNotifier) { + for msg := range n.messages { + errs := n.Router.Send(msg, nil) + + for i, err := range errs { + if err != nil { + // Use fmt so it doesn't trigger another notification. + fmt.Println("Failed to send notification via shoutrrr (url="+n.Urls[i]+"): ", err) + } + } + } + + n.done <- true +} + func (e *shoutrrrTypeNotifier) buildMessage(entries []*log.Entry) string { var body bytes.Buffer if err := e.template.Execute(&body, entries); err != nil { @@ -58,20 +84,8 @@ func (e *shoutrrrTypeNotifier) buildMessage(entries []*log.Entry) string { } func (e *shoutrrrTypeNotifier) sendEntries(entries []*log.Entry) { - msg := e.buildMessage(entries) - - // Do the sending in a separate goroutine so we don't block the main process. - go func() { - errs := e.Router.Send(msg, nil) - - for i, err := range errs { - if err != nil { - // Use fmt so it doesn't trigger another notification. - fmt.Println("Failed to send notification via shoutrrr (url="+e.Urls[i]+"): ", err) - } - } - }() + e.messages <- msg } func (e *shoutrrrTypeNotifier) StartNotification() { @@ -89,6 +103,15 @@ func (e *shoutrrrTypeNotifier) SendNotification() { e.entries = nil } +func (e *shoutrrrTypeNotifier) Close() { + close(e.messages) + + // Use fmt so it doesn't trigger another notification. + fmt.Println("Waiting for the notification goroutine to finish") + + _ = <-e.done +} + func (e *shoutrrrTypeNotifier) Levels() []log.Level { return e.logLevels } diff --git a/pkg/notifications/shoutrrr_test.go b/pkg/notifications/shoutrrr_test.go index 5db7473..8bba2d3 100644 --- a/pkg/notifications/shoutrrr_test.go +++ b/pkg/notifications/shoutrrr_test.go @@ -1,6 +1,7 @@ package notifications import ( + "github.com/containrrr/shoutrrr/pkg/types" "testing" "text/template" @@ -102,3 +103,69 @@ func TestShoutrrrInvalidTemplateUsesTemplate(t *testing.T) { require.Equal(t, sd, s) } + +type blockingRouter struct { + unlock chan bool + sent chan bool +} + +func (b blockingRouter) Send(message string, params *types.Params) []error { + _ = <-b.unlock + b.sent <- true + return nil +} + +func TestSlowNotificationNotSent(t *testing.T) { + _, blockingRouter := sendNotificationsWithBlockingRouter() + + notifSent := false + select { + case notifSent = <-blockingRouter.sent: + default: + } + + require.Equal(t, false, notifSent) +} + +func TestSlowNotificationSent(t *testing.T) { + shoutrrr, blockingRouter := sendNotificationsWithBlockingRouter() + + blockingRouter.unlock <- true + shoutrrr.Close() + + notifSent := false + select { + case notifSent = <-blockingRouter.sent: + default: + } + require.Equal(t, true, notifSent) +} + +func sendNotificationsWithBlockingRouter() (*shoutrrrTypeNotifier, *blockingRouter) { + cmd := new(cobra.Command) + + router := &blockingRouter{ + unlock: make(chan bool, 1), + sent: make(chan bool, 1), + } + + shoutrrr := &shoutrrrTypeNotifier{ + template: getShoutrrrTemplate(cmd), + messages: make(chan string, 1), + done: make(chan bool), + Router: router, + } + + entry := &log.Entry{ + Message: "foo bar", + } + + go sendNotifications(shoutrrr) + + shoutrrr.StartNotification() + shoutrrr.Fire(entry) + + shoutrrr.SendNotification() + + return shoutrrr, router +} diff --git a/pkg/notifications/slack.go b/pkg/notifications/slack.go index 42b7915..5f96390 100644 --- a/pkg/notifications/slack.go +++ b/pkg/notifications/slack.go @@ -42,3 +42,5 @@ func newSlackNotifier(c *cobra.Command, acceptedLogLevels []log.Level) t.Notifie func (s *slackTypeNotifier) StartNotification() {} func (s *slackTypeNotifier) SendNotification() {} + +func (s *slackTypeNotifier) Close() {} diff --git a/pkg/types/notifier.go b/pkg/types/notifier.go index c8d07d0..27dc483 100644 --- a/pkg/types/notifier.go +++ b/pkg/types/notifier.go @@ -4,4 +4,5 @@ package types type Notifier interface { StartNotification() SendNotification() + Close() }