diff --git a/ipn/ipnserver/proxyconnect.go b/ipn/ipnserver/proxyconnect.go index eb8c55991..8f330add1 100644 --- a/ipn/ipnserver/proxyconnect.go +++ b/ipn/ipnserver/proxyconnect.go @@ -37,8 +37,7 @@ func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) return } - tr := logpolicy.NewLogtailTransport(logHost) - back, err := tr.DialContext(ctx, "tcp", hostPort) + back, err := logpolicy.DialContext(ctx, "tcp", hostPort) if err != nil { s.logf("error CONNECT dialing %v: %v", hostPort, err) http.Error(w, "Connect failure", http.StatusBadGateway) diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go index 9ccfa3d3b..9dd9593cc 100644 --- a/logpolicy/logpolicy.go +++ b/logpolicy/logpolicy.go @@ -667,11 +667,59 @@ func (p *Policy) Shutdown(ctx context.Context) error { return nil } -// NewLogtailTransport returns an HTTP Transport particularly suited to uploading -// logs to the given host name. This includes: -// - If DNS lookup fails, consult the bootstrap DNS list of Tailscale hostnames. +// DialContext is a net.Dialer.DialContext specialized for use by logtail. +// It does the following: +// - If DNS lookup fails, consults the bootstrap DNS list of Tailscale hostnames. // - If TLS connection fails, try again using LetsEncrypt's built-in root certificate, // for the benefit of older OS platforms which might not include it. +func DialContext(ctx context.Context, netw, addr string) (net.Conn, error) { + nd := netns.FromDialer(log.Printf, &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: netknob.PlatformTCPKeepAlive(), + }) + t0 := time.Now() + c, err := nd.DialContext(ctx, netw, addr) + d := time.Since(t0).Round(time.Millisecond) + if err == nil { + dialLog.Printf("dialed %q in %v", addr, d) + return c, nil + } + + if version.IsWindowsGUI() && strings.HasPrefix(netw, "tcp") { + if c, err := safesocket.Connect(safesocket.DefaultConnectionStrategy("")); err == nil { + fmt.Fprintf(c, "CONNECT %s HTTP/1.0\r\n\r\n", addr) + br := bufio.NewReader(c) + res, err := http.ReadResponse(br, nil) + if err == nil && res.StatusCode != 200 { + err = errors.New(res.Status) + } + if err != nil { + log.Printf("logtail: CONNECT response error from tailscaled: %v", err) + c.Close() + } else { + dialLog.Printf("connected via tailscaled") + return c, nil + } + } + } + + // If we failed to dial, try again with bootstrap DNS. + log.Printf("logtail: dial %q failed: %v (in %v), trying bootstrap...", addr, err, d) + dnsCache := &dnscache.Resolver{ + Forward: dnscache.Get().Forward, // use default cache's forwarder + UseLastGood: true, + LookupIPFallback: dnsfallback.Lookup, + } + dialer := dnscache.Dialer(nd.DialContext, dnsCache) + c, err = dialer(ctx, netw, addr) + if err == nil { + log.Printf("logtail: bootstrap dial succeeded") + } + return c, err +} + +// NewLogtailTransport returns an HTTP Transport particularly suited to uploading +// logs to the given host name. See DialContext for details on how it works. func NewLogtailTransport(host string) *http.Transport { // Start with a copy of http.DefaultTransport and tweak it a bit. tr := http.DefaultTransport.(*http.Transport).Clone() @@ -685,51 +733,7 @@ func NewLogtailTransport(host string) *http.Transport { tr.DisableCompression = true // Log whenever we dial: - tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { - nd := netns.FromDialer(log.Printf, &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: netknob.PlatformTCPKeepAlive(), - }) - t0 := time.Now() - c, err := nd.DialContext(ctx, netw, addr) - d := time.Since(t0).Round(time.Millisecond) - if err == nil { - dialLog.Printf("dialed %q in %v", addr, d) - return c, nil - } - - if version.IsWindowsGUI() && strings.HasPrefix(netw, "tcp") { - if c, err := safesocket.Connect(safesocket.DefaultConnectionStrategy("")); err == nil { - fmt.Fprintf(c, "CONNECT %s HTTP/1.0\r\n\r\n", addr) - br := bufio.NewReader(c) - res, err := http.ReadResponse(br, nil) - if err == nil && res.StatusCode != 200 { - err = errors.New(res.Status) - } - if err != nil { - log.Printf("logtail: CONNECT response error from tailscaled: %v", err) - c.Close() - } else { - dialLog.Printf("connected via tailscaled") - return c, nil - } - } - } - - // If we failed to dial, try again with bootstrap DNS. - log.Printf("logtail: dial %q failed: %v (in %v), trying bootstrap...", addr, err, d) - dnsCache := &dnscache.Resolver{ - Forward: dnscache.Get().Forward, // use default cache's forwarder - UseLastGood: true, - LookupIPFallback: dnsfallback.Lookup, - } - dialer := dnscache.Dialer(nd.DialContext, dnsCache) - c, err = dialer(ctx, netw, addr) - if err == nil { - log.Printf("logtail: bootstrap dial succeeded") - } - return c, err - } + tr.DialContext = DialContext // We're contacting exactly 1 hostname, so the default's 100 // max idle conns is very high for our needs. Even 2 is @@ -762,7 +766,7 @@ func goVersion() string { type noopPretendSuccessTransport struct{} func (noopPretendSuccessTransport) RoundTrip(req *http.Request) (*http.Response, error) { - io.ReadAll(req.Body) + io.Copy(io.Discard, req.Body) req.Body.Close() return &http.Response{ StatusCode: 200,