Support TLS connections to remote daemons

pull/1/head
Brian DeHamer 9 years ago
parent b7424e5c47
commit e06c46552a

@ -1,6 +1,7 @@
package container package container
import ( import (
"crypto/tls"
"time" "time"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
@ -21,8 +22,8 @@ type Client interface {
IsContainerStale(Container) (bool, error) IsContainerStale(Container) (bool, error)
} }
func NewClient(dockerHost string, pullImages bool) Client { func NewClient(dockerHost string, tlsConfig *tls.Config, pullImages bool) Client {
docker, err := dockerclient.NewDockerClient(dockerHost, nil) docker, err := dockerclient.NewDockerClient(dockerHost, tlsConfig)
if err != nil { if err != nil {
log.Fatalf("Error instantiating Docker client: %s", err) log.Fatalf("Error instantiating Docker client: %s", err)

@ -1,9 +1,13 @@
package main // import "github.com/CenturyLinkLabs/watchtower" package main // import "github.com/CenturyLinkLabs/watchtower"
import ( import (
"crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"os/signal" "os/signal"
"strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -15,10 +19,22 @@ import (
) )
var ( var (
wg sync.WaitGroup wg sync.WaitGroup
client container.Client
pollInterval time.Duration
) )
func init() {
log.SetLevel(log.InfoLevel)
}
func main() { func main() {
rootCertPath := "/etc/ssl/docker"
if os.Getenv("DOCKER_CERT_PATH") != "" {
rootCertPath = os.Getenv("DOCKER_CERT_PATH")
}
app := cli.NewApp() app := cli.NewApp()
app.Name = "watchtower" app.Name = "watchtower"
app.Usage = "Automatically update running Docker containers" app.Usage = "Automatically update running Docker containers"
@ -27,27 +43,87 @@ func main() {
app.Flags = []cli.Flag{ app.Flags = []cli.Flag{
cli.StringFlag{ cli.StringFlag{
Name: "host, H", Name: "host, H",
Usage: "daemon socket to connect to",
Value: "unix:///var/run/docker.sock", Value: "unix:///var/run/docker.sock",
Usage: "Docker daemon socket to connect to",
EnvVar: "DOCKER_HOST", EnvVar: "DOCKER_HOST",
}, },
cli.IntFlag{ cli.IntFlag{
Name: "interval, i", Name: "interval, i",
Value: 300,
Usage: "poll interval (in seconds)", Usage: "poll interval (in seconds)",
Value: 300,
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "no-pull", Name: "no-pull",
Usage: "do not pull new images", Usage: "do not pull new images",
}, },
cli.BoolFlag{
Name: "tls",
Usage: "use TLS; implied by --tlsverify",
},
cli.BoolFlag{
Name: "tlsverify",
Usage: "use TLS and verify the remote",
EnvVar: "DOCKER_TLS_VERIFY",
},
cli.StringFlag{
Name: "tlscacert",
Usage: "trust certs signed only by this CA",
Value: fmt.Sprintf("%s/ca.pem", rootCertPath),
},
cli.StringFlag{
Name: "tlscert",
Usage: "client certificate for TLS authentication",
Value: fmt.Sprintf("%s/cert.pem", rootCertPath),
},
cli.StringFlag{
Name: "tlskey",
Usage: "client key for TLS authentication",
Value: fmt.Sprintf("%s/key.pem", rootCertPath),
},
cli.BoolFlag{ cli.BoolFlag{
Name: "debug", Name: "debug",
Usage: "enable debug mode with verbose logging", Usage: "enable debug mode with verbose logging",
}, },
} }
if err := app.Run(os.Args); err != nil {
log.Fatal(err)
}
}
func before(c *cli.Context) error {
if c.GlobalBool("debug") {
log.SetLevel(log.DebugLevel)
}
pollInterval = time.Duration(c.Int("interval")) * time.Second
// Set-up container client
tls, err := tlsConfig(c)
if err != nil {
return err
}
client = container.NewClient(c.GlobalString("host"), tls, !c.GlobalBool("no-pull"))
handleSignals() handleSignals()
app.Run(os.Args) return nil
}
func start(*cli.Context) {
if err := actions.CheckPrereqs(client); err != nil {
log.Fatal(err)
}
for {
wg.Add(1)
if err := actions.Update(client); err != nil {
fmt.Println(err)
}
wg.Done()
time.Sleep(pollInterval)
}
} }
func handleSignals() { func handleSignals() {
@ -63,34 +139,56 @@ func handleSignals() {
}() }()
} }
func before(c *cli.Context) error { // tlsConfig translates the command-line options into a tls.Config struct
if c.GlobalBool("debug") { func tlsConfig(c *cli.Context) (*tls.Config, error) {
log.SetLevel(log.DebugLevel) var tlsConfig *tls.Config
} else { var err error
log.SetLevel(log.InfoLevel) caCertFlag := c.GlobalString("tlscacert")
} certFlag := c.GlobalString("tlscert")
keyFlag := c.GlobalString("tlskey")
client := newContainerClient(c) if c.GlobalBool("tls") || c.GlobalBool("tlsverify") {
return actions.CheckPrereqs(client) tlsConfig = &tls.Config{
} InsecureSkipVerify: !c.GlobalBool("tlsverify"),
}
func start(c *cli.Context) { // Load CA cert
client := newContainerClient(c) if caCertFlag != "" {
secs := time.Duration(c.Int("interval")) * time.Second var caCert []byte
for { if strings.HasPrefix(caCertFlag, "/") {
wg.Add(1) caCert, err = ioutil.ReadFile(caCertFlag)
if err := actions.Update(client); err != nil { if err != nil {
fmt.Println(err) return nil, err
}
} else {
caCert = []byte(caCertFlag)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.RootCAs = caCertPool
} }
wg.Done()
time.Sleep(secs) // Load client certificate
if certFlag != "" && keyFlag != "" {
var cert tls.Certificate
if strings.HasPrefix(certFlag, "/") && strings.HasPrefix(keyFlag, "/") {
cert, err = tls.LoadX509KeyPair(certFlag, keyFlag)
if err != nil {
return nil, err
}
} else {
cert, err = tls.X509KeyPair([]byte(certFlag), []byte(keyFlag))
if err != nil {
return nil, err
}
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
} }
}
func newContainerClient(c *cli.Context) container.Client { return tlsConfig, nil
dockerHost := c.GlobalString("host")
noPull := c.GlobalBool("no-pull")
return container.NewClient(dockerHost, !noPull)
} }

Loading…
Cancel
Save