diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go new file mode 100644 index 000000000..bf07a1a29 --- /dev/null +++ b/control/controlhttp/client.go @@ -0,0 +1,242 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package controlhttp implements the Tailscale 2021 control protocol +// base transport over HTTP. +// +// This tunnels the protocol in control/controlbase over HTTP with a +// variety of compatibility fallbacks for handling picky or deep +// inspecting proxies. +// +// In the happy path, a client makes a single cleartext HTTP request +// to the server, the server responds with 101 Switching Protocols, +// and the control base protocol takes place over plain TCP. +// +// In the compatibility path, the client does the above over HTTPS, +// resulting in double encryption (once for the control transport, and +// once for the outer TLS layer). +package controlhttp + +import ( + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "net/http/httptrace" + "net/url" + + "tailscale.com/control/controlbase" + "tailscale.com/net/dnscache" + "tailscale.com/net/dnsfallback" + "tailscale.com/net/netns" + "tailscale.com/net/tlsdial" + "tailscale.com/net/tshttpproxy" + "tailscale.com/types/key" +) + +// upgradeHeader is the value of the Upgrade HTTP header used to +// indicate the Tailscale control protocol. +const ( + upgradeHeaderValue = "tailscale-control-protocol" + handshakeHeaderName = "X-Tailscale-Handshake" +) + +// Dial connects to the HTTP server at addr, requests to switch to the +// Tailscale control protocol, and returns an established control +// protocol connection. +// +// If Dial fails to connect using addr, it also tries to tunnel over +// TLS to :443 as a compatibility fallback. +func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*controlbase.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + a := &dialParams{ + ctx: ctx, + host: host, + httpPort: port, + httpsPort: "443", + machineKey: machineKey, + controlKey: controlKey, + proxyFunc: tshttpproxy.ProxyFromEnvironment, + } + return a.dial() +} + +type dialParams struct { + ctx context.Context + host string + httpPort string + httpsPort string + machineKey key.MachinePrivate + controlKey key.MachinePublic + proxyFunc func(*http.Request) (*url.URL, error) // or nil + + // For tests only + insecureTLS bool +} + +func (a *dialParams) dial() (*controlbase.Conn, error) { + init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey) + if err != nil { + return nil, err + } + + u := &url.URL{ + Scheme: "http", + Host: net.JoinHostPort(a.host, a.httpPort), + Path: "/switch", + } + conn, httpErr := a.tryURL(u, init) + if httpErr == nil { + ret, err := cont(a.ctx, conn) + if err != nil { + conn.Close() + return nil, err + } + return ret, nil + } + + // Connecting over plain HTTP failed, assume it's an HTTP proxy + // being difficult and see if we can get through over HTTPS. + u.Scheme = "https" + u.Host = net.JoinHostPort(a.host, a.httpsPort) + init, cont, err = controlbase.ClientDeferred(a.machineKey, a.controlKey) + if err != nil { + return nil, err + } + conn, tlsErr := a.tryURL(u, init) + if tlsErr == nil { + ret, err := cont(a.ctx, conn) + if err != nil { + conn.Close() + return nil, err + } + return ret, nil + } + + return nil, fmt.Errorf("all connection attempts failed (HTTP: %v, HTTPS: %v)", httpErr, tlsErr) +} + +func (a *dialParams) tryURL(u *url.URL, init []byte) (net.Conn, error) { + dns := &dnscache.Resolver{ + Forward: dnscache.Get().Forward, + LookupIPFallback: dnsfallback.Lookup, + UseLastGood: true, + } + dialer := netns.NewDialer(log.Printf) + tr := http.DefaultTransport.(*http.Transport).Clone() + defer tr.CloseIdleConnections() + tr.Proxy = a.proxyFunc + tshttpproxy.SetTransportGetProxyConnectHeader(tr) + tr.DialContext = dnscache.Dialer(dialer.DialContext, dns) + // Disable HTTP2, since h2 can't do protocol switching. + tr.TLSClientConfig.NextProtos = []string{} + tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} + tr.TLSClientConfig = tlsdial.Config(a.host, tr.TLSClientConfig) + if a.insecureTLS { + tr.TLSClientConfig.InsecureSkipVerify = true + tr.TLSClientConfig.VerifyConnection = nil + } + tr.DialTLSContext = dnscache.TLSDialer(dialer.DialContext, dns, tr.TLSClientConfig) + tr.DisableCompression = true + + // (mis)use httptrace to extract the underlying net.Conn from the + // transport. We make exactly 1 request using this transport, so + // there will be exactly 1 GotConn call. Additionally, the + // transport handles 101 Switching Protocols correctly, such that + // the Conn will not be reused or kept alive by the transport once + // the response has been handed back from RoundTrip. + // + // In theory, the machinery of net/http should make it such that + // the trace callback happens-before we get the response, but + // there's no promise of that. So, to make sure, we use a buffered + // channel as a synchronization step to avoid data races. + // + // Note that even though we're able to extract a net.Conn via this + // mechanism, we must still keep using the eventual resp.Body to + // read from, because it includes a buffer we can't get rid of. If + // the server never sends any data after sending the HTTP + // response, we could get away with it, but violating this + // assumption leads to very mysterious transport errors (lockups, + // unexpected EOFs...), and we're bound to forget someday and + // introduce a protocol optimization at a higher level that starts + // eagerly transmitting from the server. + connCh := make(chan net.Conn, 1) + trace := httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + connCh <- info.Conn + }, + } + ctx := httptrace.WithClientTrace(a.ctx, &trace) + req := &http.Request{ + Method: "POST", + URL: u, + Header: http.Header{ + "Upgrade": []string{upgradeHeaderValue}, + "Connection": []string{"upgrade"}, + handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, + }, + } + req = req.WithContext(ctx) + + resp, err := tr.RoundTrip(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, fmt.Errorf("unexpected HTTP response: %s", resp.Status) + } + + // From here on, the underlying net.Conn is ours to use, but there + // is still a read buffer attached to it within resp.Body. So, we + // must direct I/O through resp.Body, but we can still use the + // underlying net.Conn for stuff like deadlines. + var switchedConn net.Conn + select { + case switchedConn = <-connCh: + default: + } + if switchedConn == nil { + resp.Body.Close() + return nil, fmt.Errorf("httptrace didn't provide a connection") + } + + if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue { + resp.Body.Close() + return nil, fmt.Errorf("server switched to unexpected protocol %q", next) + } + + rwc, ok := resp.Body.(io.ReadWriteCloser) + if !ok { + resp.Body.Close() + return nil, errors.New("http Transport did not provide a writable body") + } + + return &wrappedConn{switchedConn, rwc}, nil +} + +type wrappedConn struct { + net.Conn + rwc io.ReadWriteCloser +} + +func (w *wrappedConn) Read(bs []byte) (int, error) { + return w.rwc.Read(bs) +} + +func (w *wrappedConn) Write(bs []byte) (int, error) { + return w.rwc.Write(bs) +} + +func (w *wrappedConn) Close() error { + return w.rwc.Close() +} diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go new file mode 100644 index 000000000..799eb1b19 --- /dev/null +++ b/control/controlhttp/http_test.go @@ -0,0 +1,398 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlhttp + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "log" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strconv" + "sync" + "testing" + + "tailscale.com/control/controlbase" + "tailscale.com/net/socks5" + "tailscale.com/types/key" +) + +func TestControlHTTP(t *testing.T) { + tests := []struct { + name string + proxy proxy + }{ + // direct connection + { + name: "no_proxy", + proxy: nil, + }, + // SOCKS5 + { + name: "socks5", + proxy: &socksProxy{}, + }, + // HTTP->HTTP + { + name: "http_to_http", + proxy: &httpProxy{ + useTLS: false, + allowConnect: false, + allowHTTP: true, + }, + }, + // HTTP->HTTPS + { + name: "http_to_https", + proxy: &httpProxy{ + useTLS: false, + allowConnect: true, + allowHTTP: false, + }, + }, + // HTTP->any (will pick HTTP) + { + name: "http_to_any", + proxy: &httpProxy{ + useTLS: false, + allowConnect: true, + allowHTTP: true, + }, + }, + // HTTPS->HTTP + { + name: "https_to_http", + proxy: &httpProxy{ + useTLS: true, + allowConnect: false, + allowHTTP: true, + }, + }, + // HTTPS->HTTPS + { + name: "https_to_https", + proxy: &httpProxy{ + useTLS: true, + allowConnect: true, + allowHTTP: false, + }, + }, + // HTTPS->any (will pick HTTP) + { + name: "https_to_any", + proxy: &httpProxy{ + useTLS: true, + allowConnect: true, + allowHTTP: true, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testControlHTTP(t, test.proxy) + }) + } +} + +func testControlHTTP(t *testing.T, proxy proxy) { + client, server := key.NewMachine(), key.NewMachine() + + sch := make(chan serverResult, 1) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := AcceptHTTP(context.Background(), w, r, server) + if err != nil { + log.Print(err) + } + res := serverResult{ + err: err, + } + if conn != nil { + res.clientAddr = conn.RemoteAddr().String() + res.version = conn.ProtocolVersion() + res.peer = conn.Peer() + res.conn = conn + } + sch <- res + }) + + httpLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("HTTP listen: %v", err) + } + httpsLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("HTTPS listen: %v", err) + } + + httpServer := &http.Server{Handler: handler} + go httpServer.Serve(httpLn) + defer httpServer.Close() + + httpsServer := &http.Server{ + Handler: handler, + TLSConfig: tlsConfig(t), + } + go httpsServer.ServeTLS(httpsLn, "", "") + defer httpsServer.Close() + + //ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + //defer cancel() + + a := dialParams{ + ctx: context.Background(), //ctx, + host: "localhost", + httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port), + httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port), + machineKey: client, + controlKey: server.Public(), + insecureTLS: true, + } + + if proxy != nil { + proxyEnv := proxy.Start(t) + defer proxy.Close() + proxyURL, err := url.Parse(proxyEnv) + if err != nil { + t.Fatal(err) + } + a.proxyFunc = func(*http.Request) (*url.URL, error) { + return proxyURL, nil + } + } else { + a.proxyFunc = func(*http.Request) (*url.URL, error) { + return nil, nil + } + } + + conn, err := a.dial() + if err != nil { + t.Fatalf("dialing controlhttp: %v", err) + } + defer conn.Close() + si := <-sch + if si.conn != nil { + defer si.conn.Close() + } + if si.err != nil { + t.Fatalf("controlhttp server got error: %v", err) + } + if clientVersion := conn.ProtocolVersion(); si.version != clientVersion { + t.Fatalf("client and server don't agree on protocol version: %d vs %d", clientVersion, si.version) + } + if si.peer != client.Public() { + t.Fatalf("server got peer pubkey %s, want %s", si.peer, client.Public()) + } + if spub := conn.Peer(); spub != server.Public() { + t.Fatalf("client got peer pubkey %s, want %s", spub, server.Public()) + } + if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) { + t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr) + } +} + +type serverResult struct { + err error + clientAddr string + version int + peer key.MachinePublic + conn *controlbase.Conn +} + +type proxy interface { + Start(*testing.T) string + Close() + ConnIsFromProxy(string) bool +} + +type socksProxy struct { + sync.Mutex + proxy socks5.Server + ln net.Listener + clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy +} + +func (s *socksProxy) Start(t *testing.T) (url string) { + t.Helper() + s.Lock() + defer s.Unlock() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listening for SOCKS server: %v", err) + } + s.ln = ln + s.clientConnAddrs = map[string]bool{} + s.proxy.Logf = t.Logf + s.proxy.Dialer = s.dialAndRecord + go s.proxy.Serve(ln) + return fmt.Sprintf("socks5://%s", ln.Addr().String()) +} + +func (s *socksProxy) Close() { + s.Lock() + defer s.Unlock() + s.ln.Close() +} + +func (s *socksProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + s.Lock() + defer s.Unlock() + s.clientConnAddrs[conn.LocalAddr().String()] = true + return conn, nil +} + +func (s *socksProxy) ConnIsFromProxy(addr string) bool { + s.Lock() + defer s.Unlock() + return s.clientConnAddrs[addr] +} + +type httpProxy struct { + useTLS bool // take incoming connections over TLS + allowConnect bool // allow CONNECT for TLS + allowHTTP bool // allow plain HTTP proxying + + sync.Mutex + ln net.Listener + rp httputil.ReverseProxy + s http.Server + clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy +} + +func (h *httpProxy) Start(t *testing.T) (url string) { + t.Helper() + h.Lock() + defer h.Unlock() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listening for HTTP proxy: %v", err) + } + h.ln = ln + h.rp = httputil.ReverseProxy{ + Director: func(*http.Request) {}, + Transport: &http.Transport{ + DialContext: h.dialAndRecord, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + TLSNextProto: map[string]func(string, *tls.Conn) http.RoundTripper{}, + }, + } + h.clientConnAddrs = map[string]bool{} + h.s.Handler = h + if h.useTLS { + h.s.TLSConfig = tlsConfig(t) + go h.s.ServeTLS(h.ln, "", "") + return fmt.Sprintf("https://%s", ln.Addr().String()) + } else { + go h.s.Serve(h.ln) + return fmt.Sprintf("http://%s", ln.Addr().String()) + } +} + +func (h *httpProxy) Close() { + h.Lock() + defer h.Unlock() + h.s.Close() +} + +func (h *httpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != "CONNECT" { + if !h.allowHTTP { + http.Error(w, "http proxy not allowed", 500) + return + } + h.rp.ServeHTTP(w, r) + return + } + + if !h.allowConnect { + http.Error(w, "connect not allowed", 500) + return + } + + dst := r.RequestURI + c, err := h.dialAndRecord(context.Background(), "tcp", dst) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + defer c.Close() + + cc, ccbuf, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), 500) + return + } + defer cc.Close() + + io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n") + + errc := make(chan error, 1) + go func() { + _, err := io.Copy(cc, c) + errc <- err + }() + go func() { + _, err := io.Copy(c, ccbuf) + errc <- err + }() + <-errc +} + +func (h *httpProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + h.Lock() + defer h.Unlock() + h.clientConnAddrs[conn.LocalAddr().String()] = true + return conn, nil +} + +func (h *httpProxy) ConnIsFromProxy(addr string) bool { + h.Lock() + defer h.Unlock() + return h.clientConnAddrs[addr] +} + +func tlsConfig(t *testing.T) *tls.Config { + // Cert and key taken from the example code in the crypto/tls + // package. + certPem := []byte(`-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----`) + keyPem := []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 +AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q +EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== +-----END EC PRIVATE KEY-----`) + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + t.Fatal(err) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + } +} diff --git a/control/controlhttp/server.go b/control/controlhttp/server.go new file mode 100644 index 000000000..92bd9ec9b --- /dev/null +++ b/control/controlhttp/server.go @@ -0,0 +1,95 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlhttp + +import ( + "bufio" + "context" + "encoding/base64" + "errors" + "fmt" + "net" + "net/http" + + "tailscale.com/control/controlbase" + "tailscale.com/types/key" +) + +// AcceptHTTP upgrades the HTTP request given by w and r into a +// Tailscale control protocol base transport connection. +// +// AcceptHTTP always writes an HTTP response to w. The caller must not +// attempt their own response after calling AcceptHTTP. +func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) { + next := r.Header.Get("Upgrade") + if next == "" { + http.Error(w, "missing next protocol", http.StatusBadRequest) + return nil, errors.New("no next protocol in HTTP request") + } + if next != upgradeHeaderValue { + http.Error(w, "unknown next protocol", http.StatusBadRequest) + return nil, fmt.Errorf("client requested unhandled next protocol %q", next) + } + + initB64 := r.Header.Get(handshakeHeaderName) + if initB64 == "" { + http.Error(w, "missing Tailscale handshake header", http.StatusBadRequest) + return nil, errors.New("no tailscale handshake header in HTTP request") + } + init, err := base64.StdEncoding.DecodeString(initB64) + if err != nil { + http.Error(w, "invalid tailscale handshake header", http.StatusBadRequest) + return nil, fmt.Errorf("decoding base64 handshake header: %v", err) + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "make request over HTTP/1", http.StatusBadRequest) + return nil, errors.New("can't hijack client connection") + } + + w.Header().Set("Upgrade", upgradeHeaderValue) + w.Header().Set("Connection", "upgrade") + w.WriteHeader(http.StatusSwitchingProtocols) + + conn, brw, err := hijacker.Hijack() + if err != nil { + return nil, fmt.Errorf("hijacking client connection: %w", err) + } + if err := brw.Flush(); err != nil { + conn.Close() + return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err) + } + if brw.Reader.Buffered() > 0 { + conn = &drainBufConn{conn, brw.Reader} + } + + nc, err := controlbase.Server(ctx, conn, private, init) + if err != nil { + conn.Close() + return nil, fmt.Errorf("noise handshake failed: %w", err) + } + + return nc, nil +} + +// drainBufConn is a net.Conn with an initial bunch of bytes in a +// bufio.Reader. Read drains the bufio.Reader until empty, then passes +// through subsequent reads to the Conn directly. +type drainBufConn struct { + net.Conn + r *bufio.Reader +} + +func (b *drainBufConn) Read(bs []byte) (int, error) { + if b.r == nil { + return b.Conn.Read(bs) + } + n, err := b.r.Read(bs) + if b.r.Buffered() == 0 { + b.r = nil + } + return n, err +}