diff --git a/net/proxymux/mux.go b/net/proxymux/mux.go new file mode 100644 index 000000000..f16759c49 --- /dev/null +++ b/net/proxymux/mux.go @@ -0,0 +1,145 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package proxymux splits a net.Listener in two, routing SOCKS5 +// connections to one and HTTP requests to the other. +// +// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the +// same listener. +package proxymux + +import ( + "io" + "net" + "sync" + "time" +) + +// SplitSOCKSAndHTTP accepts connections on ln and passes connections +// through to either socksListener or httpListener, depending the +// first byte sent by the client. +func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) { + sl := &listener{ + addr: ln.Addr(), + c: make(chan net.Conn), + closed: make(chan struct{}), + } + hl := &listener{ + addr: ln.Addr(), + c: make(chan net.Conn), + closed: make(chan struct{}), + } + + go splitSOCKSAndHTTPListener(ln, sl, hl) + + return sl, hl +} + +func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) { + for { + conn, err := ln.Accept() + if err != nil { + sl.Close() + hl.Close() + return + } + go routeConn(conn, sl, hl) + } +} + +func routeConn(c net.Conn, socksListener, httpListener *listener) { + if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil { + c.Close() + return + } + + var b [1]byte + if _, err := io.ReadFull(c, b[:]); err != nil { + c.Close() + return + } + + if err := c.SetReadDeadline(time.Time{}); err != nil { + c.Close() + return + } + + conn := &connWithOneByte{ + Conn: c, + b: b[0], + } + + // First byte of a SOCKS5 session is a version byte set to 5. + var ln *listener + if b[0] == 5 { + ln = socksListener + } else { + ln = httpListener + } + select { + case ln.c <- conn: + case <-ln.closed: + c.Close() + } +} + +type listener struct { + addr net.Addr + c chan net.Conn + mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking. + closed chan struct{} +} + +func (ln *listener) Accept() (net.Conn, error) { + // Once closed, reliably stay closed, don't race with attempts at + // further connections. + select { + case <-ln.closed: + return nil, net.ErrClosed + default: + } + select { + case ret := <-ln.c: + return ret, nil + case <-ln.closed: + return nil, net.ErrClosed + } +} + +func (ln *listener) Close() error { + ln.mu.Lock() + defer ln.mu.Unlock() + select { + case <-ln.closed: + // Already closed + default: + close(ln.closed) + } + return nil +} + +func (ln *listener) Addr() net.Addr { + return ln.addr +} + +// connWithOneByte is a net.Conn that returns b for the first read +// request, then forwards everything else to Conn. +type connWithOneByte struct { + net.Conn + + b byte + bRead bool +} + +func (c *connWithOneByte) Read(bs []byte) (int, error) { + if c.bRead { + return c.Conn.Read(bs) + } + if len(bs) == 0 { + return 0, nil + } + c.bRead = true + bs[0] = c.b + return 1, nil +} diff --git a/net/proxymux/mux_test.go b/net/proxymux/mux_test.go new file mode 100644 index 000000000..fd37eb136 --- /dev/null +++ b/net/proxymux/mux_test.go @@ -0,0 +1,172 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxymux + +import ( + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "testing" + + "tailscale.com/net/socks5" +) + +func TestSplitSOCKSAndHTTP(t *testing.T) { + s := mkWorld(t) + defer s.Close() + + s.checkURL(s.httpClient, false) + s.checkURL(s.socksClient, false) +} + +func TestSplitSOCKSAndHTTPCloseSocks(t *testing.T) { + s := mkWorld(t) + defer s.Close() + + s.socksListener.Close() + s.checkURL(s.httpClient, false) + s.checkURL(s.socksClient, true) +} + +func TestSplitSOCKSAndHTTPCloseHTTP(t *testing.T) { + s := mkWorld(t) + defer s.Close() + + s.httpListener.Close() + s.checkURL(s.httpClient, true) + s.checkURL(s.socksClient, false) +} + +func TestSplitSOCKSAndHTTPCloseBoth(t *testing.T) { + s := mkWorld(t) + defer s.Close() + + s.httpListener.Close() + s.socksListener.Close() + s.checkURL(s.httpClient, true) + s.checkURL(s.socksClient, true) +} + +type world struct { + t *testing.T + + // targetListener/target is the HTTP server the client wants to + // reach. It unconditionally responds with HTTP 418 "I'm a + // teapot". + targetListener net.Listener + target http.Server + targetURL string + + // httpListener/httpProxy is an HTTP proxy that can proxy to + // target. + httpListener net.Listener + httpProxy http.Server + + // socksListener/socksProxy is a SOCKS5 proxy that can dial + // targetListener. + socksListener net.Listener + socksProxy *socks5.Server + + // jointListener is the mux that serves both HTTP and SOCKS5 + // proxying. + jointListener net.Listener + + // httpClient and socksClient are HTTP clients configured to proxy + // through httpProxy and socksProxy respectively. + httpClient *http.Client + socksClient *http.Client +} + +func (s *world) checkURL(c *http.Client, wantErr bool) { + s.t.Helper() + resp, err := c.Get(s.targetURL) + if wantErr { + if err == nil { + s.t.Errorf("HTTP request succeeded unexpectedly: got HTTP code %d, wanted failure", resp.StatusCode) + } + } else if err != nil { + s.t.Errorf("HTTP request failed: %v", err) + } else if c := resp.StatusCode; c != http.StatusTeapot { + s.t.Errorf("unexpected status code: got %d, want %d", c, http.StatusTeapot) + } +} + +func (s *world) Close() { + s.jointListener.Close() + s.socksListener.Close() + s.httpProxy.Close() + s.httpListener.Close() + s.target.Close() + s.targetListener.Close() +} + +func mkWorld(t *testing.T) (ret *world) { + t.Helper() + + ret = &world{ + t: t, + } + var err error + + ret.targetListener, err = net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + ret.target = http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + }), + } + go ret.target.Serve(ret.targetListener) + ret.targetURL = fmt.Sprintf("http://%s/", ret.targetListener.Addr().String()) + + ret.jointListener, err = net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + ret.socksListener, ret.httpListener = SplitSOCKSAndHTTP(ret.jointListener) + + httpProxy := http.Server{ + Handler: httputil.NewSingleHostReverseProxy(&url.URL{ + Scheme: "http", + Host: ret.targetListener.Addr().String(), + Path: "/", + }), + } + go httpProxy.Serve(ret.httpListener) + + socksProxy := socks5.Server{} + go socksProxy.Serve(ret.socksListener) + + ret.httpClient = &http.Client{ + Transport: &http.Transport{ + Proxy: func(*http.Request) (*url.URL, error) { + return &url.URL{ + Scheme: "http", + Host: ret.jointListener.Addr().String(), + Path: "/", + }, nil + }, + DisableKeepAlives: true, // one connection per request + }, + } + + ret.socksClient = &http.Client{ + Transport: &http.Transport{ + Proxy: func(*http.Request) (*url.URL, error) { + return &url.URL{ + Scheme: "socks5", + Host: ret.jointListener.Addr().String(), + Path: "/", + }, nil + }, + DisableKeepAlives: true, // one connection per request + }, + } + + return ret +}