From da42679b87005a7a3e08496dc9f5959234e2a8a8 Mon Sep 17 00:00:00 2001 From: "Lesmiscore (Naoya Ozaki)" Date: Sun, 13 Feb 2022 14:58:21 +0900 Subject: [PATCH] [utils] WebSockets wrapper for non-async functions (#2417) Authored by: Lesmiscore --- yt_dlp/compat.py | 11 ++++++++ yt_dlp/utils.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/yt_dlp/compat.py b/yt_dlp/compat.py index b97d4512e..2bc6a6b7f 100644 --- a/yt_dlp/compat.py +++ b/yt_dlp/compat.py @@ -134,6 +134,16 @@ except AttributeError: asyncio.run = compat_asyncio_run +try: # >= 3.7 + asyncio.tasks.all_tasks +except AttributeError: + asyncio.tasks.all_tasks = asyncio.tasks.Task.all_tasks + +try: + import websockets as compat_websockets +except ImportError: + compat_websockets = None + # Python 3.8+ does not honor %HOME% on windows, but this breaks compatibility with youtube-dl # See https://github.com/yt-dlp/yt-dlp/issues/792 # https://docs.python.org/3/library/os.path.html#os.path.expanduser @@ -303,6 +313,7 @@ __all__ = [ 'compat_urllib_response', 'compat_urlparse', 'compat_urlretrieve', + 'compat_websockets', 'compat_xml_parse_error', 'compat_xpath', 'compat_zip', diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index bb8d65cad..c5489d494 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals +import asyncio import base64 import binascii import calendar @@ -73,6 +74,7 @@ from .compat import ( compat_urllib_parse_unquote_plus, compat_urllib_request, compat_urlparse, + compat_websockets, compat_xpath, ) @@ -5311,3 +5313,70 @@ class Config: def parse_args(self): return self._parser.parse_args(list(self.all_args)) + + +class WebSocketsWrapper(): + """Wraps websockets module to use in non-async scopes""" + + def __init__(self, url, headers=None): + self.loop = asyncio.events.new_event_loop() + self.conn = compat_websockets.connect( + url, extra_headers=headers, ping_interval=None, + close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf')) + + def __enter__(self): + self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop) + return self + + def send(self, *args): + self.run_with_loop(self.pool.send(*args), self.loop) + + def recv(self, *args): + return self.run_with_loop(self.pool.recv(*args), self.loop) + + def __exit__(self, type, value, traceback): + try: + return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop) + finally: + self.loop.close() + self.r_cancel_all_tasks(self.loop) + + # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications + # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class + @staticmethod + def run_with_loop(main, loop): + if not asyncio.coroutines.iscoroutine(main): + raise ValueError(f'a coroutine was expected, got {main!r}') + + try: + return loop.run_until_complete(main) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + if hasattr(loop, 'shutdown_default_executor'): + loop.run_until_complete(loop.shutdown_default_executor()) + + @staticmethod + def _cancel_all_tasks(loop): + to_cancel = asyncio.tasks.all_tasks(loop) + + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete( + asyncio.tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'unhandled exception during asyncio.run() shutdown', + 'exception': task.exception(), + 'task': task, + }) + + +has_websockets = bool(compat_websockets)