@ -3,15 +3,20 @@
// The sniproxy is an outbound SNI proxy. It receives TLS connections over
// The sniproxy is an outbound SNI proxy. It receives TLS connections over
// Tailscale on one or more TCP ports and sends them out to the same SNI
// Tailscale on one or more TCP ports and sends them out to the same SNI
// hostname & port on the internet. It only does TCP.
// hostname & port on the internet. It can optionally forward one or more
// TCP ports to a specific destination. It only does TCP.
package main
package main
import (
import (
"context"
"context"
"errors"
"expvar"
"flag"
"flag"
"fmt"
"log"
"log"
"net"
"net"
"net/http"
"net/http"
"strconv"
"strings"
"strings"
"time"
"time"
@ -19,27 +24,54 @@ import (
"inet.af/tcpproxy"
"inet.af/tcpproxy"
"tailscale.com/client/tailscale"
"tailscale.com/client/tailscale"
"tailscale.com/hostinfo"
"tailscale.com/hostinfo"
"tailscale.com/metrics"
"tailscale.com/net/netutil"
"tailscale.com/net/netutil"
"tailscale.com/tsnet"
"tailscale.com/tsnet"
"tailscale.com/tsweb"
"tailscale.com/types/nettype"
"tailscale.com/types/nettype"
"tailscale.com/util/clientmetric"
"tailscale.com/util/clientmetric"
)
)
var (
var (
ports = flag . String ( "ports" , "443" , "comma-separated list of ports to proxy" )
ports = flag . String ( "ports" , "443" , "comma-separated list of ports to proxy" )
forwards = flag . String ( "forwards" , "" , "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com" )
wgPort = flag . Int ( "wg-listen-port" , 0 , "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select" )
wgPort = flag . Int ( "wg-listen-port" , 0 , "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select" )
promoteHTTPS = flag . Bool ( "promote-https" , true , "promote HTTP to HTTPS" )
promoteHTTPS = flag . Bool ( "promote-https" , true , "promote HTTP to HTTPS" )
debugPort = flag . Int ( "debug-port" , 8080 , "Listening port for debug/metrics endpoint" )
)
)
var tsMBox = dnsmessage . MustNewName ( "support.tailscale.com." )
var tsMBox = dnsmessage . MustNewName ( "support.tailscale.com." )
var (
// portForward is the state for a single port forwarding entry, as passed to the --forward flag.
numSessions = clientmetric . NewCounter ( "sniproxy_sessions" )
type portForward struct {
numBadAddrPort = clientmetric . NewCounter ( "sniproxy_bad_addrport" )
Port int
dnsResponses = clientmetric . NewCounter ( "sniproxy_dns_responses" )
Proto string
dnsFailures = clientmetric . NewCounter ( "sniproxy_dns_failed" )
Destination string
httpPromoted = clientmetric . NewCounter ( "sniproxy_http_promoted" )
}
)
// parseForward takes a proto/port/destination tuple as an input, as would be passed
// to the --forward command line flag, and returns a *portForward struct of those parameters.
func parseForward ( value string ) ( * portForward , error ) {
parts := strings . Split ( value , "/" )
if len ( parts ) != 3 {
return nil , errors . New ( "cannot parse: " + value )
}
proto := parts [ 0 ]
if proto != "tcp" {
return nil , errors . New ( "unsupported forwarding protocol: " + proto )
}
port , err := strconv . ParseUint ( parts [ 1 ] , 10 , 16 )
if err != nil {
return nil , errors . New ( "bad forwarding port: " + parts [ 1 ] )
}
host := parts [ 2 ]
if host == "" {
return nil , errors . New ( "bad destination: " + value )
}
return & portForward { Port : int ( port ) , Proto : proto , Destination : host } , nil
}
func main ( ) {
func main ( ) {
flag . Parse ( )
flag . Parse ( )
@ -58,6 +90,7 @@ func main() {
log . Fatal ( err )
log . Fatal ( err )
}
}
s . lc = lc
s . lc = lc
s . initMetrics ( )
for _ , portStr := range strings . Split ( * ports , "," ) {
for _ , portStr := range strings . Split ( * ports , "," ) {
ln , err := s . ts . Listen ( "tcp" , ":" + portStr )
ln , err := s . ts . Listen ( "tcp" , ":" + portStr )
@ -68,6 +101,34 @@ func main() {
go s . serve ( ln )
go s . serve ( ln )
}
}
for _ , forwStr := range strings . Split ( * forwards , "," ) {
if forwStr == "" {
continue
}
forw , err := parseForward ( forwStr )
if err != nil {
log . Fatal ( err )
}
ln , err := s . ts . Listen ( "tcp" , ":" + strconv . Itoa ( forw . Port ) )
if err != nil {
log . Fatal ( err )
}
log . Printf ( "Serving on port %d to %s..." , forw . Port , forw . Destination )
// Add an entry to the expvar LabelMap for Prometheus metrics,
// and create a clientmetric to report that same value.
service := portNumberToName ( forw )
s . numTCPsessions . SetInt64 ( service , 0 )
metric := fmt . Sprintf ( "sniproxy_tcp_sessions_%s" , service )
clientmetric . NewCounterFunc ( metric , func ( ) int64 {
return s . numTCPsessions . Get ( service ) . Value ( )
} )
go s . forward ( ln , forw )
}
ln , err := s . ts . Listen ( "udp" , ":53" )
ln , err := s . ts . Listen ( "udp" , ":53" )
if err != nil {
if err != nil {
log . Fatal ( err )
log . Fatal ( err )
@ -83,12 +144,31 @@ func main() {
go s . promoteHTTPS ( ln )
go s . promoteHTTPS ( ln )
}
}
if * debugPort != 0 {
mux := http . NewServeMux ( )
tsweb . Debugger ( mux )
dln , err := s . ts . Listen ( "tcp" , fmt . Sprintf ( ":%d" , * debugPort ) )
if err != nil {
log . Fatal ( err )
}
go func ( ) {
log . Fatal ( http . Serve ( dln , mux ) )
} ( )
}
select { }
select { }
}
}
type server struct {
type server struct {
ts tsnet . Server
ts tsnet . Server
lc * tailscale . LocalClient
lc * tailscale . LocalClient
numTLSsessions expvar . Int
numTCPsessions * metrics . LabelMap
numBadAddrPort expvar . Int
dnsResponses expvar . Int
dnsFailures expvar . Int
httpPromoted expvar . Int
}
}
func ( s * server ) serve ( ln net . Listener ) {
func ( s * server ) serve ( ln net . Listener ) {
@ -101,6 +181,16 @@ func (s *server) serve(ln net.Listener) {
}
}
}
}
func ( s * server ) forward ( ln net . Listener , forw * portForward ) {
for {
c , err := ln . Accept ( )
if err != nil {
log . Fatal ( err )
}
go s . forwardConn ( c , forw )
}
}
func ( s * server ) serveDNS ( ln net . Listener ) {
func ( s * server ) serveDNS ( ln net . Listener ) {
for {
for {
c , err := ln . Accept ( )
c , err := ln . Accept ( )
@ -118,7 +208,7 @@ func (s *server) serveDNSConn(c nettype.ConnPacketConn) {
n , err := c . Read ( buf )
n , err := c . Read ( buf )
if err != nil {
if err != nil {
log . Printf ( "c.Read failed: %v\n " , err )
log . Printf ( "c.Read failed: %v\n " , err )
dnsFailures. Add ( 1 )
s. dnsFailures. Add ( 1 )
return
return
}
}
@ -126,25 +216,25 @@ func (s *server) serveDNSConn(c nettype.ConnPacketConn) {
err = msg . Unpack ( buf [ : n ] )
err = msg . Unpack ( buf [ : n ] )
if err != nil {
if err != nil {
log . Printf ( "dnsmessage unpack failed: %v\n " , err )
log . Printf ( "dnsmessage unpack failed: %v\n " , err )
dnsFailures. Add ( 1 )
s. dnsFailures. Add ( 1 )
return
return
}
}
buf , err = s . dnsResponse ( & msg )
buf , err = s . dnsResponse ( & msg )
if err != nil {
if err != nil {
log . Printf ( "s.dnsResponse failed: %v\n" , err )
log . Printf ( "s.dnsResponse failed: %v\n" , err )
dnsFailures. Add ( 1 )
s. dnsFailures. Add ( 1 )
return
return
}
}
_ , err = c . Write ( buf )
_ , err = c . Write ( buf )
if err != nil {
if err != nil {
log . Printf ( "c.Write failed: %v\n" , err )
log . Printf ( "c.Write failed: %v\n" , err )
dnsFailures. Add ( 1 )
s. dnsFailures. Add ( 1 )
return
return
}
}
dnsResponses. Add ( 1 )
s. dnsResponses. Add ( 1 )
}
}
func ( s * server ) serveConn ( c net . Conn ) {
func ( s * server ) serveConn ( c net . Conn ) {
@ -152,7 +242,7 @@ func (s *server) serveConn(c net.Conn) {
_ , port , err := net . SplitHostPort ( addrPortStr )
_ , port , err := net . SplitHostPort ( addrPortStr )
if err != nil {
if err != nil {
log . Printf ( "bogus addrPort %q" , addrPortStr )
log . Printf ( "bogus addrPort %q" , addrPortStr )
numBadAddrPort. Add ( 1 )
s. numBadAddrPort. Add ( 1 )
c . Close ( )
c . Close ( )
return
return
}
}
@ -165,7 +255,7 @@ func (s *server) serveConn(c net.Conn) {
return netutil . NewOneConnListener ( c , nil ) , nil
return netutil . NewOneConnListener ( c , nil ) , nil
}
}
p . AddSNIRouteFunc ( addrPortStr , func ( ctx context . Context , sniName string ) ( t tcpproxy . Target , ok bool ) {
p . AddSNIRouteFunc ( addrPortStr , func ( ctx context . Context , sniName string ) ( t tcpproxy . Target , ok bool ) {
numSessions. Add ( 1 )
s. numTL Ss essions. Add ( 1 )
return & tcpproxy . DialProxy {
return & tcpproxy . DialProxy {
Addr : net . JoinHostPort ( sniName , port ) ,
Addr : net . JoinHostPort ( sniName , port ) ,
DialContext : dialer . DialContext ,
DialContext : dialer . DialContext ,
@ -174,6 +264,49 @@ func (s *server) serveConn(c net.Conn) {
p . Start ( )
p . Start ( )
}
}
// portNumberToName returns a human-readable name for several port numbers commonly forwarded,
// and "tcp###" for everything else. It is used for metric label names.
func portNumberToName ( forw * portForward ) string {
switch forw . Port {
case 22 :
return "ssh"
case 1433 :
return "sqlserver"
case 3306 :
return "mysql"
case 3389 :
return "rdp"
case 5432 :
return "postgres"
default :
return fmt . Sprintf ( "%s%d" , forw . Proto , forw . Port )
}
}
// forwardConn sets up a forwarder for a TCP connection. It does not inspect of the data
// like the SNI forwarding does, it merely forwards all data to the destination specified
// in the --forward=tcp/22/github.com argument.
func ( s * server ) forwardConn ( c net . Conn , forw * portForward ) {
addrPortStr := c . LocalAddr ( ) . String ( )
var dialer net . Dialer
dialer . Timeout = 30 * time . Second
var p tcpproxy . Proxy
p . ListenFunc = func ( net , laddr string ) ( net . Listener , error ) {
return netutil . NewOneConnListener ( c , nil ) , nil
}
dial := & tcpproxy . DialProxy {
Addr : fmt . Sprintf ( "%s:%d" , forw . Destination , forw . Port ) ,
DialContext : dialer . DialContext ,
}
p . AddRoute ( addrPortStr , dial )
s . numTCPsessions . Add ( portNumberToName ( forw ) , 1 )
p . Start ( )
}
func ( s * server ) dnsResponse ( req * dnsmessage . Message ) ( buf [ ] byte , err error ) {
func ( s * server ) dnsResponse ( req * dnsmessage . Message ) ( buf [ ] byte , err error ) {
resp := dnsmessage . NewBuilder ( buf ,
resp := dnsmessage . NewBuilder ( buf ,
dnsmessage . Header {
dnsmessage . Header {
@ -235,8 +368,36 @@ func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
func ( s * server ) promoteHTTPS ( ln net . Listener ) {
func ( s * server ) promoteHTTPS ( ln net . Listener ) {
err := http . Serve ( ln , http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
err := http . Serve ( ln , http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
httpPromoted. Add ( 1 )
s. httpPromoted. Add ( 1 )
http . Redirect ( w , r , "https://" + r . Host + r . RequestURI , http . StatusFound )
http . Redirect ( w , r , "https://" + r . Host + r . RequestURI , http . StatusFound )
} ) )
} ) )
log . Fatalf ( "promoteHTTPS http.Serve: %v" , err )
log . Fatalf ( "promoteHTTPS http.Serve: %v" , err )
}
}
// initMetrics sets up local prometheus metrics, and creates clientmetrics to report those
// same counters.
func ( s * server ) initMetrics ( ) {
stats := new ( metrics . Set )
stats . Set ( "tls_sessions" , & s . numTLSsessions )
clientmetric . NewCounterFunc ( "sniproxy_tls_sessions" , s . numTLSsessions . Value )
s . numTCPsessions = & metrics . LabelMap { Label : "proto" }
stats . Set ( "tcp_sessions" , s . numTCPsessions )
// clientmetric doesn't have a good way to implement a Map type.
// We create clientmetrics dynamically when parsing the --forwards argument
stats . Set ( "bad_addrport" , & s . numBadAddrPort )
clientmetric . NewCounterFunc ( "sniproxy_bad_addrport" , s . numBadAddrPort . Value )
stats . Set ( "dns_responses" , & s . dnsResponses )
clientmetric . NewCounterFunc ( "sniproxy_dns_responses" , s . dnsResponses . Value )
stats . Set ( "dns_failed" , & s . dnsFailures )
clientmetric . NewCounterFunc ( "sniproxy_dns_failed" , s . dnsFailures . Value )
stats . Set ( "http_promoted" , & s . httpPromoted )
clientmetric . NewCounterFunc ( "sniproxy_http_promoted" , s . httpPromoted . Value )
expvar . Publish ( "sniproxy" , stats )
}