From 27ed77aabba8c9eb08d66f34092b1bfcc22c482e Mon Sep 17 00:00:00 2001 From: Andrei Lebedev Date: Thu, 3 Nov 2022 11:09:37 +0100 Subject: [utils] Backport traverse_obj (etc) from yt-dlp (#31156) * Backport traverse_obj and closely related function from yt-dlp (code by pukkandan) * Backport LazyList, variadic(), try_call (code by pukkandan) * Recast using yt-dlp's newer traverse_obj() implementation and tests (code by grub4k) * Add tests for Unicode case folding support matching Py3.5+ (requires f102e3d) * Improve/add tests for variadic, try_call, join_nonempty Co-authored-by: dirkf --- youtube_dl/utils.py | 339 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 339 insertions(+) (limited to 'youtube_dl/utils.py') diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 23a65a81c..e3c3ccff9 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -43,6 +43,7 @@ from .compat import ( compat_HTTPError, compat_basestring, compat_chr, + compat_collections_abc, compat_cookiejar, compat_ctypes_WINFUNCTYPE, compat_etree_fromstring, @@ -1685,6 +1686,7 @@ USER_AGENTS = { NO_DEFAULT = object() +IDENTITY = lambda x: x ENGLISH_MONTH_NAMES = [ 'January', 'February', 'March', 'April', 'May', 'June', @@ -3867,6 +3869,105 @@ def detect_exe_version(output, version_re=None, unrecognized='present'): return unrecognized +class LazyList(compat_collections_abc.Sequence): + """Lazy immutable list from an iterable + Note that slices of a LazyList are lists and not LazyList""" + + class IndexError(IndexError): + def __init__(self, cause=None): + if cause: + # reproduce `raise from` + self.__cause__ = cause + super(IndexError, self).__init__() + + def __init__(self, iterable, **kwargs): + # kwarg-only + reverse = kwargs.get('reverse', False) + _cache = kwargs.get('_cache') + + self._iterable = iter(iterable) + self._cache = [] if _cache is None else _cache + self._reversed = reverse + + def __iter__(self): + if self._reversed: + # We need to consume the entire iterable to iterate in reverse + for item in self.exhaust(): + yield item + return + for item in self._cache: + yield item + for item in self._iterable: + self._cache.append(item) + yield item + + def _exhaust(self): + self._cache.extend(self._iterable) + self._iterable = [] # Discard the emptied iterable to make it pickle-able + return self._cache + + def exhaust(self): + """Evaluate the entire iterable""" + return self._exhaust()[::-1 if self._reversed else 1] + + @staticmethod + def _reverse_index(x): + return None if x is None else ~x + + def __getitem__(self, idx): + if isinstance(idx, slice): + if self._reversed: + idx = slice(self._reverse_index(idx.start), self._reverse_index(idx.stop), -(idx.step or 1)) + start, stop, step = idx.start, idx.stop, idx.step or 1 + elif isinstance(idx, int): + if self._reversed: + idx = self._reverse_index(idx) + start, stop, step = idx, idx, 0 + else: + raise TypeError('indices must be integers or slices') + if ((start or 0) < 0 or (stop or 0) < 0 + or (start is None and step < 0) + or (stop is None and step > 0)): + # We need to consume the entire iterable to be able to slice from the end + # Obviously, never use this with infinite iterables + self._exhaust() + try: + return self._cache[idx] + except IndexError as e: + raise self.IndexError(e) + n = max(start or 0, stop or 0) - len(self._cache) + 1 + if n > 0: + self._cache.extend(itertools.islice(self._iterable, n)) + try: + return self._cache[idx] + except IndexError as e: + raise self.IndexError(e) + + def __bool__(self): + try: + self[-1] if self._reversed else self[0] + except self.IndexError: + return False + return True + + def __len__(self): + self._exhaust() + return len(self._cache) + + def __reversed__(self): + return type(self)(self._iterable, reverse=not self._reversed, _cache=self._cache) + + def __copy__(self): + return type(self)(self._iterable, reverse=self._reversed, _cache=self._cache) + + def __repr__(self): + # repr and str should mimic a list. So we exhaust the iterable + return repr(self.exhaust()) + + def __str__(self): + return repr(self.exhaust()) + + class PagedList(object): def __len__(self): # This is only useful for tests @@ -4092,6 +4193,10 @@ def multipart_encode(data, boundary=None): return out, content_type +def variadic(x, allowed_types=(compat_str, bytes, dict)): + return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,) + + def dict_get(d, key_or_keys, default=None, skip_false_values=True): if isinstance(key_or_keys, (list, tuple)): for key in key_or_keys: @@ -4102,6 +4207,23 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True): return d.get(key_or_keys, default) +def try_call(*funcs, **kwargs): + + # parameter defaults + expected_type = kwargs.get('expected_type') + fargs = kwargs.get('args', []) + fkwargs = kwargs.get('kwargs', {}) + + for f in funcs: + try: + val = f(*fargs, **fkwargs) + except (AttributeError, KeyError, TypeError, IndexError, ZeroDivisionError): + pass + else: + if expected_type is None or isinstance(val, expected_type): + return val + + def try_get(src, getter, expected_type=None): if not isinstance(getter, (list, tuple)): getter = [getter] @@ -5835,3 +5957,220 @@ def clean_podcast_url(url): st\.fm # https://podsights.com/docs/ )/e )/''', '', url) + + +def traverse_obj(obj, *paths, **kwargs): + """ + Safely traverse nested `dict`s and `Sequence`s + + >>> obj = [{}, {"key": "value"}] + >>> traverse_obj(obj, (1, "key")) + "value" + + Each of the provided `paths` is tested and the first producing a valid result will be returned. + The next path will also be tested if the path branched but no results could be found. + Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. + A value of None is treated as the absence of a value. + + The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. + + The keys in the path can be one of: + - `None`: Return the current object. + - `str`/`int`: Return `obj[key]`. For `re.Match, return `obj.group(key)`. + - `slice`: Branch out and return all values in `obj[key]`. + - `Ellipsis`: Branch out and return a list of all values. + - `tuple`/`list`: Branch out and return a list of all matching values. + Read as: `[traverse_obj(obj, branch) for branch in branches]`. + - `function`: Branch out and return values filtered by the function. + Read as: `[value for key, value in obj if function(key, value)]`. + For `Sequence`s, `key` is the index of the value. + - `dict` Transform the current object and return a matching dict. + Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. + + `tuple`, `list`, and `dict` all support nested paths and branches. + + @params paths Paths which to traverse by. + Keyword arguments: + @param default Value to return if the paths do not match. + @param expected_type If a `type`, only accept final values of this type. + If any other callable, try to call the function on each result. + @param get_all If `False`, return the first matching result, otherwise all matching ones. + @param casesense If `False`, consider string dictionary keys as case insensitive. + + The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API + + @param _is_user_input Whether the keys are generated from user input. + If `True` strings get converted to `int`/`slice` if needed. + @param _traverse_string Whether to traverse into objects as strings. + If `True`, any non-compatible object will first be + converted into a string and then traversed into. + + + @returns The result of the object traversal. + If successful, `get_all=True`, and the path branches at least once, + then a list of results is returned instead. + A list is always returned if the last path branches and no `default` is given. + """ + + # parameter defaults + default = kwargs.get('default', NO_DEFAULT) + expected_type = kwargs.get('expected_type') + get_all = kwargs.get('get_all', True) + casesense = kwargs.get('casesense', True) + _is_user_input = kwargs.get('_is_user_input', False) + _traverse_string = kwargs.get('_traverse_string', False) + + # instant compat + str = compat_str + + is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes)) + # stand-in until compat_re_Match is added + compat_re_Match = type(re.match('a', 'a')) + # stand-in until casefold.py is added + try: + ''.casefold() + compat_casefold = lambda s: s.casefold() + except AttributeError: + compat_casefold = lambda s: s.lower() + casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k + + if isinstance(expected_type, type): + type_test = lambda val: val if isinstance(val, expected_type) else None + else: + type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) + + def from_iterable(iterables): + # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F + for it in iterables: + for item in it: + yield item + + def apply_key(key, obj): + if obj is None: + return + + elif key is None: + yield obj + + elif isinstance(key, (list, tuple)): + for branch in key: + _, result = apply_path(obj, branch) + for item in result: + yield item + + elif key is Ellipsis: + result = [] + if isinstance(obj, compat_collections_abc.Mapping): + result = obj.values() + elif is_sequence(obj): + result = obj + elif isinstance(obj, compat_re_Match): + result = obj.groups() + elif _traverse_string: + result = str(obj) + for item in result: + yield item + + elif callable(key): + if is_sequence(obj): + iter_obj = enumerate(obj) + elif isinstance(obj, compat_collections_abc.Mapping): + iter_obj = obj.items() + elif isinstance(obj, compat_re_Match): + iter_obj = enumerate(itertools.chain([obj.group()], obj.groups())) + elif _traverse_string: + iter_obj = enumerate(str(obj)) + else: + return + for item in (v for k, v in iter_obj if try_call(key, args=(k, v))): + yield item + + elif isinstance(key, dict): + iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) + yield dict((k, v if v is not None else default) for k, v in iter_obj + if v is not None or default is not NO_DEFAULT) + + elif isinstance(obj, compat_collections_abc.Mapping): + yield (obj.get(key) if casesense or (key in obj) + else next((v for k, v in obj.items() if casefold(k) == key), None)) + + elif isinstance(obj, compat_re_Match): + if isinstance(key, int) or casesense: + try: + yield obj.group(key) + return + except IndexError: + pass + if not isinstance(key, str): + return + + yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) + + else: + if _is_user_input: + key = (int_or_none(key) if ':' not in key + else slice(*map(int_or_none, key.split(':')))) + + if not isinstance(key, (int, slice)): + return + + if not is_sequence(obj): + if not _traverse_string: + return + obj = str(obj) + + try: + yield obj[key] + except IndexError: + pass + + def apply_path(start_obj, path): + objs = (start_obj,) + has_branched = False + + for key in variadic(path): + if _is_user_input and key == ':': + key = Ellipsis + + if not casesense and isinstance(key, str): + key = compat_casefold(key) + + if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key): + has_branched = True + + key_func = functools.partial(apply_key, key) + objs = from_iterable(map(key_func, objs)) + + return has_branched, objs + + def _traverse_obj(obj, path, use_list=True): + has_branched, results = apply_path(obj, path) + results = LazyList(x for x in map(type_test, results) if x is not None) + + if get_all and has_branched: + return results.exhaust() if results or use_list else None + + return results[0] if results else None + + for index, path in enumerate(paths, 1): + use_list = default is NO_DEFAULT and index == len(paths) + result = _traverse_obj(obj, path, use_list) + if result is not None: + return result + + return None if default is NO_DEFAULT else default + + +def get_first(obj, keys, **kwargs): + return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs) + + +def join_nonempty(*values, **kwargs): + + # parameter defaults + delim = kwargs.get('delim', '-') + from_dict = kwargs.get('from_dict') + + if from_dict is not None: + values = (traverse_obj(from_dict, variadic(v)) for v in values) + return delim.join(map(compat_str, filter(None, values))) -- cgit 1.4.1