@ -10,30 +10,34 @@ package main
import (
"context"
"errors"
"expvar"
"flag"
"fmt"
"log"
"net"
"net/http"
"net/netip"
"os"
"sort"
"strconv"
"strings"
"time"
"github.com/peterbourgon/ff/v3"
"golang.org/x/net/dns/dnsmessage"
" inet.af/tcpproxy "
" tailscale.com/appc "
"tailscale.com/client/tailscale"
"tailscale.com/hostinfo"
"tailscale.com/ metrics "
"tailscale.com/ net/netutil "
"tailscale.com/ ipn "
"tailscale.com/ tailcfg "
"tailscale.com/tsnet"
"tailscale.com/tsweb"
"tailscale.com/types/appctype"
"tailscale.com/types/ipproto"
"tailscale.com/types/nettype"
"tailscale.com/util/clientmetric"
"tailscale.com/util/ mak "
)
const configCapKey = "tailscale.com/sniproxy"
var tsMBox = dnsmessage . MustNewName ( "support.tailscale.com." )
// portForward is the state for a single port forwarding entry, as passed to the --forward flag.
@ -68,6 +72,7 @@ func parseForward(value string) (*portForward, error) {
}
func main ( ) {
// Parse flags
fs := flag . NewFlagSet ( "sniproxy" , flag . ContinueOnError )
var (
ports = fs . String ( "ports" , "443" , "comma-separated list of ports to proxy" )
@ -77,334 +82,214 @@ func main() {
debugPort = fs . Int ( "debug-port" , 8893 , "Listening port for debug/metrics endpoint" )
hostname = fs . String ( "hostname" , "" , "Hostname to register the service under" )
)
err := ff . Parse ( fs , os . Args [ 1 : ] , ff . WithEnvVarPrefix ( "TS_APPC" ) )
if err != nil {
log . Fatal ( "ff.Parse" )
}
if * ports == "" {
log . Fatal ( "no ports" )
}
hostinfo . SetApp ( "sniproxy" )
var ts tsnet . Server
defer ts . Close ( )
ctx , cancel := context . WithCancel ( context . Background ( ) )
defer cancel ( )
run ( ctx , & ts , * wgPort , * hostname , * promoteHTTPS , * debugPort , * ports , * forwards )
}
// run actually runs the sniproxy. Its separate from main() to assist in testing.
func run ( ctx context . Context , ts * tsnet . Server , wgPort int , hostname string , promoteHTTPS bool , debugPort int , ports , forwards string ) {
// Wire up Tailscale node + app connector server
hostinfo . SetApp ( "sniproxy" )
var s server
s . ts . Port = uint16 ( * wgPort )
s . ts . Hostname = * hostname
defer s . ts . Close ( )
s . ts = ts
s . ts . Port = uint16 ( wgPort )
s . ts . Hostname = hostname
lc , err := s . ts . LocalClient ( )
if err != nil {
log . Fatal ( err )
log . Fatal f ( "LocalClient() failed: %v" , err )
}
s . lc = lc
s . initMetrics ( )
for _ , portStr := range strings . Split ( * ports , "," ) {
ln , err := s . ts . Listen ( "tcp" , ":" + portStr )
if err != nil {
log . Fatal ( err )
}
log . Printf ( "Serving on port %v ..." , portStr )
go s . serve ( ln )
}
for _ , forwStr := range strings . Split ( * forwards , "," ) {
if forwStr == "" {
continue
}
forw , err := parseForward ( forwStr )
if err != nil {
log . Fatal ( err )
}
ln , err := s . ts . Listen ( "tcp" , ":" + strconv . Itoa ( forw . Port ) )
if err != nil {
log . Fatal ( err )
}
log . Printf ( "Serving on port %d to %s..." , forw . Port , forw . Destination )
// Add an entry to the expvar LabelMap for Prometheus metrics,
// and create a clientmetric to report that same value.
service := portNumberToName ( forw )
s . numTCPsessions . SetInt64 ( service , 0 )
metric := fmt . Sprintf ( "sniproxy_tcp_sessions_%s" , service )
clientmetric . NewCounterFunc ( metric , func ( ) int64 {
return s . numTCPsessions . Get ( service ) . Value ( )
} )
go s . forward ( ln , forw )
}
s . ts . RegisterFallbackTCPHandler ( s . appc . HandleTCPFlow )
// Start special-purpose listeners: dns, http promotion, debug server
ln , err := s . ts . Listen ( "udp" , ":53" )
if err != nil {
log . Fatal ( err )
log . Fatalf ( "failed listening on port 53: %v" , err )
}
defer ln . Close ( )
go s . serveDNS ( ln )
if * promoteHTTPS {
if promoteHTTPS {
ln , err := s . ts . Listen ( "tcp" , ":80" )
if err != nil {
log . Fatal ( err )
log . Fatalf ( "failed listening on port 80: %v" , err )
}
defer ln . Close ( )
log . Printf ( "Promoting HTTP to HTTPS ..." )
go s . promoteHTTPS ( ln )
}
if * debugPort != 0 {
if debugPort != 0 {
mux := http . NewServeMux ( )
tsweb . Debugger ( mux )
dln , err := s . ts . Listen ( "tcp" , fmt . Sprintf ( ":%d" , * debugPort ) )
dln , err := s . ts . Listen ( "tcp" , fmt . Sprintf ( ":%d" , debugPort ) )
if err != nil {
log . Fatal ( err )
log . Fatalf ( "failed listening on debug port: %v" , err )
}
defer dln . Close ( )
go func ( ) {
log . Fatal ( http . Serve ( dln , mux ) )
log . Fatalf ( "debug serve: %v" , http . Serve ( dln , mux ) )
} ( )
}
select { }
}
type server struct {
ts tsnet . Server
lc * tailscale . LocalClient
numTLSsessions expvar . Int
numTCPsessions * metrics . LabelMap
numBadAddrPort expvar . Int
dnsResponses expvar . Int
dnsFailures expvar . Int
httpPromoted expvar . Int
}
func ( s * server ) serve ( ln net . Listener ) {
for {
c , err := ln . Accept ( )
// Finally, start mainloop to configure app connector based on information
// in the netmap.
// We set the NotifyInitialNetMap flag so we will always get woken with the
// current netmap, before only being woken on changes.
bus , err := lc . WatchIPNBus ( ctx , ipn . NotifyWatchEngineUpdates | ipn . NotifyInitialNetMap | ipn . NotifyNoPrivateKeys )
if err != nil {
log . Fatal ( err )
log . Fatalf ( "watching IPN bus: %v" , err )
}
go s . serveConn ( c )
}
}
func ( s * server ) forward ( ln net . Listener , forw * portForward ) {
defer bus . Close ( )
for {
c, err := ln . Accep t( )
msg , err := bus . Next ( )
if err != nil {
log . Fatal ( err )
if errors . Is ( err , context . Canceled ) {
return
}
go s . forwardConn ( c , forw )
log . Fatalf ( "reading IPN bus: %v" , err )
}
}
func ( s * server ) serveDNS ( ln net . Listener ) {
for {
c , err := ln . Accept ( )
if err != nil {
log . Fatal ( err )
}
go s . serveDNSConn ( c . ( nettype . ConnPacketConn ) )
}
}
// NetMap contains app-connector configuration
if nm := msg . NetMap ; nm != nil && nm . SelfNode . Valid ( ) {
sn := nm . SelfNode . AsStruct ( )
func ( s * server ) serveDNSConn ( c nettype . ConnPacketConn ) {
defer c . Close ( )
c . SetReadDeadline ( time . Now ( ) . Add ( 5 * time . Second ) )
buf := make ( [ ] byte , 1500 )
n , err := c . Read ( buf )
var c appctype . AppConnectorConfig
nmConf , err := tailcfg . UnmarshalNodeCapJSON [ appctype . AppConnectorConfig ] ( sn . CapMap , configCapKey )
if err != nil {
log . Printf ( " c.Read failed: %v\n ", err )
s . dnsFailures . Add ( 1 )
return
log . Printf ( "failed to read app connector configuration from coordination server: %v" , err )
} else if len ( nmConf ) > 0 {
c = nmConf [ 0 ]
}
var msg dnsmessage . Message
err = msg . Unpack ( buf [ : n ] )
if err != nil {
log . Printf ( "dnsmessage unpack failed: %v\n " , err )
s . dnsFailures . Add ( 1 )
return
if c . AdvertiseRoutes {
if err := s . advertiseRoutesFromConfig ( ctx , & c ) ; err != nil {
log . Printf ( "failed to advertise routes: %v" , err )
}
buf , err = s . dnsResponse ( & msg )
if err != nil {
log . Printf ( "s.dnsResponse failed: %v\n" , err )
s . dnsFailures . Add ( 1 )
return
}
_ , err = c . Write ( buf )
if err != nil {
log . Printf ( "c.Write failed: %v\n" , err )
s . dnsFailures . Add ( 1 )
return
// Backwards compatibility: combine any configuration from control with flags specified
// on the command line. This is intentionally done after we advertise any routes
// because its never correct to advertise the nodes native IP addresses.
s . mergeConfigFromFlags ( & c , ports , forwards )
s . appc . Configure ( & c )
}
}
}
s . dnsResponses . Add ( 1 )
type server struct {
appc appc . Server
ts * tsnet . Server
lc * tailscale . LocalClient
}
func ( s * server ) serveConn ( c net . Conn ) {
addrPortStr := c . LocalAddr ( ) . String ( )
_ , port , err := net . SplitHostPort ( addrPortStr )
if err != nil {
log . Printf ( "bogus addrPort %q" , addrPortStr )
s . numBadAddrPort . Add ( 1 )
c . Close ( )
return
func ( s * server ) advertiseRoutesFromConfig ( ctx context . Context , c * appctype . AppConnectorConfig ) error {
// Collect the set of addresses to advertise, using a map
// to avoid duplicate entries.
addrs := map [ netip . Addr ] struct { } { }
for _ , c := range c . SNIProxy {
for _ , ip := range c . Addrs {
addrs [ ip ] = struct { } { }
}
}
for _ , c := range c . DNAT {
for _ , ip := range c . Addrs {
addrs [ ip ] = struct { } { }
}
}
var dialer net . Dialer
dialer . Timeout = 5 * time . Second
var p tcpproxy . Proxy
p . ListenFunc = func ( net , laddr string ) ( net . Listener , error ) {
return netutil . NewOneConnListener ( c , nil ) , nil
var routes [ ] netip . Prefix
for a := range addrs {
routes = append ( routes , netip . PrefixFrom ( a , a . BitLen ( ) ) )
}
p . AddSNIRouteFunc ( addrPortStr , func ( ctx context . Context , sniName string ) ( t tcpproxy . Target , ok bool ) {
s . numTLSsessions . Add ( 1 )
return & tcpproxy . DialProxy {
Addr : net . JoinHostPort ( sniName , port ) ,
DialContext : dialer . DialContext ,
} , true
sort . SliceStable ( routes , func ( i , j int ) bool {
return routes [ i ] . Addr ( ) . Less ( routes [ j ] . Addr ( ) ) // determinism r us
} )
p . Start ( )
}
// portNumberToName returns a human-readable name for several port numbers commonly forwarded,
// and "tcp###" for everything else. It is used for metric label names.
func portNumberToName ( forw * portForward ) string {
switch forw . Port {
case 22 :
return "ssh"
case 1433 :
return "sqlserver"
case 3306 :
return "mysql"
case 3389 :
return "rdp"
case 5432 :
return "postgres"
default :
return fmt . Sprintf ( "%s%d" , forw . Proto , forw . Port )
}
_ , err := s . lc . EditPrefs ( ctx , & ipn . MaskedPrefs {
Prefs : ipn . Prefs {
AdvertiseRoutes : routes ,
} ,
AdvertiseRoutesSet : true ,
} )
return err
}
// forwardConn sets up a forwarder for a TCP connection. It does not inspect of the data
// like the SNI forwarding does, it merely forwards all data to the destination specified
// in the --forward=tcp/22/github.com argument.
func ( s * server ) forwardConn ( c net . Conn , forw * portForward ) {
addrPortStr := c . LocalAddr ( ) . String ( )
var dialer net . Dialer
dialer . Timeout = 30 * time . Second
func ( s * server ) mergeConfigFromFlags ( out * appctype . AppConnectorConfig , ports , forwards string ) {
ip4 , ip6 := s . ts . TailscaleIPs ( )
var p tcpproxy . Proxy
p . ListenFunc = func ( net , laddr string ) ( net . Listener , error ) {
return netutil . NewOneConnListener ( c , nil ) , nil
sniConfigFromFlags := appctype . SNIProxyConfig {
Addrs : [ ] netip . Addr { ip4 , ip6 } ,
}
dial := & tcpproxy . DialProxy {
Addr : fmt . Sprintf ( "%s:%d" , forw . Destination , forw . Port ) ,
DialContext : dialer . DialContext ,
if ports != "" {
for _ , portStr := range strings . Split ( ports , "," ) {
port , err := strconv . ParseUint ( portStr , 10 , 16 )
if err != nil {
log . Fatalf ( "invalid port: %s" , portStr )
}
p . AddRoute ( addrPortStr , dial )
s . numTCPsessions . Add ( portNumberToName ( forw ) , 1 )
p . Start ( )
}
func ( s * server ) dnsResponse ( req * dnsmessage . Message ) ( buf [ ] byte , err error ) {
resp := dnsmessage . NewBuilder ( buf ,
dnsmessage . Header {
ID : req . Header . ID ,
Response : true ,
Authoritative : true ,
sniConfigFromFlags . IP = append ( sniConfigFromFlags . IP , tailcfg . ProtoPortRange {
Proto : int ( ipproto . TCP ) ,
Ports : tailcfg . PortRange { First : uint16 ( port ) , Last : uint16 ( port ) } ,
} )
resp . EnableCompression ( )
if len ( req . Questions ) == 0 {
buf , _ = resp . Finish ( )
return
}
}
q := req . Questions [ 0 ]
err = resp . StartQuestions ( )
var forwardConfigFromFlags [ ] appctype . DNATConfig
for _ , forwStr := range strings . Split ( forwards , "," ) {
if forwStr == "" {
continue
}
forw , err := parseForward ( forwStr )
if err != nil {
return
log . Printf ( "invalid forwarding spec: %v" , err )
continue
}
resp . Question ( q )
ip4 , ip6 := s . ts . TailscaleIPs ( )
err = resp . StartAnswers ( )
if err != nil {
return
forwardConfigFromFlags = append ( forwardConfigFromFlags , appctype . DNATConfig {
Addrs : [ ] netip . Addr { ip4 , ip6 } ,
To : [ ] string { forw . Destination } ,
IP : [ ] tailcfg . ProtoPortRange {
{
Proto : int ( ipproto . TCP ) ,
Ports : tailcfg . PortRange { First : uint16 ( forw . Port ) , Last : uint16 ( forw . Port ) } ,
} ,
} ,
} )
}
switch q . Type {
case dnsmessage . TypeAAAA :
err = resp . AAAAResource (
dnsmessage . ResourceHeader { Name : q . Name , Class : q . Class , TTL : 120 } ,
dnsmessage . AAAAResource { AAAA : ip6 . As16 ( ) } ,
)
if len ( forwardConfigFromFlags ) == 0 && len ( sniConfigFromFlags . IP ) == 0 {
return // no config specified on the command line
}
case dnsmessage . TypeA :
err = resp . AResource (
dnsmessage . ResourceHeader { Name : q . Name , Class : q . Class , TTL : 120 } ,
dnsmessage . AResource { A : ip4 . As4 ( ) } ,
)
case dnsmessage . TypeSOA :
err = resp . SOAResource (
dnsmessage . ResourceHeader { Name : q . Name , Class : q . Class , TTL : 120 } ,
dnsmessage . SOAResource { NS : q . Name , MBox : tsMBox , Serial : 2023030600 ,
Refresh : 120 , Retry : 120 , Expire : 120 , MinTTL : 60 } ,
)
case dnsmessage . TypeNS :
err = resp . NSResource (
dnsmessage . ResourceHeader { Name : q . Name , Class : q . Class , TTL : 120 } ,
dnsmessage . NSResource { NS : tsMBox } ,
)
mak . Set ( & out . SNIProxy , "flags" , sniConfigFromFlags )
for i , forward := range forwardConfigFromFlags {
mak . Set ( & out . DNAT , appctype . ConfigID ( fmt . Sprintf ( "flags_%d" , i ) ) , forward )
}
}
func ( s * server ) serveDNS ( ln net . Listener ) {
for {
c , err := ln . Accept ( )
if err != nil {
log . Printf ( "serveDNS accept: %v" , err )
return
}
return resp . Finish ( )
go s . appc . HandleDNS ( c . ( nettype . ConnPacketConn ) )
}
}
func ( s * server ) promoteHTTPS ( ln net . Listener ) {
err := http . Serve ( ln , http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
s . httpPromoted . Add ( 1 )
http . Redirect ( w , r , "https://" + r . Host + r . RequestURI , http . StatusFound )
} ) )
log . Fatalf ( "promoteHTTPS http.Serve: %v" , err )
}
// initMetrics sets up local prometheus metrics, and creates clientmetrics to report those
// same counters.
func ( s * server ) initMetrics ( ) {
stats := new ( metrics . Set )
stats . Set ( "tls_sessions" , & s . numTLSsessions )
clientmetric . NewCounterFunc ( "sniproxy_tls_sessions" , s . numTLSsessions . Value )
s . numTCPsessions = & metrics . LabelMap { Label : "proto" }
stats . Set ( "tcp_sessions" , s . numTCPsessions )
// clientmetric doesn't have a good way to implement a Map type.
// We create clientmetrics dynamically when parsing the --forwards argument
stats . Set ( "bad_addrport" , & s . numBadAddrPort )
clientmetric . NewCounterFunc ( "sniproxy_bad_addrport" , s . numBadAddrPort . Value )
stats . Set ( "dns_responses" , & s . dnsResponses )
clientmetric . NewCounterFunc ( "sniproxy_dns_responses" , s . dnsResponses . Value )
stats . Set ( "dns_failed" , & s . dnsFailures )
clientmetric . NewCounterFunc ( "sniproxy_dns_failed" , s . dnsFailures . Value )
stats . Set ( "http_promoted" , & s . httpPromoted )
clientmetric . NewCounterFunc ( "sniproxy_http_promoted" , s . httpPromoted . Value )
expvar . Publish ( "sniproxy" , stats )
}