cmd/tailscaled, wgengine: remove --fake, replace with netstack

And add a --socks5-server flag.

And fix a race in SOCKS5 replies where the response header was written
concurrently with the copy from the backend.

Co-authored with Naman Sood.

Updates #707
Updates #504

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
naman/netstack-incoming
Brad Fitzpatrick 4 years ago
parent d74cddcc56
commit 38dc6fe758

@ -96,7 +96,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
💣 tailscale.com/net/netstat from tailscale.com/ipn/ipnserver 💣 tailscale.com/net/netstat from tailscale.com/ipn/ipnserver
tailscale.com/net/packet from tailscale.com/wgengine+ tailscale.com/net/packet from tailscale.com/wgengine+
tailscale.com/net/portmapper from tailscale.com/net/netcheck+ tailscale.com/net/portmapper from tailscale.com/net/netcheck+
tailscale.com/net/socks5 from tailscale.com/wgengine/netstack tailscale.com/net/socks5 from tailscale.com/cmd/tailscaled
tailscale.com/net/stun from tailscale.com/net/netcheck+ tailscale.com/net/stun from tailscale.com/net/netcheck+
tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ tailscale.com/net/tlsdial from tailscale.com/control/controlclient+
tailscale.com/net/tsaddr from tailscale.com/ipn/ipnlocal+ tailscale.com/net/tsaddr from tailscale.com/ipn/ipnlocal+

@ -14,6 +14,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"net"
"net/http" "net/http"
"net/http/pprof" "net/http/pprof"
"os" "os"
@ -21,19 +22,23 @@ import (
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strconv" "strconv"
"sync"
"syscall" "syscall"
"time" "time"
"tailscale.com/ipn/ipnserver" "tailscale.com/ipn/ipnserver"
"tailscale.com/logpolicy" "tailscale.com/logpolicy"
"tailscale.com/net/socks5"
"tailscale.com/paths" "tailscale.com/paths"
"tailscale.com/types/flagtype" "tailscale.com/types/flagtype"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/version" "tailscale.com/version"
"tailscale.com/wgengine" "tailscale.com/wgengine"
"tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/magicsock"
"tailscale.com/wgengine/netstack" "tailscale.com/wgengine/netstack"
"tailscale.com/wgengine/router" "tailscale.com/wgengine/router"
"tailscale.com/wgengine/tstun"
) )
// globalStateKey is the ipn.StateKey that tailscaled loads on // globalStateKey is the ipn.StateKey that tailscaled loads on
@ -62,13 +67,13 @@ func defaultTunName() string {
var args struct { var args struct {
cleanup bool cleanup bool
fake bool
debug string debug string
tunname string tunname string
port uint16 port uint16
statepath string statepath string
socketpath string socketpath string
verbose int verbose int
socksAddr string // listen address for SOCKS5 server
} }
var ( var (
@ -94,9 +99,9 @@ func main() {
printVersion := false printVersion := false
flag.IntVar(&args.verbose, "verbose", 0, "log verbosity level; 0 is default, 1 or higher are increasingly verbose") flag.IntVar(&args.verbose, "verbose", 0, "log verbosity level; 0 is default, 1 or higher are increasingly verbose")
flag.BoolVar(&args.cleanup, "cleanup", false, "clean up system state and exit") flag.BoolVar(&args.cleanup, "cleanup", false, "clean up system state and exit")
flag.BoolVar(&args.fake, "fake", false, "use userspace fake tunnel+routing instead of kernel TUN interface")
flag.StringVar(&args.debug, "debug", "", "listen address ([ip]:port) of optional debug server") flag.StringVar(&args.debug, "debug", "", "listen address ([ip]:port) of optional debug server")
flag.StringVar(&args.tunname, "tun", defaultTunName(), "tunnel interface name") flag.StringVar(&args.socksAddr, "socks5-server", "", `optional [ip]:port to run a SOCK5 server (e.g. "localhost:1080")`)
flag.StringVar(&args.tunname, "tun", defaultTunName(), `tunnel interface name; use "userspace-networking" (beta) to not use TUN`)
flag.Var(flagtype.PortValue(&args.port, magicsock.DefaultPort), "port", "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") flag.Var(flagtype.PortValue(&args.port, magicsock.DefaultPort), "port", "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
flag.StringVar(&args.statepath, "state", paths.DefaultTailscaledStateFile(), "path of state file") flag.StringVar(&args.statepath, "state", paths.DefaultTailscaledStateFile(), "path of state file")
flag.StringVar(&args.socketpath, "socket", paths.DefaultTailscaledSocket(), "path of the service unix socket") flag.StringVar(&args.socketpath, "socket", paths.DefaultTailscaledSocket(), "path of the service unix socket")
@ -190,23 +195,73 @@ func run() error {
go runDebugServer(debugMux, args.debug) go runDebugServer(debugMux, args.debug)
} }
var e wgengine.Engine var socksListener net.Listener
if args.fake { if args.socksAddr != "" {
var impl wgengine.FakeImplFactory var err error
if args.tunname == "userspace-networking" { socksListener, err = net.Listen("tcp", args.socksAddr)
impl = netstack.Create if err != nil {
log.Fatalf("SOCKS5 listener: %v", err)
} }
e, err = wgengine.NewFakeUserspaceEngine(logf, 0, impl) }
conf := wgengine.Config{
ListenPort: args.port,
}
if args.tunname == "userspace-networking" {
conf.TUN = tstun.NewFakeTUN()
conf.RouterGen = router.NewFake
} else { } else {
e, err = wgengine.NewUserspaceEngine(logf, wgengine.Config{ conf.TUNName = args.tunname
TUNName: args.tunname,
ListenPort: args.port,
})
} }
e, err := wgengine.NewUserspaceEngine(logf, conf)
if err != nil { if err != nil {
logf("wgengine.New: %v", err) logf("wgengine.New: %v", err)
return err return err
} }
var ns *netstack.Impl
if args.tunname == "userspace-networking" {
tunDev, magicConn := e.(wgengine.InternalsGetter).GetInternals()
ns, err = netstack.Create(logf, tunDev, e, magicConn)
if err != nil {
log.Fatalf("netstack.Create: %v", err)
}
if err := ns.Start(); err != nil {
log.Fatalf("failed to start netstack: %v", err)
}
}
if socksListener != nil {
srv := &socks5.Server{
Logf: logger.WithPrefix(logf, "socks5: "),
}
if args.tunname == "userspace-networking" {
srv.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
return ns.DialContextTCP(ctx, addr)
}
} else {
var mu sync.Mutex
var dns netstack.DNSMap
e.AddNetworkMapCallback(func(nm *netmap.NetworkMap) {
mu.Lock()
defer mu.Unlock()
dns = netstack.DNSMapFromNetworkMap(nm)
})
srv.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
ipp, err := dns.Resolve(ctx, addr)
if err != nil {
return nil, err
}
var d net.Dialer
return d.DialContext(ctx, network, ipp.String())
}
}
go func() {
log.Fatalf("SOCKS5 server exited: %v", srv.Serve(socksListener))
}()
}
e = wgengine.NewWatchdog(e) e = wgengine.NewWatchdog(e)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())

@ -41,7 +41,7 @@ func TestLocalLogLines(t *testing.T) {
// set up a LocalBackend, super bare bones. No functional data. // set up a LocalBackend, super bare bones. No functional data.
store := &ipn.MemoryStore{} store := &ipn.MemoryStore{}
e, err := wgengine.NewFakeUserspaceEngine(logListen.Logf, 0, nil) e, err := wgengine.NewFakeUserspaceEngine(logListen.Logf, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -56,7 +56,7 @@ func TestRunMultipleAccepts(t *testing.T) {
} }
} }
eng, err := wgengine.NewFakeUserspaceEngine(logf, 0, nil) eng, err := wgengine.NewFakeUserspaceEngine(logf, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -108,7 +108,7 @@ func (s *Server) Serve(l net.Listener) error {
conn := &Conn{clientConn: c, srv: s} conn := &Conn{clientConn: c, srv: s}
err := conn.Run() err := conn.Run()
if err != nil { if err != nil {
s.logf("socks5: client connection failed: %v", err) s.logf("client connection failed: %v", err)
conn.clientConn.Close() conn.clientConn.Close()
} }
}() }()
@ -123,7 +123,6 @@ type Conn struct {
srv *Server srv *Server
clientConn net.Conn clientConn net.Conn
serverConn net.Conn
request *request request *request
} }
@ -153,11 +152,7 @@ func (c *Conn) handleRequest() error {
return fmt.Errorf("unsupported command %v", req.command) return fmt.Errorf("unsupported command %v", req.command)
} }
c.request = req c.request = req
return c.createReply()
}
func (c *Conn) createReply() error {
var err error
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
srv, err := c.srv.dial( srv, err := c.srv.dial(
@ -171,14 +166,12 @@ func (c *Conn) createReply() error {
c.clientConn.Write(buf) c.clientConn.Write(buf)
return err return err
} }
c.serverConn = srv defer srv.Close()
serverAddr, serverPortStr, err := net.SplitHostPort(c.serverConn.LocalAddr().String()) serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String())
if err != nil { if err != nil {
return err return err
} }
serverPort, _ := strconv.Atoi(serverPortStr) serverPort, _ := strconv.Atoi(serverPortStr)
go io.Copy(c.clientConn, c.serverConn)
go io.Copy(c.serverConn, c.clientConn)
var bindAddrType addrType var bindAddrType addrType
if ip := net.ParseIP(serverAddr); ip != nil { if ip := net.ParseIP(serverAddr); ip != nil {
@ -190,7 +183,6 @@ func (c *Conn) createReply() error {
} else { } else {
bindAddrType = domainName bindAddrType = domainName
} }
res := &response{ res := &response{
reply: success, reply: success,
bindAddrType: bindAddrType, bindAddrType: bindAddrType,
@ -203,7 +195,23 @@ func (c *Conn) createReply() error {
buf, _ = res.marshal() buf, _ = res.marshal()
} }
c.clientConn.Write(buf) c.clientConn.Write(buf)
return err
errc := make(chan error, 2)
go func() {
_, err := io.Copy(c.clientConn, srv)
if err != nil {
err = fmt.Errorf("from backend to client: %w", err)
}
errc <- err
}()
go func() {
_, err := io.Copy(srv, c.clientConn)
if err != nil {
err = fmt.Errorf("from client to backend: %w", err)
}
errc <- err
}()
return <-errc
} }
// parseClientGreeting parses a request initiation packet // parseClientGreeting parses a request initiation packet

@ -33,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/net/socks5"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
@ -55,13 +54,13 @@ type Impl struct {
logf logger.Logf logf logger.Logf
mu sync.Mutex mu sync.Mutex
dns map[string]netaddr.IP // Magic DNS names (both base + FQDN) => first IP dns DNSMap
} }
const nicID = 1 const nicID = 1
// Create creates and populates a new Impl. // Create creates and populates a new Impl.
func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (wgengine.FakeImpl, error) { func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) {
if mc == nil { if mc == nil {
return nil, errors.New("nil magicsock.Conn") return nil, errors.New("nil magicsock.Conn")
} }
@ -121,33 +120,40 @@ func (ns *Impl) Start() error {
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket) ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket)
go ns.injectOutbound() go ns.injectOutbound()
ns.tundev.PostFilterIn = ns.injectInbound ns.tundev.PostFilterIn = ns.injectInbound
go ns.socks5Server()
return nil return nil
} }
func (ns *Impl) updateDNS(nm *netmap.NetworkMap) { // DNSMap maps MagicDNS names (both base + FQDN) to their first IP.
ns.mu.Lock() // It should not be mutated once created.
defer ns.mu.Unlock() type DNSMap map[string]netaddr.IP
ns.dns = make(map[string]netaddr.IP)
func DNSMapFromNetworkMap(nm *netmap.NetworkMap) DNSMap {
ret := make(DNSMap)
suffix := nm.MagicDNSSuffix() suffix := nm.MagicDNSSuffix()
if nm.Name != "" && len(nm.Addresses) > 0 { if nm.Name != "" && len(nm.Addresses) > 0 {
ip := nm.Addresses[0].IP ip := nm.Addresses[0].IP
ns.dns[strings.TrimRight(nm.Name, ".")] = ip ret[strings.TrimRight(nm.Name, ".")] = ip
if dnsname.HasSuffix(nm.Name, suffix) { if dnsname.HasSuffix(nm.Name, suffix) {
ns.dns[dnsname.TrimSuffix(nm.Name, suffix)] = ip ret[dnsname.TrimSuffix(nm.Name, suffix)] = ip
} }
} }
for _, p := range nm.Peers { for _, p := range nm.Peers {
if p.Name != "" && len(p.Addresses) > 0 { if p.Name != "" && len(p.Addresses) > 0 {
ip := p.Addresses[0].IP ip := p.Addresses[0].IP
ns.dns[strings.TrimRight(p.Name, ".")] = ip ret[strings.TrimRight(p.Name, ".")] = ip
if dnsname.HasSuffix(p.Name, suffix) { if dnsname.HasSuffix(p.Name, suffix) {
ns.dns[dnsname.TrimSuffix(p.Name, suffix)] = ip ret[dnsname.TrimSuffix(p.Name, suffix)] = ip
} }
} }
} }
return ret
}
func (ns *Impl) updateDNS(nm *netmap.NetworkMap) {
ns.mu.Lock()
defer ns.mu.Unlock()
ns.dns = DNSMapFromNetworkMap(nm)
} }
func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { func (ns *Impl) updateIPs(nm *netmap.NetworkMap) {
@ -198,8 +204,9 @@ func (ns *Impl) updateIPs(nm *netmap.NetworkMap) {
} }
} }
// resolve resolves addr into an IP:port. // Resolve resolves addr into an IP:port using first the MagicDNS contents
func (ns *Impl) resolve(ctx context.Context, addr string) (netaddr.IPPort, error) { // of m, else using the system resolver.
func (m DNSMap) Resolve(ctx context.Context, addr string) (netaddr.IPPort, error) {
ipp, pippErr := netaddr.ParseIPPort(addr) ipp, pippErr := netaddr.ParseIPPort(addr)
if pippErr == nil { if pippErr == nil {
return ipp, nil return ipp, nil
@ -222,9 +229,7 @@ func (ns *Impl) resolve(ctx context.Context, addr string) (netaddr.IPPort, error
// Host is not an IP, so assume it's a DNS name. // Host is not an IP, so assume it's a DNS name.
// Try MagicDNS first, else otherwise a real DNS lookup. // Try MagicDNS first, else otherwise a real DNS lookup.
ns.mu.Lock() ip := m[host]
ip := ns.dns[host]
ns.mu.Unlock()
if !ip.IsZero() { if !ip.IsZero() {
return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
} }
@ -242,8 +247,12 @@ func (ns *Impl) resolve(ctx context.Context, addr string) (netaddr.IPPort, error
return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
} }
func (ns *Impl) dialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) { func (ns *Impl) DialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) {
remoteIPPort, err := ns.resolve(ctx, addr) ns.mu.Lock()
dnsMap := ns.dns
ns.mu.Unlock()
remoteIPPort, err := dnsMap.Resolve(ctx, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -342,7 +351,7 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, address stri
} }
cancel() cancel()
}() }()
server, err := ns.dialContextTCP(ctx, address) server, err := ns.DialContextTCP(ctx, address)
if err != nil { if err != nil {
ns.logf("netstack: could not connect to server %s: %s", address, err) ns.logf("netstack: could not connect to server %s: %s", address, err)
return return
@ -361,21 +370,6 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, address stri
ns.logf("[v2] netstack: forwarder connection to %s closed", address) ns.logf("[v2] netstack: forwarder connection to %s closed", address)
} }
func (ns *Impl) socks5Server() {
ln, err := net.Listen("tcp", "localhost:1080")
if err != nil {
ns.logf("could not start SOCKS5 listener: %v", err)
return
}
srv := &socks5.Server{
Logf: ns.logf,
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
return ns.dialContextTCP(ctx, addr)
},
}
ns.logf("SOCKS5 server exited: %v", srv.Serve(ln))
}
func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
ns.logf("[v2] UDP ForwarderRequest: %v", r) ns.logf("[v2] UDP ForwarderRequest: %v", r)
var wq waiter.Queue var wq waiter.Queue

@ -131,6 +131,15 @@ type userspaceEngine struct {
// Lock ordering: magicsock.Conn.mu, wgLock, then mu. // Lock ordering: magicsock.Conn.mu, wgLock, then mu.
} }
// InternalsGetter is implemented by Engines that can export their internals.
type InternalsGetter interface {
GetInternals() (*tstun.TUN, *magicsock.Conn)
}
func (e *userspaceEngine) GetInternals() (*tstun.TUN, *magicsock.Conn) {
return e.tundev, e.magicConn
}
// RouterGen is the signature for a function that creates a // RouterGen is the signature for a function that creates a
// router.Router. // router.Router.
type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error) type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error)
@ -157,36 +166,18 @@ type Config struct {
// If zero, a port is automatically selected. // If zero, a port is automatically selected.
ListenPort uint16 ListenPort uint16
// Fake determines whether this engine is running in fake mode, // Fake determines whether this engine should automatically
// which disables such features as DNS configuration and unrestricted ICMP Echo responses. // reply to ICMP pings.
Fake bool Fake bool
// FakeImplFactory, if non-nil, creates a FakeImpl to use as a fake engine
// implementation. Two values are typical: nil, for a basic ping-only fake
// implementation, and netstack.Create, which creates a userspace network
// stack using gvisor's netstack. The desire to keep netstack out of some
// binaries is why the FakeImpl interface exists, so wgengine need not
// depend on gvisor.
FakeImplFactory FakeImplFactory
}
// FakeImpl is a fake or alternate version of Engine that can be started. See
// Config.FakeImplFactory for details.
type FakeImpl interface {
Start() error
} }
// FakeImplFactory is the type of a function used to create FakeImpls. func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) {
type FakeImplFactory func(logger.Logf, *tstun.TUN, Engine, *magicsock.Conn) (FakeImpl, error)
func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16, impl FakeImplFactory) (Engine, error) {
logf("Starting userspace wireguard engine (with fake TUN device)") logf("Starting userspace wireguard engine (with fake TUN device)")
return NewUserspaceEngine(logf, Config{ return NewUserspaceEngine(logf, Config{
TUN: tstun.NewFakeTUN(), TUN: tstun.NewFakeTUN(),
RouterGen: router.NewFake, RouterGen: router.NewFake,
ListenPort: listenPort, ListenPort: listenPort,
Fake: true, Fake: true,
FakeImplFactory: impl,
}) })
} }
@ -292,18 +283,7 @@ func newUserspaceEngine(logf logger.Logf, rawTUNDev tun.Device, conf Config) (_
// Respond to all pings only in fake mode. // Respond to all pings only in fake mode.
if conf.Fake { if conf.Fake {
if f := conf.FakeImplFactory; f != nil { e.tundev.PostFilterIn = echoRespondToAll
impl, err := f(logf, e.tundev, e, e.magicConn)
if err != nil {
return nil, err
}
if err := impl.Start(); err != nil {
return nil, err
}
} else {
// Respond to all pings only in fake mode.
e.tundev.PostFilterIn = echoRespondToAll
}
} }
e.tundev.PreFilterOut = e.handleLocalPackets e.tundev.PreFilterOut = e.handleLocalPackets

@ -84,7 +84,7 @@ func TestNoteReceiveActivity(t *testing.T) {
} }
func TestUserspaceEngineReconfig(t *testing.T) { func TestUserspaceEngineReconfig(t *testing.T) {
e, err := NewFakeUserspaceEngine(t.Logf, 0, nil) e, err := NewFakeUserspaceEngine(t.Logf, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -17,7 +17,7 @@ func TestWatchdog(t *testing.T) {
t.Run("default watchdog does not fire", func(t *testing.T) { t.Run("default watchdog does not fire", func(t *testing.T) {
t.Parallel() t.Parallel()
e, err := NewFakeUserspaceEngine(t.Logf, 0, nil) e, err := NewFakeUserspaceEngine(t.Logf, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -35,7 +35,7 @@ func TestWatchdog(t *testing.T) {
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) { t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
t.Parallel() t.Parallel()
e, err := NewFakeUserspaceEngine(t.Logf, 0, nil) e, err := NewFakeUserspaceEngine(t.Logf, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

Loading…
Cancel
Save