about summary refs log tree commit diff
path: root/youtube_dl/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'youtube_dl/utils.py')
-rw-r--r--youtube_dl/utils.py102
1 files changed, 70 insertions, 32 deletions
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index e1b05b307..cd4303566 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -49,11 +49,14 @@ from .compat import (
     compat_cookiejar,
     compat_ctypes_WINFUNCTYPE,
     compat_datetime_timedelta_total_seconds,
+    compat_etree_Element,
     compat_etree_fromstring,
+    compat_etree_iterfind,
     compat_expanduser,
     compat_html_entities,
     compat_html_entities_html5,
     compat_http_client,
+    compat_http_cookies,
     compat_integer_types,
     compat_kwargs,
     compat_ncompress as ncompress,
@@ -6253,15 +6256,16 @@ if __debug__:
 
 def traverse_obj(obj, *paths, **kwargs):
     """
-    Safely traverse nested `dict`s and `Iterable`s
+    Safely traverse nested `dict`s and `Iterable`s, etc
 
     >>> obj = [{}, {"key": "value"}]
     >>> traverse_obj(obj, (1, "key"))
-    "value"
+    'value'
 
     Each of the provided `paths` is tested and the first producing a valid result will be returned.
     The next path will also be tested if the path branched but no results could be found.
-    Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
+    Supported values for traversal are `Mapping`, `Iterable`, `re.Match`, `xml.etree.ElementTree`
+    (xpath) and `http.cookies.Morsel`.
     Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
 
     The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@@ -6269,8 +6273,9 @@ def traverse_obj(obj, *paths, **kwargs):
     The keys in the path can be one of:
         - `None`:           Return the current object.
         - `set`:            Requires the only item in the set to be a type or function,
-                            like `{type}`/`{func}`. If a `type`, returns only values
-                            of this type. If a function, returns `func(obj)`.
+                            like `{type}`/`{type, type, ...}`/`{func}`. If one or more `type`s,
+                            return only values that have one of the types. If a function,
+                            return `func(obj)`.
         - `str`/`int`:      Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
         - `slice`:          Branch out and return all values in `obj[key]`.
         - `Ellipsis`:       Branch out and return a list of all values.
@@ -6282,8 +6287,10 @@ def traverse_obj(obj, *paths, **kwargs):
                             For `Iterable`s, `key` is the enumeration count of the value.
                             For `re.Match`es, `key` is the group number (0 = full match)
                             as well as additionally any group names, if given.
-        - `dict`            Transform the current object and return a matching dict.
+        - `dict`:           Transform the current object and return a matching dict.
                             Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
+        - `any`-builtin:    Take the first matching object and return it, resetting branching.
+        - `all`-builtin:    Take all matching objects and return them as a list, resetting branching.
 
         `tuple`, `list`, and `dict` all support nested paths and branches.
 
@@ -6299,10 +6306,8 @@ def traverse_obj(obj, *paths, **kwargs):
     @param get_all          If `False`, return the first matching result, otherwise all matching ones.
     @param casesense        If `False`, consider string dictionary keys as case insensitive.
 
-    The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API
+    The following is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
 
-    @param _is_user_input    Whether the keys are generated from user input.
-                            If `True` strings get converted to `int`/`slice` if needed.
     @param _traverse_string  Whether to traverse into objects as strings.
                             If `True`, any non-compatible object will first be
                             converted into a string and then traversed into.
@@ -6322,7 +6327,6 @@ def traverse_obj(obj, *paths, **kwargs):
     expected_type = kwargs.get('expected_type')
     get_all = kwargs.get('get_all', True)
     casesense = kwargs.get('casesense', True)
-    _is_user_input = kwargs.get('_is_user_input', False)
     _traverse_string = kwargs.get('_traverse_string', False)
 
     # instant compat
@@ -6336,10 +6340,8 @@ def traverse_obj(obj, *paths, **kwargs):
         type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
 
     def lookup_or_none(v, k, getter=None):
-        try:
+        with compat_contextlib_suppress(LookupError):
             return getter(v, k) if getter else v[k]
-        except IndexError:
-            return None
 
     def from_iterable(iterables):
         # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
@@ -6361,12 +6363,13 @@ def traverse_obj(obj, *paths, **kwargs):
             result = obj
 
         elif isinstance(key, set):
-            assert len(key) == 1, 'Set should only be used to wrap a single item'
-            item = next(iter(key))
-            if isinstance(item, type):
-                result = obj if isinstance(obj, item) else None
+            assert len(key) >= 1, 'At least one item is required in a `set` key'
+            if all(isinstance(item, type) for item in key):
+                result = obj if isinstance(obj, tuple(key)) else None
             else:
-                result = try_call(item, args=(obj,))
+                item = next(iter(key))
+                assert len(key) == 1, 'Multiple items in a `set` key must all be types'
+                result = try_call(item, args=(obj,)) if not isinstance(item, type) else None
 
         elif isinstance(key, (list, tuple)):
             branching = True
@@ -6375,9 +6378,11 @@ def traverse_obj(obj, *paths, **kwargs):
 
         elif key is Ellipsis:
             branching = True
+            if isinstance(obj, compat_http_cookies.Morsel):
+                obj = dict(obj, key=obj.key, value=obj.value)
             if isinstance(obj, compat_collections_abc.Mapping):
                 result = obj.values()
-            elif is_iterable_like(obj):
+            elif is_iterable_like(obj, (compat_collections_abc.Iterable, compat_etree_Element)):
                 result = obj
             elif isinstance(obj, compat_re_Match):
                 result = obj.groups()
@@ -6389,9 +6394,11 @@ def traverse_obj(obj, *paths, **kwargs):
 
         elif callable(key):
             branching = True
+            if isinstance(obj, compat_http_cookies.Morsel):
+                obj = dict(obj, key=obj.key, value=obj.value)
             if isinstance(obj, compat_collections_abc.Mapping):
                 iter_obj = obj.items()
-            elif is_iterable_like(obj):
+            elif is_iterable_like(obj, (compat_collections_abc.Iterable, compat_etree_Element)):
                 iter_obj = enumerate(obj)
             elif isinstance(obj, compat_re_Match):
                 iter_obj = itertools.chain(
@@ -6413,6 +6420,8 @@ def traverse_obj(obj, *paths, **kwargs):
                           if v is not None or default is not NO_DEFAULT) or None
 
         elif isinstance(obj, compat_collections_abc.Mapping):
+            if isinstance(obj, compat_http_cookies.Morsel):
+                obj = dict(obj, key=obj.key, value=obj.value)
             result = (try_call(obj.get, args=(key,))
                       if casesense or try_call(obj.__contains__, args=(key,))
                       else next((v for k, v in obj.items() if casefold(k) == key), None))
@@ -6430,12 +6439,40 @@ def traverse_obj(obj, *paths, **kwargs):
         else:
             result = None
             if isinstance(key, (int, slice)):
-                if is_iterable_like(obj, compat_collections_abc.Sequence):
+                if is_iterable_like(obj, (compat_collections_abc.Sequence, compat_etree_Element)):
                     branching = isinstance(key, slice)
                     result = lookup_or_none(obj, key)
                 elif _traverse_string:
                     result = lookup_or_none(str(obj), key)
 
+            elif isinstance(obj, compat_etree_Element) and isinstance(key, str):
+                xpath, _, special = key.rpartition('/')
+                if not special.startswith('@') and not special.endswith('()'):
+                    xpath = key
+                    special = None
+
+                # Allow abbreviations of relative paths, absolute paths error
+                if xpath.startswith('/'):
+                    xpath = '.' + xpath
+                elif xpath and not xpath.startswith('./'):
+                    xpath = './' + xpath
+
+                def apply_specials(element):
+                    if special is None:
+                        return element
+                    if special == '@':
+                        return element.attrib
+                    if special.startswith('@'):
+                        return try_call(element.attrib.get, args=(special[1:],))
+                    if special == 'text()':
+                        return element.text
+                    raise SyntaxError('apply_specials is missing case for {0!r}'.format(special))
+
+                if xpath:
+                    result = list(map(apply_specials, compat_etree_iterfind(obj, xpath)))
+                else:
+                    result = apply_specials(obj)
+
         return branching, result if branching else (result,)
 
     def lazy_last(iterable):
@@ -6456,17 +6493,18 @@ def traverse_obj(obj, *paths, **kwargs):
 
         key = None
         for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
-            if _is_user_input and isinstance(key, str):
-                if key == ':':
-                    key = Ellipsis
-                elif ':' in key:
-                    key = slice(*map(int_or_none, key.split(':')))
-                elif int_or_none(key) is not None:
-                    key = int(key)
-
             if not casesense and isinstance(key, str):
                 key = compat_casefold(key)
 
+            if key in (any, all):
+                has_branched = False
+                filtered_objs = (obj for obj in objs if obj not in (None, {}))
+                if key is any:
+                    objs = (next(filtered_objs, None),)
+                else:
+                    objs = (list(filtered_objs),)
+                continue
+
             if __debug__ and callable(key):
                 # Verify function signature
                 _try_bind_args(key, None, None)
@@ -6505,9 +6543,9 @@ def traverse_obj(obj, *paths, **kwargs):
     return None if default is NO_DEFAULT else default
 
 
-def T(x):
-    """ For use in yt-dl instead of {type} or set((type,)) """
-    return set((x,))
+def T(*x):
+    """ For use in yt-dl instead of {type, ...} or set((type, ...)) """
+    return set(x)
 
 
 def get_first(obj, keys, **kwargs):