diff --git a/cmd/pgproxy/README.md b/cmd/pgproxy/README.md new file mode 100644 index 000000000..e80c82663 --- /dev/null +++ b/cmd/pgproxy/README.md @@ -0,0 +1,40 @@ +# pgproxy + +The pgproxy server is a proxy for the Postgres wire protocol. + +The proxy runs an in-process Tailscale instance, accepts postgres +client connections over Tailscale only, and proxies them to the +configured upstream postgres server. + +This proxy exists because postgres clients default to very insecure +connection settings: either they "prefer" but do not require TLS; or +they set sslmode=require, which merely requires that a TLS handshake +took place, but don't verify the server's TLS certificate or the +presented TLS hostname. In other words, sslmode=require enforces that +a TLS session is created, but that session can trivially be +machine-in-the-middled to steal credentials, data, inject malicious +queries, and so forth. + +Because this flaw is in the client's validation of the TLS session, +you have no way of reliably detecting the misconfiguration +server-side. You could fix the configuration of all the clients you +know of, but the default makes it very easy to accidentally regress. + +Instead of trying to verify client configuration over time, this proxy +removes the need for postgres clients to be configured correctly: the +upstream database is configured to only accept connections from the +proxy, and the proxy is only available to clients over Tailscale. + +Therefore, clients must use the proxy to connect to the database. The +client<>proxy connection is secured end-to-end by Tailscale, which the +proxy enforces by verifying that the connecting client is a known +current Tailscale peer. The proxy<>server connection is established by +the proxy itself, using strict TLS verification settings, and the +client is only allowed to communicate with the server once we've +established that the upstream connection is safe to use. + +A couple side benefits: because clients can only connect via +Tailscale, you can use Tailscale ACLs as an extra layer of defense on +top of the postgres user/password authentication. And, the proxy can +maintain an audit log of who connected to the database, complete with +the strongly authenticated Tailscale identity of the client. diff --git a/cmd/pgproxy/pgproxy.go b/cmd/pgproxy/pgproxy.go new file mode 100644 index 000000000..e3fa25954 --- /dev/null +++ b/cmd/pgproxy/pgproxy.go @@ -0,0 +1,366 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The pgproxy server is a proxy for the Postgres wire protocol. +package main + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + crand "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "expvar" + "flag" + "fmt" + "io" + "log" + "math/big" + "net" + "net/http" + "os" + "strings" + "time" + + "tailscale.com/client/tailscale" + "tailscale.com/metrics" + "tailscale.com/tsnet" + "tailscale.com/tsweb" + "tailscale.com/types/logger" +) + +var ( + hostname = flag.String("hostname", "", "Tailscale hostname to serve on") + port = flag.Int("port", 5432, "Listening port for client connections") + debugPort = flag.Int("debug-port", 80, "Listening port for debug/metrics endpoint") + upstreamAddr = flag.String("upstream-addr", "", "Address of the upstream Postgres server, in host:port format") + upstreamCA = flag.String("upstream-ca-file", "", "File containing the PEM-encoded CA certificate for the upstream server") + tailscaleDir = flag.String("state-dir", "", "Directory in which to store the Tailscale auth state") +) + +func main() { + flag.Parse() + if *hostname == "" { + log.Fatal("missing --hostname") + } + if *upstreamAddr == "" { + log.Fatal("missing --upstream-addr") + } + if *upstreamCA == "" { + log.Fatal("missing --upstream-ca-file") + } + if *tailscaleDir == "" { + log.Fatal("missing --state-dir") + } + + ts := &tsnet.Server{ + Dir: *tailscaleDir, + Hostname: *hostname, + // Make the stdout logs a clean audit log of connections. + Logf: logger.Discard, + } + + if os.Getenv("TS_AUTHKEY") == "" { + log.Print("Note: you need to run this with TS_AUTHKEY=... the first time, to join your tailnet of choice.") + } + + tsclient, err := ts.LocalClient() + if err != nil { + log.Fatalf("getting tsnet API client: %v", err) + } + + p, err := newProxy(*upstreamAddr, *upstreamCA, tsclient) + if err != nil { + log.Fatal(err) + } + expvar.Publish("pgproxy", p.Expvar()) + + if *debugPort != 0 { + mux := http.NewServeMux() + tsweb.Debugger(mux) + srv := &http.Server{ + Handler: mux, + } + dln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort)) + if err != nil { + log.Fatal(err) + } + go func() { + log.Fatal(srv.Serve(dln)) + }() + } + + ln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *port)) + if err != nil { + log.Fatal(err) + } + log.Printf("serving access to %s on port %d", *upstreamAddr, *port) + log.Fatal(p.Serve(ln)) +} + +// proxy is a postgres wire protocol proxy, which strictly enforces +// the security of the TLS connection to its upstream regardless of +// what the client's TLS configuration is. +type proxy struct { + upstreamAddr string // "my.database.com:5432" + upstreamHost string // "my.database.com" + upstreamCertPool *x509.CertPool + downstreamCert []tls.Certificate + client *tailscale.LocalClient + + activeSessions expvar.Int + startedSessions expvar.Int + errors metrics.LabelMap +} + +// newProxy returns a proxy that forwards connections to +// upstreamAddr. The upstream's TLS session is verified using the CA +// cert(s) in upstreamCAPath. +func newProxy(upstreamAddr, upstreamCAPath string, client *tailscale.LocalClient) (*proxy, error) { + bs, err := os.ReadFile(upstreamCAPath) + if err != nil { + return nil, err + } + upstreamCertPool := x509.NewCertPool() + if !upstreamCertPool.AppendCertsFromPEM(bs) { + return nil, fmt.Errorf("invalid CA cert in %q", upstreamCAPath) + } + + h, _, err := net.SplitHostPort(upstreamAddr) + if err != nil { + return nil, err + } + downstreamCert, err := mkSelfSigned(h) + if err != nil { + return nil, err + } + + return &proxy{ + upstreamAddr: upstreamAddr, + upstreamHost: h, + upstreamCertPool: upstreamCertPool, + downstreamCert: []tls.Certificate{downstreamCert}, + client: client, + errors: metrics.LabelMap{Label: "kind"}, + }, nil +} + +// Expvar returns p's monitoring metrics. +func (p *proxy) Expvar() expvar.Var { + ret := &metrics.Set{} + ret.Set("sessions_active", &p.activeSessions) + ret.Set("sessions_started", &p.startedSessions) + ret.Set("session_errors", &p.errors) + return ret +} + +// Serve accepts postgres client connections on ln and proxies them to +// the configured upstream. ln can be any net.Listener, but all client +// connections must originate from tailscale IPs that can be verified +// with WhoIs. +func (p *proxy) Serve(ln net.Listener) error { + var lastSessionID int64 + for { + c, err := ln.Accept() + if err != nil { + return err + } + id := time.Now().UnixNano() + if id == lastSessionID { + // Bluntly enforce SID uniqueness, even if collisions are + // fantastically unlikely (but OSes vary in how much timer + // precision they expose to the OS, so id might be rounded + // e.g. to the same millisecond) + id++ + } + lastSessionID = id + go func(sessionID int64) { + if err := p.serve(sessionID, c); err != nil { + log.Printf("%d: session ended with error: %v", sessionID, err) + } + }(id) + } +} + +var ( + // sslStart is the magic bytes that postgres clients use to indicate + // that they want to do a TLS handshake. Servers should respond with + // the single byte "S" before starting a normal TLS handshake. + sslStart = [8]byte{0, 0, 0, 8, 0x04, 0xd2, 0x16, 0x2f} + // plaintextStart is the magic bytes that postgres clients use to + // indicate that they're starting a plaintext authentication + // handshake. + plaintextStart = [8]byte{0, 0, 0, 86, 0, 3, 0, 0} +) + +// serve proxies the postgres client on c to the proxy's upstream, +// enforcing strict TLS to the upstream. +func (p *proxy) serve(sessionID int64, c net.Conn) error { + defer c.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + whois, err := p.client.WhoIs(ctx, c.RemoteAddr().String()) + if err != nil { + p.errors.Add("whois-failed", 1) + return fmt.Errorf("getting client identity: %v", err) + } + + // Before anything else, log the connection attempt. + user, machine := "", "" + if whois.Node != nil { + if whois.Node.Hostinfo.ShareeNode() { + machine = "external-device" + } else { + machine = strings.TrimSuffix(whois.Node.Name, ".") + } + } + if whois.UserProfile != nil { + user = whois.UserProfile.LoginName + if user == "tagged-devices" && whois.Node != nil { + user = strings.Join(whois.Node.Tags, ",") + } + } + if user == "" || machine == "" { + p.errors.Add("no-ts-identity", 1) + return fmt.Errorf("couldn't identify source user and machine (user %q, machine %q)", user, machine) + } + log.Printf("%d: session start, from %s (machine %s, user %s)", sessionID, c.RemoteAddr(), machine, user) + start := time.Now() + defer func() { + elapsed := time.Since(start) + log.Printf("%d: session end, from %s (machine %s, user %s), lasted %s", sessionID, c.RemoteAddr(), machine, user, elapsed.Round(time.Millisecond)) + }() + + // Read the client's opening message, to figure out if it's trying + // to TLS or not. + var buf [8]byte + if _, err := io.ReadFull(c, buf[:len(sslStart)]); err != nil { + p.errors.Add("network-error", 1) + return fmt.Errorf("initial magic read: %v", err) + } + var clientIsTLS bool + switch { + case buf == sslStart: + clientIsTLS = true + case buf == plaintextStart: + clientIsTLS = false + default: + p.errors.Add("client-bad-protocol", 1) + return fmt.Errorf("unrecognized initial packet = % 02x", buf) + } + + // Dial & verify upstream connection. + var d net.Dialer + d.Timeout = 10 * time.Second + upc, err := d.Dial("tcp", p.upstreamAddr) + if err != nil { + p.errors.Add("network-error", 1) + return fmt.Errorf("upstream dial: %v", err) + } + defer upc.Close() + if _, err := upc.Write(sslStart[:]); err != nil { + p.errors.Add("network-error", 1) + return fmt.Errorf("upstream write of start-ssl magic: %v", err) + } + if _, err := io.ReadFull(upc, buf[:1]); err != nil { + p.errors.Add("network-error", 1) + return fmt.Errorf("reading upstream start-ssl response: %v", err) + } + if buf[0] != 'S' { + p.errors.Add("upstream-bad-protocol", 1) + return fmt.Errorf("upstream didn't acknowldge start-ssl, said %q", buf[0]) + } + tlsConf := &tls.Config{ + ServerName: p.upstreamHost, + RootCAs: p.upstreamCertPool, + MinVersion: tls.VersionTLS12, + } + uptc := tls.Client(upc, tlsConf) + if err = uptc.HandshakeContext(ctx); err != nil { + p.errors.Add("upstream-tls", 1) + return fmt.Errorf("upstream TLS handshake: %v", err) + } + + // Accept the client conn and set it up the way the client wants. + var clientConn net.Conn + if clientIsTLS { + io.WriteString(c, "S") // yeah, we're good to speak TLS + s := tls.Server(c, &tls.Config{ + ServerName: p.upstreamHost, + Certificates: p.downstreamCert, + MinVersion: tls.VersionTLS12, + }) + if err = uptc.HandshakeContext(ctx); err != nil { + p.errors.Add("client-tls", 1) + return fmt.Errorf("client TLS handshake: %v", err) + } + clientConn = s + } else { + // Repeat the header we read earlier up to the server. + if _, err := uptc.Write(plaintextStart[:]); err != nil { + p.errors.Add("network-error", 1) + return fmt.Errorf("sending initial client bytes to upstream: %v", err) + } + clientConn = c + } + + // Finally, proxy the client to the upstream. + errc := make(chan error, 1) + go func() { + _, err := io.Copy(uptc, clientConn) + errc <- err + }() + go func() { + _, err := io.Copy(clientConn, uptc) + errc <- err + }() + if err := <-errc; err != nil { + // Don't increment error counts here, because the most common + // cause of termination is client or server closing the + // connection normally, and it'll obscure "interesting" + // handshake errors. + return fmt.Errorf("session terminated with error: %v", err) + } + return nil +} + +// mkSelfSigned creates and returns a self-signed TLS certificate for +// hostname. +func mkSelfSigned(hostname string) (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader) + if err != nil { + return tls.Certificate{}, err + } + pub := priv.Public() + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"pgproxy"}, + }, + DNSNames: []string{hostname}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(crand.Reader, &template, &template, pub, priv) + if err != nil { + return tls.Certificate{}, err + } + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + return tls.Certificate{}, err + } + + return tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + Leaf: cert, + }, nil +}