From 7b0127e1e11186bcbb80a18b1b530d864a5dbada Mon Sep 17 00:00:00 2001 From: Simon Sawicki <37424085+Grub4K@users.noreply.github.com> Date: Sun, 9 Oct 2022 03:31:37 +0200 Subject: [PATCH] [utils] `traverse_obj`: Allow `re.Match` objects (#5174) Authored by: Grub4K --- test/test_utils.py | 20 ++++++++++++++++++++ yt_dlp/utils.py | 22 +++++++++++++++++++--- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 6f3f6cb91..90085a9c0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2,6 +2,7 @@ # Allow direct execution import os +import re import sys import unittest @@ -2080,6 +2081,25 @@ Line 1 with self.assertRaises(TypeError, msg='too many params should result in error'): traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True) + # Test re.Match as input obj + mobj = re.fullmatch(r'0(12)(?P3)(4)?', '0123') + self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None], + msg='`...` on a `re.Match` should give its `groups()`') + self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 2)), ['0123', '3'], + msg='function on a `re.Match` should give groupno, value starting at 0') + self.assertEqual(traverse_obj(mobj, 'group'), '3', + msg='str key on a `re.Match` should give group with that name') + self.assertEqual(traverse_obj(mobj, 2), '3', + msg='int key on a `re.Match` should give group with that name') + self.assertEqual(traverse_obj(mobj, 'gRoUp', casesense=False), '3', + msg='str key on a `re.Match` should respect casesense') + self.assertEqual(traverse_obj(mobj, 'fail'), None, + msg='failing str key on a `re.Match` should return `default`') + self.assertEqual(traverse_obj(mobj, 'gRoUpS', casesense=False), None, + msg='failing str key on a `re.Match` should return `default`') + self.assertEqual(traverse_obj(mobj, 8), None, + msg='failing int key on a `re.Match` should return `default`') + if __name__ == '__main__': unittest.main() diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 7d8e97162..cb14908c7 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -5305,13 +5305,14 @@ def traverse_obj( Each of the provided `paths` is tested and the first producing a valid result will be returned. The next path will also be tested if the path branched but no results could be found. + Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. A value of None is treated as the absence of a value. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. The keys in the path can be one of: - `None`: Return the current object. - - `str`/`int`: Return `obj[key]`. + - `str`/`int`: Return `obj[key]`. For `re.Match, return `obj.group(key)`. - `slice`: Branch out and return all values in `obj[key]`. - `Ellipsis`: Branch out and return a list of all values. - `tuple`/`list`: Branch out and return a list of all matching values. @@ -5322,7 +5323,7 @@ def traverse_obj( - `dict` Transform the current object and return a matching dict. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. - `tuple`, `list`, and `dict` all support nested paths and branches + `tuple`, `list`, and `dict` all support nested paths and branches. @params paths Paths which to traverse by. @param default Value to return if the paths do not match. @@ -5370,6 +5371,8 @@ def traverse_obj( yield from obj.values() elif is_sequence(obj): yield from obj + elif isinstance(obj, re.Match): + yield from obj.groups() elif traverse_string: yield from str(obj) @@ -5378,6 +5381,8 @@ def traverse_obj( iter_obj = enumerate(obj) elif isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() + elif isinstance(obj, re.Match): + iter_obj = enumerate((obj.group(), *obj.groups())) elif traverse_string: iter_obj = enumerate(str(obj)) else: @@ -5389,10 +5394,21 @@ def traverse_obj( yield {k: v if v is not None else default for k, v in iter_obj if v is not None or default is not NO_DEFAULT} - elif isinstance(obj, dict): + elif isinstance(obj, collections.abc.Mapping): yield (obj.get(key) if casesense or (key in obj) else next((v for k, v in obj.items() if casefold(k) == key), None)) + elif isinstance(obj, re.Match): + if isinstance(key, int) or casesense: + with contextlib.suppress(IndexError): + yield obj.group(key) + return + + if not isinstance(key, str): + return + + yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) + else: if is_user_input: key = (int_or_none(key) if ':' not in key