[networking] Add strict Request extension checking (#7604)

Authored by: coletdjnz
Co-authored-by: pukkandan <pukkandan.ytdlp@gmail.com>
pull/7681/head
coletdjnz 1 year ago committed by GitHub
parent 11de6fec9c
commit 86aea0d3a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -804,10 +804,10 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
assert not isinstance(exc_info.value, TransportError) assert not isinstance(exc_info.value, TransportError)
def run_validation(handler, fail, req, **handler_kwargs): def run_validation(handler, error, req, **handler_kwargs):
with handler(**handler_kwargs) as rh: with handler(**handler_kwargs) as rh:
if fail: if error:
with pytest.raises(UnsupportedRequest): with pytest.raises(error):
rh.validate(req) rh.validate(req)
else: else:
rh.validate(req) rh.validate(req)
@ -824,6 +824,9 @@ class TestRequestHandlerValidation:
_SUPPORTED_PROXY_SCHEMES = None _SUPPORTED_PROXY_SCHEMES = None
_SUPPORTED_URL_SCHEMES = None _SUPPORTED_URL_SCHEMES = None
def _check_extensions(self, extensions):
extensions.clear()
class HTTPSupportedRH(ValidationRH): class HTTPSupportedRH(ValidationRH):
_SUPPORTED_URL_SCHEMES = ('http',) _SUPPORTED_URL_SCHEMES = ('http',)
@ -834,26 +837,26 @@ class TestRequestHandlerValidation:
('https', False, {}), ('https', False, {}),
('data', False, {}), ('data', False, {}),
('ftp', False, {}), ('ftp', False, {}),
('file', True, {}), ('file', UnsupportedRequest, {}),
('file', False, {'enable_file_urls': True}), ('file', False, {'enable_file_urls': True}),
]), ]),
(NoCheckRH, [('http', False, {})]), (NoCheckRH, [('http', False, {})]),
(ValidationRH, [('http', True, {})]) (ValidationRH, [('http', UnsupportedRequest, {})])
] ]
PROXY_SCHEME_TESTS = [ PROXY_SCHEME_TESTS = [
# scheme, expected to fail # scheme, expected to fail
('Urllib', [ ('Urllib', [
('http', False), ('http', False),
('https', True), ('https', UnsupportedRequest),
('socks4', False), ('socks4', False),
('socks4a', False), ('socks4a', False),
('socks5', False), ('socks5', False),
('socks5h', False), ('socks5h', False),
('socks', True), ('socks', UnsupportedRequest),
]), ]),
(NoCheckRH, [('http', False)]), (NoCheckRH, [('http', False)]),
(HTTPSupportedRH, [('http', True)]), (HTTPSupportedRH, [('http', UnsupportedRequest)]),
] ]
PROXY_KEY_TESTS = [ PROXY_KEY_TESTS = [
@ -863,8 +866,22 @@ class TestRequestHandlerValidation:
('unrelated', False), ('unrelated', False),
]), ]),
(NoCheckRH, [('all', False)]), (NoCheckRH, [('all', False)]),
(HTTPSupportedRH, [('all', True)]), (HTTPSupportedRH, [('all', UnsupportedRequest)]),
(HTTPSupportedRH, [('no', True)]), (HTTPSupportedRH, [('no', UnsupportedRequest)]),
]
EXTENSION_TESTS = [
('Urllib', [
({'cookiejar': 'notacookiejar'}, AssertionError),
({'cookiejar': CookieJar()}, False),
({'timeout': 1}, False),
({'timeout': 'notatimeout'}, AssertionError),
({'unsupported': 'value'}, UnsupportedRequest),
]),
(NoCheckRH, [
({'cookiejar': 'notacookiejar'}, False),
({'somerandom': 'test'}, False), # but any extension is allowed through
]),
] ]
@pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [ @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
@ -907,15 +924,16 @@ class TestRequestHandlerValidation:
@pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1']) @pytest.mark.parametrize('proxy_url', ['//example.com', 'example.com', '127.0.0.1'])
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True) @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_missing_proxy_scheme(self, handler, proxy_url): def test_missing_proxy_scheme(self, handler, proxy_url):
run_validation(handler, True, Request('http://', proxies={'http': 'example.com'})) run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': 'example.com'}))
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_cookiejar_extension(self, handler):
run_validation(handler, True, Request('http://', extensions={'cookiejar': 'notacookiejar'}))
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True) @pytest.mark.parametrize('handler,extensions,fail', [
def test_timeout_extension(self, handler): (handler_tests[0], extensions, fail)
run_validation(handler, True, Request('http://', extensions={'timeout': 'notavalidtimeout'})) for handler_tests in EXTENSION_TESTS
for extensions, fail in handler_tests[1]
], indirect=['handler'])
def test_extension(self, handler, extensions, fail):
run_validation(
handler, fail, Request('http://', extensions=extensions))
def test_invalid_request_type(self): def test_invalid_request_type(self):
rh = self.ValidationRH(logger=FakeLogger()) rh = self.ValidationRH(logger=FakeLogger())

@ -385,6 +385,11 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
if self.enable_file_urls: if self.enable_file_urls:
self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file') self._SUPPORTED_URL_SCHEMES = (*self._SUPPORTED_URL_SCHEMES, 'file')
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
extensions.pop('cookiejar', None)
extensions.pop('timeout', None)
def _create_instance(self, proxies, cookiejar): def _create_instance(self, proxies, cookiejar):
opener = urllib.request.OpenerDirector() opener = urllib.request.OpenerDirector()
handlers = [ handlers = [

@ -21,6 +21,7 @@ from .exceptions import (
TransportError, TransportError,
UnsupportedRequest, UnsupportedRequest,
) )
from ..compat.types import NoneType
from ..utils import ( from ..utils import (
bug_reports_message, bug_reports_message,
classproperty, classproperty,
@ -147,6 +148,7 @@ class RequestHandler(abc.ABC):
a proxy url with an url scheme not in this list will raise an UnsupportedRequest. a proxy url with an url scheme not in this list will raise an UnsupportedRequest.
- `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum. - `_SUPPORTED_FEATURES`: a tuple of supported features, as defined in Features enum.
The above may be set to None to disable the checks. The above may be set to None to disable the checks.
Parameters: Parameters:
@ -169,9 +171,14 @@ class RequestHandler(abc.ABC):
Requests may have additional optional parameters defined as extensions. Requests may have additional optional parameters defined as extensions.
RequestHandler subclasses may choose to support custom extensions. RequestHandler subclasses may choose to support custom extensions.
If an extension is supported, subclasses should extend _check_extensions(extensions)
to pop and validate the extension.
- Extensions left in `extensions` are treated as unsupported and UnsupportedRequest will be raised.
The following extensions are defined for RequestHandler: The following extensions are defined for RequestHandler:
- `cookiejar`: Cookiejar to use for this request - `cookiejar`: Cookiejar to use for this request.
- `timeout`: socket timeout to use for this request - `timeout`: socket timeout to use for this request.
To enable these, add extensions.pop('<extension>', None) to _check_extensions
Apart from the url protocol, proxies dict may contain the following keys: Apart from the url protocol, proxies dict may contain the following keys:
- `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol. - `all`: proxy to use for all protocols. Used as a fallback if no proxy is set for a specific protocol.
@ -263,26 +270,19 @@ class RequestHandler(abc.ABC):
if scheme not in self._SUPPORTED_PROXY_SCHEMES: if scheme not in self._SUPPORTED_PROXY_SCHEMES:
raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"') raise UnsupportedRequest(f'Unsupported proxy type: "{scheme}"')
def _check_cookiejar_extension(self, extensions):
if not extensions.get('cookiejar'):
return
if not isinstance(extensions['cookiejar'], CookieJar):
raise UnsupportedRequest('cookiejar is not a CookieJar')
def _check_timeout_extension(self, extensions):
if extensions.get('timeout') is None:
return
if not isinstance(extensions['timeout'], (float, int)):
raise UnsupportedRequest('timeout is not a float or int')
def _check_extensions(self, extensions): def _check_extensions(self, extensions):
self._check_cookiejar_extension(extensions) """Check extensions for unsupported extensions. Subclasses should extend this."""
self._check_timeout_extension(extensions) assert isinstance(extensions.get('cookiejar'), (CookieJar, NoneType))
assert isinstance(extensions.get('timeout'), (float, int, NoneType))
def _validate(self, request): def _validate(self, request):
self._check_url_scheme(request) self._check_url_scheme(request)
self._check_proxies(request.proxies or self.proxies) self._check_proxies(request.proxies or self.proxies)
self._check_extensions(request.extensions) extensions = request.extensions.copy()
self._check_extensions(extensions)
if extensions:
# TODO: add support for optional extensions
raise UnsupportedRequest(f'Unsupported extensions: {", ".join(extensions.keys())}')
@wrap_request_errors @wrap_request_errors
def validate(self, request: Request): def validate(self, request: Request):

Loading…
Cancel
Save