From 2c956e30bea76678e7c2ec1204f2be398a64e94d Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 29 Sep 2025 17:57:04 -0700 Subject: [PATCH] ipn/ipnlocal: proxy h2c grpc using net/http.Transport instead of x/net/http2 (Kinda related: #17351) Updates #17305 Change-Id: I47df2612732a5713577164e74652bc9fa3cd14b3 Signed-off-by: Brad Fitzpatrick --- ipn/ipnlocal/serve.go | 22 +++++----- ipn/ipnlocal/serve_test.go | 88 +++++++++++++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 12 deletions(-) diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index dc4142404..3c967fd1e 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -34,7 +34,6 @@ import ( "unicode/utf8" "go4.org/mem" - "golang.org/x/net/http2" "tailscale.com/ipn" "tailscale.com/net/netutil" "tailscale.com/syncs" @@ -761,8 +760,8 @@ type reverseProxy struct { insecure bool backend string lb *LocalBackend - httpTransport lazy.SyncValue[*http.Transport] // transport for non-h2c backends - h2cTransport lazy.SyncValue[*http2.Transport] // transport for h2c backends + httpTransport lazy.SyncValue[*http.Transport] // transport for non-h2c backends + h2cTransport lazy.SyncValue[*http.Transport] // transport for h2c backends // closed tracks whether proxy is closed/currently closing. closed atomic.Bool } @@ -770,9 +769,7 @@ type reverseProxy struct { // close ensures that any open backend connections get closed. func (rp *reverseProxy) close() { rp.closed.Store(true) - if h2cT := rp.h2cTransport.Get(func() *http2.Transport { - return nil - }); h2cT != nil { + if h2cT := rp.h2cTransport.Get(func() *http.Transport { return nil }); h2cT != nil { h2cT.CloseIdleConnections() } if httpTransport := rp.httpTransport.Get(func() *http.Transport { @@ -843,14 +840,17 @@ func (rp *reverseProxy) getTransport() *http.Transport { // getH2CTransport returns the Transport used for GRPC requests to the backend. // The Transport gets created lazily, at most once. -func (rp *reverseProxy) getH2CTransport() *http2.Transport { - return rp.h2cTransport.Get(func() *http2.Transport { - return &http2.Transport{ - AllowHTTP: true, - DialTLSContext: func(ctx context.Context, network string, addr string, _ *tls.Config) (net.Conn, error) { +func (rp *reverseProxy) getH2CTransport() http.RoundTripper { + return rp.h2cTransport.Get(func() *http.Transport { + var p http.Protocols + p.SetUnencryptedHTTP2(true) + tr := &http.Transport{ + Protocols: &p, + DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { return rp.lb.dialer.SystemDial(ctx, "tcp", rp.url.Host) }, } + return tr }) } diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index a081ed27b..b4461d12f 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -15,6 +15,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "net/http/httptest" "net/netip" @@ -881,7 +882,7 @@ func mustCreateURL(t *testing.T, u string) url.URL { func newTestBackend(t *testing.T, opts ...any) *LocalBackend { var logf logger.Logf = logger.Discard - const debug = true + const debug = false if debug { logf = logger.WithPrefix(tstest.WhileTestRunningLogger(t), "... ") } @@ -1085,3 +1086,88 @@ func TestEncTailscaleHeaderValue(t *testing.T) { } } } + +func TestServeGRPCProxy(t *testing.T) { + const msg = "some-response\n" + backend := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Path-Was", r.RequestURI) + w.Header().Set("Proto-Was", r.Proto) + io.WriteString(w, msg) + })) + backend.EnableHTTP2 = true + backend.Config.Protocols = new(http.Protocols) + backend.Config.Protocols.SetHTTP1(true) + backend.Config.Protocols.SetUnencryptedHTTP2(true) + backend.Start() + defer backend.Close() + + backendURL := must.Get(url.Parse(backend.URL)) + + lb := newTestBackend(t) + rp := &reverseProxy{ + logf: t.Logf, + url: backendURL, + backend: backend.URL, + lb: lb, + } + + req := func(method, urlStr string, opt ...any) *http.Request { + req := httptest.NewRequest(method, urlStr, nil) + for _, o := range opt { + switch v := o.(type) { + case int: + req.ProtoMajor = v + case string: + req.Header.Set("Content-Type", v) + default: + panic(fmt.Sprintf("unsupported option type %T", v)) + } + } + return req + } + + tests := []struct { + name string + req *http.Request + wantPath string + wantProto string + wantBody string + }{ + { + name: "non-gRPC", + req: req("GET", "http://foo/bar"), + wantPath: "/bar", + wantProto: "HTTP/1.1", + }, + { + name: "gRPC-but-not-http2", + req: req("GET", "http://foo/bar", "application/grpc"), + wantPath: "/bar", + wantProto: "HTTP/1.1", + }, + { + name: "gRPC--http2", + req: req("GET", "http://foo/bar", 2, "application/grpc"), + wantPath: "/bar", + wantProto: "HTTP/2.0", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + rp.ServeHTTP(rec, tt.req) + + res := rec.Result() + got := must.Get(io.ReadAll(res.Body)) + if got, want := res.Header.Get("Path-Was"), tt.wantPath; want != got { + t.Errorf("Path-Was %q, want %q", got, want) + } + if got, want := res.Header.Get("Proto-Was"), tt.wantProto; want != got { + t.Errorf("Proto-Was %q, want %q", got, want) + } + if string(got) != msg { + t.Errorf("got body %q, want %q", got, msg) + } + }) + } +}