diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index 21fb5fdd3..015d5ae96 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -398,6 +398,11 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort) return nil } +func getServeHTTPContext(r *http.Request) (c *serveHTTPContext, ok bool) { + c, ok = r.Context().Value(serveHTTPContextKey{}).(*serveHTTPContext) + return c, ok +} + func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, at string, ok bool) { var z ipn.HTTPHandlerView // zero value @@ -405,7 +410,7 @@ func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, return z, "", false } - sctx, ok := r.Context().Value(serveHTTPContextKey{}).(*serveHTTPContext) + sctx, ok := getServeHTTPContext(r) if !ok { b.logf("[unexpected] localbackend: no serveHTTPContext in request") return z, "", false @@ -446,11 +451,8 @@ func (b *LocalBackend) proxyHandlerForBackend(backend string) (*httputil.Reverse Rewrite: func(r *httputil.ProxyRequest) { r.SetURL(u) r.Out.Host = r.In.Host - r.Out.Header.Set("X-Forwarded-Host", r.In.Host) - r.Out.Header.Set("X-Forwarded-Proto", "https") - if c, ok := r.Out.Context().Value(serveHTTPContextKey{}).(*serveHTTPContext); ok { - r.Out.Header.Set("X-Forwarded-For", c.SrcAddr.Addr().String()) - } + addProxyForwardedHeaders(r) + b.addTailscaleIdentityHeaders(r) }, Transport: &http.Transport{ DialContext: b.dialer.SystemDial, @@ -468,6 +470,36 @@ func (b *LocalBackend) proxyHandlerForBackend(backend string) (*httputil.Reverse return rp, nil } +func addProxyForwardedHeaders(r *httputil.ProxyRequest) { + r.Out.Header.Set("X-Forwarded-Host", r.In.Host) + r.Out.Header.Set("X-Forwarded-Proto", "https") + if c, ok := getServeHTTPContext(r.Out); ok { + r.Out.Header.Set("X-Forwarded-For", c.SrcAddr.Addr().String()) + } +} + +func (b *LocalBackend) addTailscaleIdentityHeaders(r *httputil.ProxyRequest) { + // Clear any incoming values squatting in the headers. + r.Out.Header.Del("Tailscale-User-Login") + r.Out.Header.Del("Tailscale-User-Name") + + c, ok := getServeHTTPContext(r.Out) + if !ok { + return + } + node, user, ok := b.WhoIs(c.SrcAddr) + if !ok { + return // traffic from outside of Tailnet (funneled) + } + if node.IsTagged() { + // 2023-06-14: Not setting identity headers for tagged nodes. + // Only currently set for nodes with user identities. + return + } + r.Out.Header.Set("Tailscale-User-Login", user.LoginName) + r.Out.Header.Set("Tailscale-User-Name", user.DisplayName) +} + func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) { h, mountPoint, ok := b.getServeHandler(r) if !ok { diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index b78e5c63f..362803fcf 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -10,13 +10,22 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/netip" "net/url" "os" "path/filepath" + "strings" "testing" "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/types/logid" + "tailscale.com/types/netmap" "tailscale.com/util/cmpx" + "tailscale.com/util/must" + "tailscale.com/wgengine" ) func TestExpandProxyArg(t *testing.T) { @@ -160,6 +169,139 @@ func TestGetServeHandler(t *testing.T) { } } +func TestServeHTTPProxy(t *testing.T) { + sys := &tsd.System{} + e, err := wgengine.NewUserspaceEngine(t.Logf, wgengine.Config{SetSubsystem: sys.Set}) + if err != nil { + t.Fatal(err) + } + sys.Set(e) + sys.Set(new(mem.Store)) + b, err := NewLocalBackend(t.Logf, logid.PublicID{}, sys, 0) + if err != nil { + t.Fatal(err) + } + defer b.Shutdown() + dir := t.TempDir() + b.SetVarRoot(dir) + + pm := must.Get(newProfileManager(new(mem.Store), t.Logf)) + pm.currentProfile = &ipn.LoginProfile{ID: "id0"} + b.pm = pm + + b.netMap = &netmap.NetworkMap{ + SelfNode: &tailcfg.Node{ + Name: "example.ts.net", + }, + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfile{ + tailcfg.UserID(1): { + LoginName: "someone@example.com", + DisplayName: "Some One", + }, + }, + } + b.nodeByAddr = map[netip.Addr]*tailcfg.Node{ + netip.MustParseAddr("100.150.151.152"): { + ComputedName: "some-peer", + User: tailcfg.UserID(1), + }, + netip.MustParseAddr("100.150.151.153"): { + ComputedName: "some-tagged-peer", + Tags: []string{"tag:server", "tag:test"}, + User: tailcfg.UserID(1), + }, + } + + // Start test serve endpoint. + testServ := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // Piping all the headers through the response writer + // so we can check their values in tests below. + for key, val := range r.Header { + w.Header().Add(key, strings.Join(val, ",")) + } + }, + )) + defer testServ.Close() + + conf := &ipn.ServeConfig{ + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "example.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: testServ.URL}, + }}, + }, + } + if err := b.SetServeConfig(conf); err != nil { + t.Fatal(err) + } + + type headerCheck struct { + header string + want string + } + + tests := []struct { + name string + srcIP string + wantHeaders []headerCheck + }{ + { + name: "request-from-user-within-tailnet", + srcIP: "100.150.151.152", + wantHeaders: []headerCheck{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-For", "100.150.151.152"}, + {"Tailscale-User-Login", "someone@example.com"}, + {"Tailscale-User-Name", "Some One"}, + }, + }, + { + name: "request-from-tagged-node-within-tailnet", + srcIP: "100.150.151.153", + wantHeaders: []headerCheck{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-For", "100.150.151.153"}, + {"Tailscale-User-Login", ""}, + {"Tailscale-User-Name", ""}, + }, + }, + { + name: "request-from-outside-tailnet", + srcIP: "100.160.161.162", + wantHeaders: []headerCheck{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-For", "100.160.161.162"}, + {"Tailscale-User-Login", ""}, + {"Tailscale-User-Name", ""}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{Path: "/"}, + TLS: &tls.ConnectionState{ServerName: "example.ts.net"}, + } + req = req.WithContext(context.WithValue(req.Context(), serveHTTPContextKey{}, &serveHTTPContext{ + DestPort: 443, + SrcAddr: netip.MustParseAddrPort(tt.srcIP + ":1234"), // random src port for tests + })) + + w := httptest.NewRecorder() + b.serveWebHandler(w, req) + + // Verify the headers. + h := w.Result().Header + for _, c := range tt.wantHeaders { + if got := h.Get(c.header); got != c.want { + t.Errorf("invalid %q header; want=%q, got=%q", c.header, c.want, got) + } + } + }) + } +} + func TestServeFileOrDirectory(t *testing.T) { td := t.TempDir() writeFile := func(suffix, contents string) {