diff --git a/container/client.go b/container/client.go index 7fdbf9a..a9722e3 100644 --- a/container/client.go +++ b/container/client.go @@ -1,6 +1,7 @@ package container import ( + "crypto/tls" "time" log "github.com/Sirupsen/logrus" @@ -21,8 +22,8 @@ type Client interface { IsContainerStale(Container) (bool, error) } -func NewClient(dockerHost string, pullImages bool) Client { - docker, err := dockerclient.NewDockerClient(dockerHost, nil) +func NewClient(dockerHost string, tlsConfig *tls.Config, pullImages bool) Client { + docker, err := dockerclient.NewDockerClient(dockerHost, tlsConfig) if err != nil { log.Fatalf("Error instantiating Docker client: %s", err) diff --git a/main.go b/main.go index c5d0800..6151133 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,13 @@ package main // import "github.com/CenturyLinkLabs/watchtower" import ( + "crypto/tls" + "crypto/x509" "fmt" + "io/ioutil" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -15,10 +19,22 @@ import ( ) var ( - wg sync.WaitGroup + wg sync.WaitGroup + client container.Client + pollInterval time.Duration ) +func init() { + log.SetLevel(log.InfoLevel) +} + func main() { + rootCertPath := "/etc/ssl/docker" + + if os.Getenv("DOCKER_CERT_PATH") != "" { + rootCertPath = os.Getenv("DOCKER_CERT_PATH") + } + app := cli.NewApp() app.Name = "watchtower" app.Usage = "Automatically update running Docker containers" @@ -27,27 +43,87 @@ func main() { app.Flags = []cli.Flag{ cli.StringFlag{ Name: "host, H", + Usage: "daemon socket to connect to", Value: "unix:///var/run/docker.sock", - Usage: "Docker daemon socket to connect to", EnvVar: "DOCKER_HOST", }, cli.IntFlag{ Name: "interval, i", - Value: 300, Usage: "poll interval (in seconds)", + Value: 300, }, cli.BoolFlag{ Name: "no-pull", Usage: "do not pull new images", }, + cli.BoolFlag{ + Name: "tls", + Usage: "use TLS; implied by --tlsverify", + }, + cli.BoolFlag{ + Name: "tlsverify", + Usage: "use TLS and verify the remote", + EnvVar: "DOCKER_TLS_VERIFY", + }, + cli.StringFlag{ + Name: "tlscacert", + Usage: "trust certs signed only by this CA", + Value: fmt.Sprintf("%s/ca.pem", rootCertPath), + }, + cli.StringFlag{ + Name: "tlscert", + Usage: "client certificate for TLS authentication", + Value: fmt.Sprintf("%s/cert.pem", rootCertPath), + }, + cli.StringFlag{ + Name: "tlskey", + Usage: "client key for TLS authentication", + Value: fmt.Sprintf("%s/key.pem", rootCertPath), + }, cli.BoolFlag{ Name: "debug", Usage: "enable debug mode with verbose logging", }, } + if err := app.Run(os.Args); err != nil { + log.Fatal(err) + } +} + +func before(c *cli.Context) error { + if c.GlobalBool("debug") { + log.SetLevel(log.DebugLevel) + } + + pollInterval = time.Duration(c.Int("interval")) * time.Second + + // Set-up container client + tls, err := tlsConfig(c) + if err != nil { + return err + } + + client = container.NewClient(c.GlobalString("host"), tls, !c.GlobalBool("no-pull")) + handleSignals() - app.Run(os.Args) + return nil +} + +func start(*cli.Context) { + if err := actions.CheckPrereqs(client); err != nil { + log.Fatal(err) + } + + for { + wg.Add(1) + if err := actions.Update(client); err != nil { + fmt.Println(err) + } + wg.Done() + + time.Sleep(pollInterval) + } } func handleSignals() { @@ -63,34 +139,56 @@ func handleSignals() { }() } -func before(c *cli.Context) error { - if c.GlobalBool("debug") { - log.SetLevel(log.DebugLevel) - } else { - log.SetLevel(log.InfoLevel) - } +// tlsConfig translates the command-line options into a tls.Config struct +func tlsConfig(c *cli.Context) (*tls.Config, error) { + var tlsConfig *tls.Config + var err error + caCertFlag := c.GlobalString("tlscacert") + certFlag := c.GlobalString("tlscert") + keyFlag := c.GlobalString("tlskey") - client := newContainerClient(c) - return actions.CheckPrereqs(client) -} + if c.GlobalBool("tls") || c.GlobalBool("tlsverify") { + tlsConfig = &tls.Config{ + InsecureSkipVerify: !c.GlobalBool("tlsverify"), + } -func start(c *cli.Context) { - client := newContainerClient(c) - secs := time.Duration(c.Int("interval")) * time.Second + // Load CA cert + if caCertFlag != "" { + var caCert []byte - for { - wg.Add(1) - if err := actions.Update(client); err != nil { - fmt.Println(err) + if strings.HasPrefix(caCertFlag, "/") { + caCert, err = ioutil.ReadFile(caCertFlag) + if err != nil { + return nil, err + } + } else { + caCert = []byte(caCertFlag) + } + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig.RootCAs = caCertPool } - wg.Done() - time.Sleep(secs) + // Load client certificate + if certFlag != "" && keyFlag != "" { + var cert tls.Certificate + + if strings.HasPrefix(certFlag, "/") && strings.HasPrefix(keyFlag, "/") { + cert, err = tls.LoadX509KeyPair(certFlag, keyFlag) + if err != nil { + return nil, err + } + } else { + cert, err = tls.X509KeyPair([]byte(certFlag), []byte(keyFlag)) + if err != nil { + return nil, err + } + } + tlsConfig.Certificates = []tls.Certificate{cert} + } } -} -func newContainerClient(c *cli.Context) container.Client { - dockerHost := c.GlobalString("host") - noPull := c.GlobalBool("no-pull") - return container.NewClient(dockerHost, !noPull) + return tlsConfig, nil }