[rh:websockets] Migrate websockets to networking framework (#7720)

* Adds a basic WebSocket framework
* Introduces new minimum `websockets` version of 12.0
* Deprecates `WebSocketsWrapper`

Fixes https://github.com/yt-dlp/yt-dlp/issues/8439

Authored by: coletdjnz
pull/5847/merge
coletdjnz 1 year ago committed by GitHub
parent 45d82be65f
commit ccfd70f4c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,3 +6,4 @@ brotlicffi; implementation_name!='cpython'
certifi certifi
requests>=2.31.0,<3 requests>=2.31.0,<3
urllib3>=1.26.17,<3 urllib3>=1.26.17,<3
websockets>=12.0

@ -19,3 +19,8 @@ def handler(request):
pytest.skip(f'{RH_KEY} request handler is not available') pytest.skip(f'{RH_KEY} request handler is not available')
return functools.partial(handler, logger=FakeLogger) return functools.partial(handler, logger=FakeLogger)
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)

@ -52,6 +52,8 @@ from yt_dlp.networking.exceptions import (
from yt_dlp.utils._utils import _YDLLogger as FakeLogger from yt_dlp.utils._utils import _YDLLogger as FakeLogger
from yt_dlp.utils.networking import HTTPHeaderDict from yt_dlp.utils.networking import HTTPHeaderDict
from test.conftest import validate_and_send
TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_DIR = os.path.dirname(os.path.abspath(__file__))
@ -275,11 +277,6 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode()) self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
def validate_and_send(rh, req):
rh.validate(req)
return rh.send(req)
class TestRequestHandlerBase: class TestRequestHandlerBase:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
@ -872,8 +869,9 @@ class TestRequestsRequestHandler(TestRequestHandlerBase):
]) ])
@pytest.mark.parametrize('handler', ['Requests'], indirect=True) @pytest.mark.parametrize('handler', ['Requests'], indirect=True)
def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match): def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match):
from urllib3.response import HTTPResponse as Urllib3Response
from requests.models import Response as RequestsResponse from requests.models import Response as RequestsResponse
from urllib3.response import HTTPResponse as Urllib3Response
from yt_dlp.networking._requests import RequestsResponseAdapter from yt_dlp.networking._requests import RequestsResponseAdapter
requests_res = RequestsResponse() requests_res = RequestsResponse()
requests_res.raw = Urllib3Response(body=b'', status=200) requests_res.raw = Urllib3Response(body=b'', status=200)
@ -929,13 +927,17 @@ class TestRequestHandlerValidation:
('http', False, {}), ('http', False, {}),
('https', False, {}), ('https', False, {}),
]), ]),
('Websockets', [
('ws', False, {}),
('wss', False, {}),
]),
(NoCheckRH, [('http', False, {})]), (NoCheckRH, [('http', False, {})]),
(ValidationRH, [('http', UnsupportedRequest, {})]) (ValidationRH, [('http', UnsupportedRequest, {})])
] ]
PROXY_SCHEME_TESTS = [ PROXY_SCHEME_TESTS = [
# scheme, expected to fail # scheme, expected to fail
('Urllib', [ ('Urllib', 'http', [
('http', False), ('http', False),
('https', UnsupportedRequest), ('https', UnsupportedRequest),
('socks4', False), ('socks4', False),
@ -944,7 +946,7 @@ class TestRequestHandlerValidation:
('socks5h', False), ('socks5h', False),
('socks', UnsupportedRequest), ('socks', UnsupportedRequest),
]), ]),
('Requests', [ ('Requests', 'http', [
('http', False), ('http', False),
('https', False), ('https', False),
('socks4', False), ('socks4', False),
@ -952,8 +954,11 @@ class TestRequestHandlerValidation:
('socks5', False), ('socks5', False),
('socks5h', False), ('socks5h', False),
]), ]),
(NoCheckRH, [('http', False)]), (NoCheckRH, 'http', [('http', False)]),
(HTTPSupportedRH, [('http', UnsupportedRequest)]), (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
('Websockets', 'ws', [('http', UnsupportedRequest)]),
(NoCheckRH, 'http', [('http', False)]),
(HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
] ]
PROXY_KEY_TESTS = [ PROXY_KEY_TESTS = [
@ -972,7 +977,7 @@ class TestRequestHandlerValidation:
] ]
EXTENSION_TESTS = [ EXTENSION_TESTS = [
('Urllib', [ ('Urllib', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': 'notacookiejar'}, AssertionError),
({'cookiejar': YoutubeDLCookieJar()}, False), ({'cookiejar': YoutubeDLCookieJar()}, False),
({'cookiejar': CookieJar()}, AssertionError), ({'cookiejar': CookieJar()}, AssertionError),
@ -980,17 +985,21 @@ class TestRequestHandlerValidation:
({'timeout': 'notatimeout'}, AssertionError), ({'timeout': 'notatimeout'}, AssertionError),
({'unsupported': 'value'}, UnsupportedRequest), ({'unsupported': 'value'}, UnsupportedRequest),
]), ]),
('Requests', [ ('Requests', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': 'notacookiejar'}, AssertionError),
({'cookiejar': YoutubeDLCookieJar()}, False), ({'cookiejar': YoutubeDLCookieJar()}, False),
({'timeout': 1}, False), ({'timeout': 1}, False),
({'timeout': 'notatimeout'}, AssertionError), ({'timeout': 'notatimeout'}, AssertionError),
({'unsupported': 'value'}, UnsupportedRequest), ({'unsupported': 'value'}, UnsupportedRequest),
]), ]),
(NoCheckRH, [ (NoCheckRH, 'http', [
({'cookiejar': 'notacookiejar'}, False), ({'cookiejar': 'notacookiejar'}, False),
({'somerandom': 'test'}, False), # but any extension is allowed through ({'somerandom': 'test'}, False), # but any extension is allowed through
]), ]),
('Websockets', 'ws', [
({'cookiejar': YoutubeDLCookieJar()}, False),
({'timeout': 2}, False),
]),
] ]
@pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [ @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
@ -1016,14 +1025,14 @@ class TestRequestHandlerValidation:
run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'})) run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'}))
run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'}) run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'})
@pytest.mark.parametrize('handler,scheme,fail', [ @pytest.mark.parametrize('handler,req_scheme,scheme,fail', [
(handler_tests[0], scheme, fail) (handler_tests[0], handler_tests[1], scheme, fail)
for handler_tests in PROXY_SCHEME_TESTS for handler_tests in PROXY_SCHEME_TESTS
for scheme, fail in handler_tests[1] for scheme, fail in handler_tests[2]
], indirect=['handler']) ], indirect=['handler'])
def test_proxy_scheme(self, handler, scheme, fail): def test_proxy_scheme(self, handler, req_scheme, scheme, fail):
run_validation(handler, fail, Request('http://', proxies={'http': f'{scheme}://example.com'})) run_validation(handler, fail, Request(f'{req_scheme}://', proxies={req_scheme: f'{scheme}://example.com'}))
run_validation(handler, fail, Request('http://'), proxies={'http': f'{scheme}://example.com'}) run_validation(handler, fail, Request(f'{req_scheme}://'), proxies={req_scheme: f'{scheme}://example.com'})
@pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests'], indirect=True) @pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests'], indirect=True)
def test_empty_proxy(self, handler): def test_empty_proxy(self, handler):
@ -1035,14 +1044,14 @@ class TestRequestHandlerValidation:
def test_invalid_proxy_url(self, handler, proxy_url): def test_invalid_proxy_url(self, handler, proxy_url):
run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url})) run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url}))
@pytest.mark.parametrize('handler,extensions,fail', [ @pytest.mark.parametrize('handler,scheme,extensions,fail', [
(handler_tests[0], extensions, fail) (handler_tests[0], handler_tests[1], extensions, fail)
for handler_tests in EXTENSION_TESTS for handler_tests in EXTENSION_TESTS
for extensions, fail in handler_tests[1] for extensions, fail in handler_tests[2]
], indirect=['handler']) ], indirect=['handler'])
def test_extension(self, handler, extensions, fail): def test_extension(self, handler, scheme, extensions, fail):
run_validation( run_validation(
handler, fail, Request('http://', extensions=extensions)) handler, fail, Request(f'{scheme}://', extensions=extensions))
def test_invalid_request_type(self): def test_invalid_request_type(self):
rh = self.ValidationRH(logger=FakeLogger()) rh = self.ValidationRH(logger=FakeLogger())
@ -1075,6 +1084,22 @@ class FakeRHYDL(FakeYDL):
self._request_director = self.build_request_director([FakeRH]) self._request_director = self.build_request_director([FakeRH])
class AllUnsupportedRHYDL(FakeYDL):
def __init__(self, *args, **kwargs):
class UnsupportedRH(RequestHandler):
def _send(self, request: Request):
pass
_SUPPORTED_FEATURES = ()
_SUPPORTED_PROXY_SCHEMES = ()
_SUPPORTED_URL_SCHEMES = ()
super().__init__(*args, **kwargs)
self._request_director = self.build_request_director([UnsupportedRH])
class TestRequestDirector: class TestRequestDirector:
def test_handler_operations(self): def test_handler_operations(self):
@ -1234,6 +1259,12 @@ class TestYoutubeDLNetworking:
with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'): with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'):
ydl.urlopen('file://') ydl.urlopen('file://')
@pytest.mark.parametrize('scheme', (['ws', 'wss']))
def test_websocket_unavailable_error(self, scheme):
with AllUnsupportedRHYDL() as ydl:
with pytest.raises(RequestError, match=r'This request requires WebSocket support'):
ydl.urlopen(f'{scheme}://')
def test_legacy_server_connect_error(self): def test_legacy_server_connect_error(self):
with FakeRHYDL() as ydl: with FakeRHYDL() as ydl:
for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'): for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'):

@ -210,6 +210,16 @@ class SocksHTTPTestRequestHandler(http.server.BaseHTTPRequestHandler, SocksTestR
self.wfile.write(payload.encode()) self.wfile.write(payload.encode())
class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
def handle(self):
import websockets.sync.server
protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
connection.handshake()
connection.send(json.dumps(self.socks_info))
connection.close()
@contextlib.contextmanager @contextlib.contextmanager
def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs): def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
server = server_thread = None server = server_thread = None
@ -252,8 +262,22 @@ class HTTPSocksTestProxyContext(SocksProxyTestContext):
return json.loads(handler.send(request).read().decode()) return json.loads(handler.send(request).read().decode())
class WebSocketSocksTestProxyContext(SocksProxyTestContext):
REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler
def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
handler.validate(request)
ws = handler.send(request)
ws.send('socks_info')
socks_info = ws.recv()
ws.close()
return json.loads(socks_info)
CTX_MAP = { CTX_MAP = {
'http': HTTPSocksTestProxyContext, 'http': HTTPSocksTestProxyContext,
'ws': WebSocketSocksTestProxyContext,
} }
@ -263,7 +287,7 @@ def ctx(request):
class TestSocks4Proxy: class TestSocks4Proxy:
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4_no_auth(self, handler, ctx): def test_socks4_no_auth(self, handler, ctx):
with handler() as rh: with handler() as rh:
with ctx.socks_server(Socks4ProxyHandler) as server_address: with ctx.socks_server(Socks4ProxyHandler) as server_address:
@ -271,7 +295,7 @@ class TestSocks4Proxy:
rh, proxies={'all': f'socks4://{server_address}'}) rh, proxies={'all': f'socks4://{server_address}'})
assert response['version'] == 4 assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4_auth(self, handler, ctx): def test_socks4_auth(self, handler, ctx):
with handler() as rh: with handler() as rh:
with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address: with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
@ -281,7 +305,7 @@ class TestSocks4Proxy:
rh, proxies={'all': f'socks4://user:@{server_address}'}) rh, proxies={'all': f'socks4://user:@{server_address}'})
assert response['version'] == 4 assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4a_ipv4_target(self, handler, ctx): def test_socks4a_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address: with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh: with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
@ -289,7 +313,7 @@ class TestSocks4Proxy:
assert response['version'] == 4 assert response['version'] == 4
assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1') assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks4a_domain_target(self, handler, ctx): def test_socks4a_domain_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address: with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh: with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
@ -298,7 +322,7 @@ class TestSocks4Proxy:
assert response['ipv4_address'] is None assert response['ipv4_address'] is None
assert response['domain_address'] == 'localhost' assert response['domain_address'] == 'localhost'
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx): def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address: with ctx.socks_server(Socks4ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}' source_address = f'127.0.0.{random.randint(5, 255)}'
@ -308,7 +332,7 @@ class TestSocks4Proxy:
assert response['client_address'][0] == source_address assert response['client_address'][0] == source_address
assert response['version'] == 4 assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('reply_code', [ @pytest.mark.parametrize('reply_code', [
Socks4CD.REQUEST_REJECTED_OR_FAILED, Socks4CD.REQUEST_REJECTED_OR_FAILED,
Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD, Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
@ -320,7 +344,7 @@ class TestSocks4Proxy:
with pytest.raises(ProxyError): with pytest.raises(ProxyError):
ctx.socks_info_request(rh) ctx.socks_info_request(rh)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv6_socks4_proxy(self, handler, ctx): def test_ipv6_socks4_proxy(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address: with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}) as rh: with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
@ -329,7 +353,7 @@ class TestSocks4Proxy:
assert response['ipv4_address'] == '127.0.0.1' assert response['ipv4_address'] == '127.0.0.1'
assert response['version'] == 4 assert response['version'] == 4
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_timeout(self, handler, ctx): def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address: with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh: with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
@ -339,7 +363,7 @@ class TestSocks4Proxy:
class TestSocks5Proxy: class TestSocks5Proxy:
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_no_auth(self, handler, ctx): def test_socks5_no_auth(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -347,7 +371,7 @@ class TestSocks5Proxy:
assert response['auth_methods'] == [0x0] assert response['auth_methods'] == [0x0]
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_user_pass(self, handler, ctx): def test_socks5_user_pass(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address: with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
with handler() as rh: with handler() as rh:
@ -360,7 +384,7 @@ class TestSocks5Proxy:
assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS] assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_ipv4_target(self, handler, ctx): def test_socks5_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -368,7 +392,7 @@ class TestSocks5Proxy:
assert response['ipv4_address'] == '127.0.0.1' assert response['ipv4_address'] == '127.0.0.1'
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_domain_target(self, handler, ctx): def test_socks5_domain_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -376,7 +400,7 @@ class TestSocks5Proxy:
assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1') assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5h_domain_target(self, handler, ctx): def test_socks5h_domain_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh: with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
@ -385,7 +409,7 @@ class TestSocks5Proxy:
assert response['domain_address'] == 'localhost' assert response['domain_address'] == 'localhost'
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5h_ip_target(self, handler, ctx): def test_socks5h_ip_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh: with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
@ -394,7 +418,7 @@ class TestSocks5Proxy:
assert response['domain_address'] is None assert response['domain_address'] is None
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_socks5_ipv6_destination(self, handler, ctx): def test_socks5_ipv6_destination(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -402,7 +426,7 @@ class TestSocks5Proxy:
assert response['ipv6_address'] == '::1' assert response['ipv6_address'] == '::1'
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv6_socks5_proxy(self, handler, ctx): def test_ipv6_socks5_proxy(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address: with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh: with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@ -413,7 +437,7 @@ class TestSocks5Proxy:
# XXX: is there any feasible way of testing IPv6 source addresses? # XXX: is there any feasible way of testing IPv6 source addresses?
# Same would go for non-proxy source_address test... # Same would go for non-proxy source_address test...
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx): def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address: with ctx.socks_server(Socks5ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}' source_address = f'127.0.0.{random.randint(5, 255)}'
@ -422,7 +446,7 @@ class TestSocks5Proxy:
assert response['client_address'][0] == source_address assert response['client_address'][0] == source_address
assert response['version'] == 5 assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
@pytest.mark.parametrize('reply_code', [ @pytest.mark.parametrize('reply_code', [
Socks5Reply.GENERAL_FAILURE, Socks5Reply.GENERAL_FAILURE,
Socks5Reply.CONNECTION_NOT_ALLOWED, Socks5Reply.CONNECTION_NOT_ALLOWED,
@ -439,7 +463,7 @@ class TestSocks5Proxy:
with pytest.raises(ProxyError): with pytest.raises(ProxyError):
ctx.socks_info_request(rh) ctx.socks_info_request(rh)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Websockets', 'ws')], indirect=True)
def test_timeout(self, handler, ctx): def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address: with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh: with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:

@ -0,0 +1,380 @@
#!/usr/bin/env python3
# Allow direct execution
import os
import sys
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import http.client
import http.cookiejar
import http.server
import json
import random
import ssl
import threading
from yt_dlp import socks
from yt_dlp.cookies import YoutubeDLCookieJar
from yt_dlp.dependencies import websockets
from yt_dlp.networking import Request
from yt_dlp.networking.exceptions import (
CertificateVerifyError,
HTTPError,
ProxyError,
RequestError,
SSLError,
TransportError,
)
from yt_dlp.utils.networking import HTTPHeaderDict
from test.conftest import validate_and_send
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
def websocket_handler(websocket):
for message in websocket:
if isinstance(message, bytes):
if message == b'bytes':
return websocket.send('2')
elif isinstance(message, str):
if message == 'headers':
return websocket.send(json.dumps(dict(websocket.request.headers)))
elif message == 'path':
return websocket.send(websocket.request.path)
elif message == 'source_address':
return websocket.send(websocket.remote_address[0])
elif message == 'str':
return websocket.send('1')
return websocket.send(message)
def process_request(self, request):
if request.path.startswith('/gen_'):
status = http.HTTPStatus(int(request.path[5:]))
if 300 <= status.value <= 300:
return websockets.http11.Response(
status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
return self.protocol.reject(status.value, status.phrase)
return self.protocol.accept(request)
def create_websocket_server(**ws_kwargs):
import websockets.sync.server
wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
ws_port = wsd.socket.getsockname()[1]
ws_server_thread = threading.Thread(target=wsd.serve_forever)
ws_server_thread.daemon = True
ws_server_thread.start()
return ws_server_thread, ws_port
def create_ws_websocket_server():
return create_websocket_server()
def create_wss_websocket_server():
certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx)
MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
def create_mtls_wss_websocket_server():
certfn = os.path.join(TEST_DIR, 'testcert.pem')
cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.verify_mode = ssl.CERT_REQUIRED
sslctx.load_verify_locations(cafile=cacertfn)
sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx)
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
class TestWebsSocketRequestHandlerConformance:
@classmethod
def setup_class(cls):
cls.ws_thread, cls.ws_port = create_ws_websocket_server()
cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
cls.wss_thread, cls.wss_port = create_wss_websocket_server()
cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_basic_websockets(self, handler):
with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
assert 'upgrade' in ws.headers
assert ws.status == 101
ws.send('foo')
assert ws.recv() == 'foo'
ws.close()
# https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
@pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_send_types(self, handler, msg, opcode):
with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send(msg)
assert int(ws.recv()) == opcode
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_verify_cert(self, handler):
with handler() as rh:
with pytest.raises(CertificateVerifyError):
validate_and_send(rh, Request(self.wss_base_url))
with handler(verify=False) as rh:
ws = validate_and_send(rh, Request(self.wss_base_url))
assert ws.status == 101
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_ssl_error(self, handler):
with handler(verify=False) as rh:
with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
validate_and_send(rh, Request(self.bad_wss_host))
assert not issubclass(exc_info.type, CertificateVerifyError)
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('path,expected', [
# Unicode characters should be encoded with uppercase percent-encoding
('/中文', '/%E4%B8%AD%E6%96%87'),
# don't normalize existing percent encodings
('/%c7%9f', '/%c7%9f'),
])
def test_percent_encode(self, handler, path, expected):
with handler() as rh:
ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
ws.send('path')
assert ws.recv() == expected
assert ws.status == 101
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_remove_dot_segments(self, handler):
with handler() as rh:
# This isn't a comprehensive test,
# but it should be enough to check whether the handler is removing dot segments
ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
assert ws.status == 101
ws.send('path')
assert ws.recv() == '/test'
ws.close()
# We are restricted to known HTTP status codes in http.HTTPStatus
# Redirects are not supported for websockets
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
def test_raise_http_error(self, handler, status):
with handler() as rh:
with pytest.raises(HTTPError) as exc_info:
validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
assert exc_info.value.status == status
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@pytest.mark.parametrize('params,extensions', [
({'timeout': 0.00001}, {}),
({}, {'timeout': 0.00001}),
])
def test_timeout(self, handler, params, extensions):
with handler(**params) as rh:
with pytest.raises(TransportError):
validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_cookies(self, handler):
cookiejar = YoutubeDLCookieJar()
cookiejar.set_cookie(http.cookiejar.Cookie(
version=0, name='test', value='ytdlp', port=None, port_specified=False,
domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
path_specified=True, secure=False, expires=None, discard=False, comment=None,
comment_url=None, rest={}))
with handler(cookiejar=cookiejar) as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close()
with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
assert 'cookie' not in json.loads(ws.recv())
ws.close()
ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}'
with handler(source_address=source_address) as rh:
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('source_address')
assert source_address == ws.recv()
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_response_url(self, handler):
with handler() as rh:
url = f'{self.ws_base_url}/something'
ws = validate_and_send(rh, Request(url))
assert ws.url == url
ws.close()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_request_headers(self, handler):
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
# Global Headers
ws = validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
headers = HTTPHeaderDict(json.loads(ws.recv()))
assert headers['test1'] == 'test'
ws.close()
# Per request headers, merged with global
ws = validate_and_send(rh, Request(
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
ws.send('headers')
headers = HTTPHeaderDict(json.loads(ws.recv()))
assert headers['test1'] == 'test'
assert headers['test2'] == 'changed'
assert headers['test3'] == 'test3'
ws.close()
@pytest.mark.parametrize('client_cert', (
{'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
'client_certificate_password': 'foobar',
},
{
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
'client_certificate_password': 'foobar',
}
))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_mtls(self, handler, client_cert):
with handler(
# Disable client-side validation of unacceptable self-signed testcert.pem
# The test is of a check on the server side, so unaffected
verify=False,
client_cert=client_cert
) as rh:
validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
def create_fake_ws_connection(raised):
import websockets.sync.client
class FakeWsConnection(websockets.sync.client.ClientConnection):
def __init__(self, *args, **kwargs):
class FakeResponse:
body = b''
headers = {}
status_code = 101
reason_phrase = 'test'
self.response = FakeResponse()
def send(self, *args, **kwargs):
raise raised()
def recv(self, *args, **kwargs):
raise raised()
def close(self, *args, **kwargs):
return
return FakeWsConnection()
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
class TestWebsocketsRequestHandler:
@pytest.mark.parametrize('raised,expected', [
# https://websockets.readthedocs.io/en/stable/reference/exceptions.html
(lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
# Requires a response object. Should be covered by HTTP error tests.
# (lambda: websockets.exceptions.InvalidStatus(), TransportError),
(lambda: websockets.exceptions.InvalidHandshake(), TransportError),
# These are subclasses of InvalidHandshake
(lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
(lambda: websockets.exceptions.NegotiationError(), TransportError),
# Catch-all
(lambda: websockets.exceptions.WebSocketException(), TransportError),
(lambda: TimeoutError(), TransportError),
# These may be raised by our create_connection implementation, which should also be caught
(lambda: OSError(), TransportError),
(lambda: ssl.SSLError(), SSLError),
(lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
(lambda: socks.ProxyError(), ProxyError),
])
def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
import websockets.sync.client
import yt_dlp.networking._websockets
with handler() as rh:
def fake_connect(*args, **kwargs):
raise raised()
monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
with pytest.raises(expected) as exc_info:
rh.send(Request('ws://fake-url'))
assert exc_info.type is expected
@pytest.mark.parametrize('raised,expected,match', [
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
(lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
(lambda: RuntimeError(), TransportError, None),
(lambda: TimeoutError(), TransportError, None),
(lambda: TypeError(), RequestError, None),
(lambda: socks.ProxyError(), ProxyError, None),
# Catch-all
(lambda: websockets.exceptions.WebSocketException(), TransportError, None),
])
def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
from yt_dlp.networking._websockets import WebsocketsResponseAdapter
ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
with pytest.raises(expected, match=match) as exc_info:
ws.send('test')
assert exc_info.type is expected
@pytest.mark.parametrize('raised,expected,match', [
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
(lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
(lambda: RuntimeError(), TransportError, None),
(lambda: TimeoutError(), TransportError, None),
(lambda: socks.ProxyError(), ProxyError, None),
# Catch-all
(lambda: websockets.exceptions.WebSocketException(), TransportError, None),
])
def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
from yt_dlp.networking._websockets import WebsocketsResponseAdapter
ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
with pytest.raises(expected, match=match) as exc_info:
ws.recv()
assert exc_info.type is expected

@ -4052,6 +4052,7 @@ class YoutubeDL:
return self._request_director.send(req) return self._request_director.send(req)
except NoSupportingHandlers as e: except NoSupportingHandlers as e:
for ue in e.unsupported_errors: for ue in e.unsupported_errors:
# FIXME: This depends on the order of errors.
if not (ue.handler and ue.msg): if not (ue.handler and ue.msg):
continue continue
if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower(): if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower():
@ -4061,6 +4062,15 @@ class YoutubeDL:
if 'unsupported proxy type: "https"' in ue.msg.lower(): if 'unsupported proxy type: "https"' in ue.msg.lower():
raise RequestError( raise RequestError(
'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests') 'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests')
elif (
re.match(r'unsupported url scheme: "wss?"', ue.msg.lower())
and 'websockets' not in self._request_director.handlers
):
raise RequestError(
'This request requires WebSocket support. '
'Ensure one of the following dependencies are installed: websockets',
cause=ue) from ue
raise raise
except SSLError as e: except SSLError as e:
if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e): if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e):

@ -6,7 +6,7 @@ from . import get_suitable_downloader
from .common import FileDownloader from .common import FileDownloader
from .external import FFmpegFD from .external import FFmpegFD
from ..networking import Request from ..networking import Request
from ..utils import DownloadError, WebSocketsWrapper, str_or_none, try_get from ..utils import DownloadError, str_or_none, try_get
class NiconicoDmcFD(FileDownloader): class NiconicoDmcFD(FileDownloader):
@ -64,7 +64,6 @@ class NiconicoLiveFD(FileDownloader):
ws_url = info_dict['url'] ws_url = info_dict['url']
ws_extractor = info_dict['ws'] ws_extractor = info_dict['ws']
ws_origin_host = info_dict['origin'] ws_origin_host = info_dict['origin']
cookies = info_dict.get('cookies')
live_quality = info_dict.get('live_quality', 'high') live_quality = info_dict.get('live_quality', 'high')
live_latency = info_dict.get('live_latency', 'high') live_latency = info_dict.get('live_latency', 'high')
dl = FFmpegFD(self.ydl, self.params or {}) dl = FFmpegFD(self.ydl, self.params or {})
@ -76,12 +75,7 @@ class NiconicoLiveFD(FileDownloader):
def communicate_ws(reconnect): def communicate_ws(reconnect):
if reconnect: if reconnect:
ws = WebSocketsWrapper(ws_url, { ws = self.ydl.urlopen(Request(ws_url, headers={'Origin': f'https://{ws_origin_host}'}))
'Cookies': str_or_none(cookies) or '',
'Origin': f'https://{ws_origin_host}',
'Accept': '*/*',
'User-Agent': self.params['http_headers']['User-Agent'],
})
if self.ydl.params.get('verbose', False): if self.ydl.params.get('verbose', False):
self.to_screen('[debug] Sending startWatching request') self.to_screen('[debug] Sending startWatching request')
ws.send(json.dumps({ ws.send(json.dumps({

@ -2,11 +2,9 @@ import re
from .common import InfoExtractor from .common import InfoExtractor
from ..compat import compat_parse_qs from ..compat import compat_parse_qs
from ..dependencies import websockets
from ..networking import Request from ..networking import Request
from ..utils import ( from ..utils import (
ExtractorError, ExtractorError,
WebSocketsWrapper,
js_to_json, js_to_json,
traverse_obj, traverse_obj,
update_url_query, update_url_query,
@ -167,8 +165,6 @@ class FC2LiveIE(InfoExtractor):
}] }]
def _real_extract(self, url): def _real_extract(self, url):
if not websockets:
raise ExtractorError('websockets library is not available. Please install it.', expected=True)
video_id = self._match_id(url) video_id = self._match_id(url)
webpage = self._download_webpage('https://live.fc2.com/%s/' % video_id, video_id) webpage = self._download_webpage('https://live.fc2.com/%s/' % video_id, video_id)
@ -199,13 +195,9 @@ class FC2LiveIE(InfoExtractor):
ws_url = update_url_query(control_server['url'], {'control_token': control_server['control_token']}) ws_url = update_url_query(control_server['url'], {'control_token': control_server['control_token']})
playlist_data = None playlist_data = None
self.to_screen('%s: Fetching HLS playlist info via WebSocket' % video_id) ws = self._request_webpage(Request(ws_url, headers={
ws = WebSocketsWrapper(ws_url, {
'Cookie': str(self._get_cookies('https://live.fc2.com/'))[12:],
'Origin': 'https://live.fc2.com', 'Origin': 'https://live.fc2.com',
'Accept': '*/*', }), video_id, note='Fetching HLS playlist info via WebSocket')
'User-Agent': self.get_param('http_headers')['User-Agent'],
})
self.write_debug('Sending HLS server request') self.write_debug('Sending HLS server request')

@ -8,12 +8,11 @@ import time
from urllib.parse import urlparse from urllib.parse import urlparse
from .common import InfoExtractor, SearchInfoExtractor from .common import InfoExtractor, SearchInfoExtractor
from ..dependencies import websockets from ..networking import Request
from ..networking.exceptions import HTTPError from ..networking.exceptions import HTTPError
from ..utils import ( from ..utils import (
ExtractorError, ExtractorError,
OnDemandPagedList, OnDemandPagedList,
WebSocketsWrapper,
bug_reports_message, bug_reports_message,
clean_html, clean_html,
float_or_none, float_or_none,
@ -934,8 +933,6 @@ class NiconicoLiveIE(InfoExtractor):
_KNOWN_LATENCY = ('high', 'low') _KNOWN_LATENCY = ('high', 'low')
def _real_extract(self, url): def _real_extract(self, url):
if not websockets:
raise ExtractorError('websockets library is not available. Please install it.', expected=True)
video_id = self._match_id(url) video_id = self._match_id(url)
webpage, urlh = self._download_webpage_handle(f'https://live.nicovideo.jp/watch/{video_id}', video_id) webpage, urlh = self._download_webpage_handle(f'https://live.nicovideo.jp/watch/{video_id}', video_id)
@ -950,17 +947,13 @@ class NiconicoLiveIE(InfoExtractor):
}) })
hostname = remove_start(urlparse(urlh.url).hostname, 'sp.') hostname = remove_start(urlparse(urlh.url).hostname, 'sp.')
cookies = try_get(urlh.url, self._downloader._calc_cookies)
latency = try_get(self._configuration_arg('latency'), lambda x: x[0]) latency = try_get(self._configuration_arg('latency'), lambda x: x[0])
if latency not in self._KNOWN_LATENCY: if latency not in self._KNOWN_LATENCY:
latency = 'high' latency = 'high'
ws = WebSocketsWrapper(ws_url, { ws = self._request_webpage(
'Cookies': str_or_none(cookies) or '', Request(ws_url, headers={'Origin': f'https://{hostname}'}),
'Origin': f'https://{hostname}', video_id=video_id, note='Connecting to WebSocket server')
'Accept': '*/*',
'User-Agent': self.get_param('http_headers')['User-Agent'],
})
self.write_debug('[debug] Sending HLS server request') self.write_debug('[debug] Sending HLS server request')
ws.send(json.dumps({ ws.send(json.dumps({
@ -1034,7 +1027,6 @@ class NiconicoLiveIE(InfoExtractor):
'protocol': 'niconico_live', 'protocol': 'niconico_live',
'ws': ws, 'ws': ws,
'video_id': video_id, 'video_id': video_id,
'cookies': cookies,
'live_latency': latency, 'live_latency': latency,
'origin': hostname, 'origin': hostname,
}) })

@ -21,3 +21,11 @@ except ImportError:
pass pass
except Exception as e: except Exception as e:
warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message()) warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message())
try:
from . import _websockets
except ImportError:
pass
except Exception as e:
warnings.warn(f'Failed to import "websockets" request handler: {e}' + bug_reports_message())

@ -0,0 +1,159 @@
from __future__ import annotations
import io
import logging
import ssl
import sys
from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket
from .common import Response, register_rh, Features
from .exceptions import (
CertificateVerifyError,
HTTPError,
RequestError,
SSLError,
TransportError, ProxyError,
)
from .websocket import WebSocketRequestHandler, WebSocketResponse
from ..compat import functools
from ..dependencies import websockets
from ..utils import int_or_none
from ..socks import ProxyError as SocksProxyError
if not websockets:
raise ImportError('websockets is not installed')
import websockets.version
websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
if websockets_version < (12, 0):
raise ImportError('Only websockets>=12.0 is supported')
import websockets.sync.client
from websockets.uri import parse_uri
class WebsocketsResponseAdapter(WebSocketResponse):
def __init__(self, wsw: websockets.sync.client.ClientConnection, url):
super().__init__(
fp=io.BytesIO(wsw.response.body or b''),
url=url,
headers=wsw.response.headers,
status=wsw.response.status_code,
reason=wsw.response.reason_phrase,
)
self.wsw = wsw
def close(self):
self.wsw.close()
super().close()
def send(self, message):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
try:
return self.wsw.send(message)
except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
raise TransportError(cause=e) from e
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except TypeError as e:
raise RequestError(cause=e) from e
def recv(self):
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
try:
return self.wsw.recv()
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
raise TransportError(cause=e) from e
@register_rh
class WebsocketsRH(WebSocketRequestHandler):
"""
Websockets request handler
https://websockets.readthedocs.io
https://github.com/python-websockets/websockets
"""
_SUPPORTED_URL_SCHEMES = ('wss', 'ws')
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
RH_NAME = 'websockets'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for name in ('websockets.client', 'websockets.server'):
logger = logging.getLogger(name)
handler = logging.StreamHandler(stream=sys.stdout)
handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
logger.addHandler(handler)
if self.verbose:
logger.setLevel(logging.DEBUG)
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
extensions.pop('timeout', None)
extensions.pop('cookiejar', None)
def _send(self, request):
timeout = float(request.extensions.get('timeout') or self.timeout)
headers = self._merge_headers(request.headers)
if 'cookie' not in headers:
cookiejar = request.extensions.get('cookiejar') or self.cookiejar
cookie_header = cookiejar.get_cookie_header(request.url)
if cookie_header:
headers['cookie'] = cookie_header
wsuri = parse_uri(request.url)
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
'timeout': timeout
}
proxy = select_proxy(request.url, request.proxies or self.proxies or {})
try:
if proxy:
socks_proxy_options = make_socks_proxy_opts(proxy)
sock = create_connection(
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
_create_socket_func=functools.partial(
create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
**create_conn_kwargs
)
else:
sock = create_connection(
address=(wsuri.host, wsuri.port),
**create_conn_kwargs
)
conn = websockets.sync.client.connect(
sock=sock,
uri=request.url,
additional_headers=headers,
open_timeout=timeout,
user_agent_header=None,
ssl_context=self._make_sslcontext() if wsuri.secure else None,
close_timeout=0, # not ideal, but prevents yt-dlp hanging
)
return WebsocketsResponseAdapter(conn, url=request.url)
# Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
except SocksProxyError as e:
raise ProxyError(cause=e) from e
except websockets.exceptions.InvalidURI as e:
raise RequestError(cause=e) from e
except ssl.SSLCertVerificationError as e:
raise CertificateVerifyError(cause=e) from e
except ssl.SSLError as e:
raise SSLError(cause=e) from e
except websockets.exceptions.InvalidStatus as e:
raise HTTPError(
Response(
fp=io.BytesIO(e.response.body),
url=request.url,
headers=e.response.headers,
status=e.response.status_code,
reason=e.response.reason_phrase),
) from e
except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
raise TransportError(cause=e) from e

@ -0,0 +1,23 @@
from __future__ import annotations
import abc
from .common import Response, RequestHandler
class WebSocketResponse(Response):
def send(self, message: bytes | str):
"""
Send a message to the server.
@param message: The message to send. A string (str) is sent as a text frame, bytes is sent as a binary frame.
"""
raise NotImplementedError
def recv(self):
raise NotImplementedError
class WebSocketRequestHandler(RequestHandler, abc.ABC):
pass

@ -1,4 +1,6 @@
"""No longer used and new code should not use. Exists only for API compat.""" """No longer used and new code should not use. Exists only for API compat."""
import asyncio
import atexit
import platform import platform
import struct import struct
import sys import sys
@ -32,6 +34,77 @@ has_certifi = bool(certifi)
has_websockets = bool(websockets) has_websockets = bool(websockets)
class WebSocketsWrapper:
"""Wraps websockets module to use in non-async scopes"""
pool = None
def __init__(self, url, headers=None, connect=True, **ws_kwargs):
self.loop = asyncio.new_event_loop()
# XXX: "loop" is deprecated
self.conn = websockets.connect(
url, extra_headers=headers, ping_interval=None,
close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'), **ws_kwargs)
if connect:
self.__enter__()
atexit.register(self.__exit__, None, None, None)
def __enter__(self):
if not self.pool:
self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
return self
def send(self, *args):
self.run_with_loop(self.pool.send(*args), self.loop)
def recv(self, *args):
return self.run_with_loop(self.pool.recv(*args), self.loop)
def __exit__(self, type, value, traceback):
try:
return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
finally:
self.loop.close()
self._cancel_all_tasks(self.loop)
# taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
# for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
@staticmethod
def run_with_loop(main, loop):
if not asyncio.iscoroutine(main):
raise ValueError(f'a coroutine was expected, got {main!r}')
try:
return loop.run_until_complete(main)
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
if hasattr(loop, 'shutdown_default_executor'):
loop.run_until_complete(loop.shutdown_default_executor())
@staticmethod
def _cancel_all_tasks(loop):
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
# XXX: "loop" is removed in python 3.10+
loop.run_until_complete(
asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})
def load_plugins(name, suffix, namespace): def load_plugins(name, suffix, namespace):
from ..plugins import load_plugins from ..plugins import load_plugins
ret = load_plugins(name, suffix) ret = load_plugins(name, suffix)

@ -1,5 +1,3 @@
import asyncio
import atexit
import base64 import base64
import binascii import binascii
import calendar import calendar
@ -54,7 +52,7 @@ from ..compat import (
compat_os_name, compat_os_name,
compat_shlex_quote, compat_shlex_quote,
) )
from ..dependencies import websockets, xattr from ..dependencies import xattr
__name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module __name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module
@ -4923,77 +4921,6 @@ class Config:
return self.parser.parse_args(self.all_args) return self.parser.parse_args(self.all_args)
class WebSocketsWrapper:
"""Wraps websockets module to use in non-async scopes"""
pool = None
def __init__(self, url, headers=None, connect=True):
self.loop = asyncio.new_event_loop()
# XXX: "loop" is deprecated
self.conn = websockets.connect(
url, extra_headers=headers, ping_interval=None,
close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
if connect:
self.__enter__()
atexit.register(self.__exit__, None, None, None)
def __enter__(self):
if not self.pool:
self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
return self
def send(self, *args):
self.run_with_loop(self.pool.send(*args), self.loop)
def recv(self, *args):
return self.run_with_loop(self.pool.recv(*args), self.loop)
def __exit__(self, type, value, traceback):
try:
return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
finally:
self.loop.close()
self._cancel_all_tasks(self.loop)
# taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
# for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
@staticmethod
def run_with_loop(main, loop):
if not asyncio.iscoroutine(main):
raise ValueError(f'a coroutine was expected, got {main!r}')
try:
return loop.run_until_complete(main)
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
if hasattr(loop, 'shutdown_default_executor'):
loop.run_until_complete(loop.shutdown_default_executor())
@staticmethod
def _cancel_all_tasks(loop):
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
# XXX: "loop" is removed in python 3.10+
loop.run_until_complete(
asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})
def merge_headers(*dicts): def merge_headers(*dicts):
"""Merge dicts of http headers case insensitively, prioritizing the latter ones""" """Merge dicts of http headers case insensitively, prioritizing the latter ones"""
return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))} return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}

Loading…
Cancel
Save