diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 394ffbf6c..dbde63aa7 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -100,6 +100,8 @@ type Server struct { // If empty, the Tailscale default is used. ControlURL string + getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error) + initOnce sync.Once initErr error lb *ipnlocal.LocalBackend @@ -842,18 +844,28 @@ func (s *Server) ListenTLS(network, addr string) (net.Listener, error) { return nil, errors.New("tsnet: you must enable HTTPS in the admin panel to proceed. See https://tailscale.com/s/https") } - lc, err := s.LocalClient() // do local client first before listening. + ln, err := s.listen(network, addr, listenOnTailnet) if err != nil { return nil, err } + return tls.NewListener(ln, &tls.Config{ + GetCertificate: s.getCert, + }), nil +} - ln, err := s.listen(network, addr, listenOnTailnet) +// getCert is the GetCertificate function used by ListenTLS. +// +// It calls GetCertificate on the localClient, passing in the ClientHelloInfo. +// For testing, if s.getCertForTesting is set, it will call that instead. +func (s *Server) getCert(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + if s.getCertForTesting != nil { + return s.getCertForTesting(hi) + } + lc, err := s.LocalClient() if err != nil { return nil, err } - return tls.NewListener(ln, &tls.Config{ - GetCertificate: lc.GetCertificate, - }), nil + return lc.GetCertificate(hi) } // FunnelOption is an option passed to ListenFunnel to configure the listener. @@ -909,10 +921,7 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L return nil, err } - lc, err := s.LocalClient() - if err != nil { - return nil, err - } + lc := s.localClient // May not have funnel enabled. Enable it. srvConfig, err := lc.GetServeConfig(ctx) @@ -944,7 +953,7 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L return nil, err } return tls.NewListener(ln, &tls.Config{ - GetCertificate: lc.GetCertificate, + GetCertificate: s.getCert, }), nil } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 0dab542f0..e170de5e0 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -4,27 +4,40 @@ package tsnet import ( + "bufio" "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "errors" "flag" "fmt" "io" + "io/ioutil" + "math/big" "net" "net/http" "net/http/httptest" "net/netip" "os" "path/filepath" + "strings" + "sync" "testing" "time" "golang.org/x/net/proxy" + "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/net/netns" "tailscale.com/tailcfg" "tailscale.com/tstest/integration" "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/logger" + "tailscale.com/util/must" ) // TestListener_Server ensures that the listener type always keeps the Server @@ -93,6 +106,10 @@ func startControl(t *testing.T) (controlURL string) { derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1") control := &testcontrol.Server{ DERPMap: derpMap, + DNSConfig: &tailcfg.DNSConfig{ + Proxied: true, + }, + MagicDNSDomain: "tail-scale.ts.net", } control.HTTPTestServer = httptest.NewUnstartedServer(control) control.HTTPTestServer.Start() @@ -102,17 +119,96 @@ func startControl(t *testing.T) (controlURL string) { return controlURL } +type testCertIssuer struct { + mu sync.Mutex + certs map[string]*tls.Certificate + + root *x509.Certificate + rootKey *ecdsa.PrivateKey +} + +func newCertIssuer() *testCertIssuer { + rootKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + t := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "root", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + rootDER, err := x509.CreateCertificate(rand.Reader, t, t, &rootKey.PublicKey, rootKey) + if err != nil { + panic(err) + } + rootCA, err := x509.ParseCertificate(rootDER) + if err != nil { + panic(err) + } + return &testCertIssuer{ + certs: make(map[string]*tls.Certificate), + root: rootCA, + rootKey: rootKey, + } +} + +func (tci *testCertIssuer) getCert(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + tci.mu.Lock() + defer tci.mu.Unlock() + cert, ok := tci.certs[chi.ServerName] + if ok { + return cert, nil + } + + certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + certTmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{chi.ServerName}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + } + certDER, err := x509.CreateCertificate(rand.Reader, certTmpl, tci.root, &certPrivKey.PublicKey, tci.rootKey) + if err != nil { + return nil, err + } + cert = &tls.Certificate{ + Certificate: [][]byte{certDER, tci.root.Raw}, + PrivateKey: certPrivKey, + } + tci.certs[chi.ServerName] = cert + return cert, nil +} + +func (tci *testCertIssuer) Pool() *x509.CertPool { + p := x509.NewCertPool() + p.AddCert(tci.root) + return p +} + +var testCertRoot = newCertIssuer() + func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) (*Server, netip.Addr) { t.Helper() tmp := filepath.Join(t.TempDir(), hostname) os.MkdirAll(tmp, 0755) s := &Server{ - Dir: tmp, - ControlURL: controlURL, - Hostname: hostname, - Store: new(mem.Store), - Ephemeral: true, + Dir: tmp, + ControlURL: controlURL, + Hostname: hostname, + Store: new(mem.Store), + Ephemeral: true, + getCertForTesting: testCertRoot.getCert, } if !*verboseNodes { s.Logf = logger.Discard @@ -368,3 +464,112 @@ func TestListenerCleanup(t *testing.T) { t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err) } } + +func TestFunnel(t *testing.T) { + ctx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer dialCancel() + + controlURL := startControl(t) + s1, _ := startServer(t, ctx, controlURL, "s1") + s2, _ := startServer(t, ctx, controlURL, "s2") + + ln := must.Get(s1.ListenFunnel("tcp", ":443")) + defer ln.Close() + wantSrcAddrPort := netip.MustParseAddrPort("127.0.0.1:1234") + wantTarget := ipn.HostPort("s1.tail-scale.ts.net:443") + srv := &http.Server{ + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + tc, ok := c.(*tls.Conn) + if !ok { + t.Errorf("ConnContext called with non-TLS conn: %T", c) + } + if fc, ok := tc.NetConn().(*ipn.FunnelConn); !ok { + t.Errorf("ConnContext called with non-FunnelConn: %T", c) + } else if fc.Src != wantSrcAddrPort { + t.Errorf("ConnContext called with wrong SrcAddrPort; got %v, want %v", fc.Src, wantSrcAddrPort) + } else if fc.Target != wantTarget { + t.Errorf("ConnContext called with wrong Target; got %q, want %q", fc.Target, wantTarget) + } + return ctx + }, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "hello") + }), + } + go srv.Serve(ln) + + c := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialIngressConn(s2, s1, addr) + }, + TLSClientConfig: &tls.Config{ + RootCAs: testCertRoot.Pool(), + }, + }, + } + resp, err := c.Get("https://s1.tail-scale.ts.net:443") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("unexpected status code: %v", resp.StatusCode) + return + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != "hello" { + t.Errorf("unexpected body: %q", body) + } +} + +func dialIngressConn(from, to *Server, target string) (net.Conn, error) { + toLC := must.Get(to.LocalClient()) + toStatus := must.Get(toLC.StatusWithoutPeers(context.Background())) + peer6 := toStatus.Self.PeerAPIURL[1] // IPv6 + toPeerAPI, ok := strings.CutPrefix(peer6, "http://") + if !ok { + return nil, fmt.Errorf("unexpected PeerAPIURL %q", peer6) + } + + dialCtx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second) + outConn, err := from.Dial(dialCtx, "tcp", toPeerAPI) + dialCancel() + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", "/v0/ingress", nil) + if err != nil { + return nil, err + } + req.Host = toPeerAPI + req.Header.Set("Tailscale-Ingress-Src", "127.0.0.1:1234") + req.Header.Set("Tailscale-Ingress-Target", target) + if err := req.Write(outConn); err != nil { + return nil, err + } + + br := bufio.NewReader(outConn) + res, err := http.ReadResponse(br, req) + if err != nil { + return nil, err + } + defer res.Body.Close() // just to appease vet + if res.StatusCode != 101 { + return nil, fmt.Errorf("unexpected status code: %v", res.StatusCode) + } + return &bufferedConn{outConn, br}, nil +} + +type bufferedConn struct { + net.Conn + reader *bufio.Reader +} + +func (c *bufferedConn) Read(b []byte) (int, error) { + return c.reader.Read(b) +} diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 2c724258c..dc1e6603a 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -26,6 +26,7 @@ import ( "github.com/klauspost/compress/zstd" "go4.org/mem" + "golang.org/x/exp/slices" "tailscale.com/net/netaddr" "tailscale.com/net/tsaddr" "tailscale.com/smallzstd" @@ -39,11 +40,12 @@ const msgLimit = 1 << 20 // encrypted message length limit // Server is a control plane server. Its zero value is ready for use. // Everything is stored in-memory in one tailnet. type Server struct { - Logf logger.Logf // nil means to use the log package - DERPMap *tailcfg.DERPMap // nil means to use prod DERP map - RequireAuth bool - Verbose bool - DNSConfig *tailcfg.DNSConfig // nil means no DNS config + Logf logger.Logf // nil means to use the log package + DERPMap *tailcfg.DERPMap // nil means to use prod DERP map + RequireAuth bool + Verbose bool + DNSConfig *tailcfg.DNSConfig // nil means no DNS config + MagicDNSDomain string // ExplicitBaseURL or HTTPTestServer must be set. ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL @@ -328,6 +330,15 @@ func (s *Server) AddFakeNode() { // TODO: send updates to other (non-fake?) nodes } +func (s *Server) AllUsers() (users []*tailcfg.User) { + s.mu.Lock() + defer s.mu.Unlock() + for _, u := range s.users { + users = append(users, u.Clone()) + } + return users +} + func (s *Server) AllNodes() (nodes []*tailcfg.Node) { s.mu.Lock() defer s.mu.Unlock() @@ -494,6 +505,11 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. Addresses: allowedIPs, AllowedIPs: allowedIPs, Hostinfo: req.Hostinfo.View(), + Name: req.Hostinfo.Hostname, + Capabilities: []string{ + tailcfg.NodeAttrFunnel, + tailcfg.CapabilityFunnelPorts + "?ports=8080,443", + }, } requireAuth := s.RequireAuth if requireAuth && s.nodeKeyAuthed[nk] { @@ -729,6 +745,20 @@ var keepAliveMsg = &struct { KeepAlive: true, } +func packetFilterWithIngressCaps() []tailcfg.FilterRule { + out := slices.Clone(tailcfg.FilterAllowAll) + out = append(out, tailcfg.FilterRule{ + SrcIPs: []string{"*"}, + CapGrant: []tailcfg.CapGrant{ + { + Dsts: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + Caps: []string{tailcfg.CapabilityIngress}, + }, + }, + }) + return out +} + // MapResponse generates a MapResponse for a MapRequest. // // No updates to s are done here. @@ -741,16 +771,24 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, } user, _ := s.getUser(nk) t := time.Date(2020, 8, 3, 0, 0, 0, 1, time.UTC) + dns := s.DNSConfig + if dns != nil && s.MagicDNSDomain != "" { + dns = dns.Clone() + dns.CertDomains = []string{ + fmt.Sprintf(node.Hostinfo.Hostname() + "." + s.MagicDNSDomain), + } + } + res = &tailcfg.MapResponse{ Node: node, DERPMap: s.DERPMap, Domain: string(user.Domain), CollectServices: "true", - PacketFilter: tailcfg.FilterAllowAll, + PacketFilter: packetFilterWithIngressCaps(), Debug: &tailcfg.Debug{ DisableUPnP: "true", }, - DNSConfig: s.DNSConfig, + DNSConfig: dns, ControlTime: &t, } for _, p := range s.AllNodes() { @@ -761,6 +799,13 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, sort.Slice(res.Peers, func(i, j int) bool { return res.Peers[i].ID < res.Peers[j].ID }) + for _, u := range s.AllUsers() { + res.UserProfiles = append(res.UserProfiles, tailcfg.UserProfile{ + ID: u.ID, + LoginName: u.LoginName, + DisplayName: u.DisplayName, + }) + } v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128)