about summary refs log tree commit diff
path: root/youtube_dl/utils.py
diff options
context:
space:
mode:
authorAndrei Lebedev <lebdron@gmail.com>2022-11-03 11:09:37 +0100
committerGitHub <noreply@github.com>2022-11-03 10:09:37 +0000
commit27ed77aabba8c9eb08d66f34092b1bfcc22c482e (patch)
tree7cc41fc5e398009a5cf8e7e4156afb0246aa34d3 /youtube_dl/utils.py
parentc4b19a88169fa76c5eb665d274e7270a0fe452c4 (diff)
downloadyoutube-dl-27ed77aabba8c9eb08d66f34092b1bfcc22c482e.tar.gz
youtube-dl-27ed77aabba8c9eb08d66f34092b1bfcc22c482e.tar.xz
youtube-dl-27ed77aabba8c9eb08d66f34092b1bfcc22c482e.zip
[utils] Backport traverse_obj (etc) from yt-dlp (#31156)
* Backport traverse_obj and closely related function from yt-dlp (code by pukkandan)
* Backport LazyList, variadic(), try_call (code by pukkandan)
* Recast using yt-dlp's newer traverse_obj() implementation and tests (code by grub4k)
* Add tests for Unicode case folding support matching Py3.5+ (requires f102e3d)
* Improve/add tests for variadic, try_call, join_nonempty

Co-authored-by: dirkf <fieldhouse@gmx.net>
Diffstat (limited to 'youtube_dl/utils.py')
-rw-r--r--youtube_dl/utils.py339
1 files changed, 339 insertions, 0 deletions
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index 23a65a81c..e3c3ccff9 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -43,6 +43,7 @@ from .compat import (
     compat_HTTPError,
     compat_basestring,
     compat_chr,
+    compat_collections_abc,
     compat_cookiejar,
     compat_ctypes_WINFUNCTYPE,
     compat_etree_fromstring,
@@ -1685,6 +1686,7 @@ USER_AGENTS = {
 
 
 NO_DEFAULT = object()
+IDENTITY = lambda x: x
 
 ENGLISH_MONTH_NAMES = [
     'January', 'February', 'March', 'April', 'May', 'June',
@@ -3867,6 +3869,105 @@ def detect_exe_version(output, version_re=None, unrecognized='present'):
         return unrecognized
 
 
+class LazyList(compat_collections_abc.Sequence):
+    """Lazy immutable list from an iterable
+    Note that slices of a LazyList are lists and not LazyList"""
+
+    class IndexError(IndexError):
+        def __init__(self, cause=None):
+            if cause:
+                # reproduce `raise from`
+                self.__cause__ = cause
+            super(IndexError, self).__init__()
+
+    def __init__(self, iterable, **kwargs):
+        # kwarg-only
+        reverse = kwargs.get('reverse', False)
+        _cache = kwargs.get('_cache')
+
+        self._iterable = iter(iterable)
+        self._cache = [] if _cache is None else _cache
+        self._reversed = reverse
+
+    def __iter__(self):
+        if self._reversed:
+            # We need to consume the entire iterable to iterate in reverse
+            for item in self.exhaust():
+                yield item
+            return
+        for item in self._cache:
+            yield item
+        for item in self._iterable:
+            self._cache.append(item)
+            yield item
+
+    def _exhaust(self):
+        self._cache.extend(self._iterable)
+        self._iterable = []  # Discard the emptied iterable to make it pickle-able
+        return self._cache
+
+    def exhaust(self):
+        """Evaluate the entire iterable"""
+        return self._exhaust()[::-1 if self._reversed else 1]
+
+    @staticmethod
+    def _reverse_index(x):
+        return None if x is None else ~x
+
+    def __getitem__(self, idx):
+        if isinstance(idx, slice):
+            if self._reversed:
+                idx = slice(self._reverse_index(idx.start), self._reverse_index(idx.stop), -(idx.step or 1))
+            start, stop, step = idx.start, idx.stop, idx.step or 1
+        elif isinstance(idx, int):
+            if self._reversed:
+                idx = self._reverse_index(idx)
+            start, stop, step = idx, idx, 0
+        else:
+            raise TypeError('indices must be integers or slices')
+        if ((start or 0) < 0 or (stop or 0) < 0
+                or (start is None and step < 0)
+                or (stop is None and step > 0)):
+            # We need to consume the entire iterable to be able to slice from the end
+            # Obviously, never use this with infinite iterables
+            self._exhaust()
+            try:
+                return self._cache[idx]
+            except IndexError as e:
+                raise self.IndexError(e)
+        n = max(start or 0, stop or 0) - len(self._cache) + 1
+        if n > 0:
+            self._cache.extend(itertools.islice(self._iterable, n))
+        try:
+            return self._cache[idx]
+        except IndexError as e:
+            raise self.IndexError(e)
+
+    def __bool__(self):
+        try:
+            self[-1] if self._reversed else self[0]
+        except self.IndexError:
+            return False
+        return True
+
+    def __len__(self):
+        self._exhaust()
+        return len(self._cache)
+
+    def __reversed__(self):
+        return type(self)(self._iterable, reverse=not self._reversed, _cache=self._cache)
+
+    def __copy__(self):
+        return type(self)(self._iterable, reverse=self._reversed, _cache=self._cache)
+
+    def __repr__(self):
+        # repr and str should mimic a list. So we exhaust the iterable
+        return repr(self.exhaust())
+
+    def __str__(self):
+        return repr(self.exhaust())
+
+
 class PagedList(object):
     def __len__(self):
         # This is only useful for tests
@@ -4092,6 +4193,10 @@ def multipart_encode(data, boundary=None):
     return out, content_type
 
 
+def variadic(x, allowed_types=(compat_str, bytes, dict)):
+    return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,)
+
+
 def dict_get(d, key_or_keys, default=None, skip_false_values=True):
     if isinstance(key_or_keys, (list, tuple)):
         for key in key_or_keys:
@@ -4102,6 +4207,23 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True):
     return d.get(key_or_keys, default)
 
 
+def try_call(*funcs, **kwargs):
+
+    # parameter defaults
+    expected_type = kwargs.get('expected_type')
+    fargs = kwargs.get('args', [])
+    fkwargs = kwargs.get('kwargs', {})
+
+    for f in funcs:
+        try:
+            val = f(*fargs, **fkwargs)
+        except (AttributeError, KeyError, TypeError, IndexError, ZeroDivisionError):
+            pass
+        else:
+            if expected_type is None or isinstance(val, expected_type):
+                return val
+
+
 def try_get(src, getter, expected_type=None):
     if not isinstance(getter, (list, tuple)):
         getter = [getter]
@@ -5835,3 +5957,220 @@ def clean_podcast_url(url):
                 st\.fm # https://podsights.com/docs/
             )/e
         )/''', '', url)
+
+
+def traverse_obj(obj, *paths, **kwargs):
+    """
+    Safely traverse nested `dict`s and `Sequence`s
+
+    >>> obj = [{}, {"key": "value"}]
+    >>> traverse_obj(obj, (1, "key"))
+    "value"
+
+    Each of the provided `paths` is tested and the first producing a valid result will be returned.
+    The next path will also be tested if the path branched but no results could be found.
+    Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
+    A value of None is treated as the absence of a value.
+
+    The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
+
+    The keys in the path can be one of:
+        - `None`:           Return the current object.
+        - `str`/`int`:      Return `obj[key]`. For `re.Match, return `obj.group(key)`.
+        - `slice`:          Branch out and return all values in `obj[key]`.
+        - `Ellipsis`:       Branch out and return a list of all values.
+        - `tuple`/`list`:   Branch out and return a list of all matching values.
+                            Read as: `[traverse_obj(obj, branch) for branch in branches]`.
+        - `function`:       Branch out and return values filtered by the function.
+                            Read as: `[value for key, value in obj if function(key, value)]`.
+                            For `Sequence`s, `key` is the index of the value.
+        - `dict`            Transform the current object and return a matching dict.
+                            Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
+
+        `tuple`, `list`, and `dict` all support nested paths and branches.
+
+    @params paths           Paths which to traverse by.
+    Keyword arguments:
+    @param default          Value to return if the paths do not match.
+    @param expected_type    If a `type`, only accept final values of this type.
+                            If any other callable, try to call the function on each result.
+    @param get_all          If `False`, return the first matching result, otherwise all matching ones.
+    @param casesense        If `False`, consider string dictionary keys as case insensitive.
+
+    The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API
+
+    @param _is_user_input    Whether the keys are generated from user input.
+                            If `True` strings get converted to `int`/`slice` if needed.
+    @param _traverse_string  Whether to traverse into objects as strings.
+                            If `True`, any non-compatible object will first be
+                            converted into a string and then traversed into.
+
+
+    @returns                The result of the object traversal.
+                            If successful, `get_all=True`, and the path branches at least once,
+                            then a list of results is returned instead.
+                            A list is always returned if the last path branches and no `default` is given.
+    """
+
+    # parameter defaults
+    default = kwargs.get('default', NO_DEFAULT)
+    expected_type = kwargs.get('expected_type')
+    get_all = kwargs.get('get_all', True)
+    casesense = kwargs.get('casesense', True)
+    _is_user_input = kwargs.get('_is_user_input', False)
+    _traverse_string = kwargs.get('_traverse_string', False)
+
+    # instant compat
+    str = compat_str
+
+    is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes))
+    # stand-in until compat_re_Match is added
+    compat_re_Match = type(re.match('a', 'a'))
+    # stand-in until casefold.py is added
+    try:
+        ''.casefold()
+        compat_casefold = lambda s: s.casefold()
+    except AttributeError:
+        compat_casefold = lambda s: s.lower()
+    casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k
+
+    if isinstance(expected_type, type):
+        type_test = lambda val: val if isinstance(val, expected_type) else None
+    else:
+        type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
+
+    def from_iterable(iterables):
+        # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
+        for it in iterables:
+            for item in it:
+                yield item
+
+    def apply_key(key, obj):
+        if obj is None:
+            return
+
+        elif key is None:
+            yield obj
+
+        elif isinstance(key, (list, tuple)):
+            for branch in key:
+                _, result = apply_path(obj, branch)
+                for item in result:
+                    yield item
+
+        elif key is Ellipsis:
+            result = []
+            if isinstance(obj, compat_collections_abc.Mapping):
+                result = obj.values()
+            elif is_sequence(obj):
+                result = obj
+            elif isinstance(obj, compat_re_Match):
+                result = obj.groups()
+            elif _traverse_string:
+                result = str(obj)
+            for item in result:
+                yield item
+
+        elif callable(key):
+            if is_sequence(obj):
+                iter_obj = enumerate(obj)
+            elif isinstance(obj, compat_collections_abc.Mapping):
+                iter_obj = obj.items()
+            elif isinstance(obj, compat_re_Match):
+                iter_obj = enumerate(itertools.chain([obj.group()], obj.groups()))
+            elif _traverse_string:
+                iter_obj = enumerate(str(obj))
+            else:
+                return
+            for item in (v for k, v in iter_obj if try_call(key, args=(k, v))):
+                yield item
+
+        elif isinstance(key, dict):
+            iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items())
+            yield dict((k, v if v is not None else default) for k, v in iter_obj
+                       if v is not None or default is not NO_DEFAULT)
+
+        elif isinstance(obj, compat_collections_abc.Mapping):
+            yield (obj.get(key) if casesense or (key in obj)
+                   else next((v for k, v in obj.items() if casefold(k) == key), None))
+
+        elif isinstance(obj, compat_re_Match):
+            if isinstance(key, int) or casesense:
+                try:
+                    yield obj.group(key)
+                    return
+                except IndexError:
+                    pass
+            if not isinstance(key, str):
+                return
+
+            yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
+
+        else:
+            if _is_user_input:
+                key = (int_or_none(key) if ':' not in key
+                       else slice(*map(int_or_none, key.split(':'))))
+
+            if not isinstance(key, (int, slice)):
+                return
+
+            if not is_sequence(obj):
+                if not _traverse_string:
+                    return
+                obj = str(obj)
+
+            try:
+                yield obj[key]
+            except IndexError:
+                pass
+
+    def apply_path(start_obj, path):
+        objs = (start_obj,)
+        has_branched = False
+
+        for key in variadic(path):
+            if _is_user_input and key == ':':
+                key = Ellipsis
+
+            if not casesense and isinstance(key, str):
+                key = compat_casefold(key)
+
+            if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key):
+                has_branched = True
+
+            key_func = functools.partial(apply_key, key)
+            objs = from_iterable(map(key_func, objs))
+
+        return has_branched, objs
+
+    def _traverse_obj(obj, path, use_list=True):
+        has_branched, results = apply_path(obj, path)
+        results = LazyList(x for x in map(type_test, results) if x is not None)
+
+        if get_all and has_branched:
+            return results.exhaust() if results or use_list else None
+
+        return results[0] if results else None
+
+    for index, path in enumerate(paths, 1):
+        use_list = default is NO_DEFAULT and index == len(paths)
+        result = _traverse_obj(obj, path, use_list)
+        if result is not None:
+            return result
+
+    return None if default is NO_DEFAULT else default
+
+
+def get_first(obj, keys, **kwargs):
+    return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs)
+
+
+def join_nonempty(*values, **kwargs):
+
+    # parameter defaults
+    delim = kwargs.get('delim', '-')
+    from_dict = kwargs.get('from_dict')
+
+    if from_dict is not None:
+        values = (traverse_obj(from_dict, variadic(v)) for v in values)
+    return delim.join(map(compat_str, filter(None, values)))