about summary refs log tree commit diff
diff options
context:
space:
mode:
authorYen Chi Hsuan <yan12125@gmail.com>2016-07-06 20:02:52 +0800
committerYen Chi Hsuan <yan12125@gmail.com>2016-07-06 20:02:52 +0800
commit84c237fb8a2afa06fd3c36f7da9517682e63480e (patch)
tree4bc87148026bb3bd6ed279d3f354eaf0004e97e4
parentab49d7a9fae08763de549f85ba138b22f9122a99 (diff)
downloadyoutube-dl-84c237fb8a2afa06fd3c36f7da9517682e63480e.tar.gz
youtube-dl-84c237fb8a2afa06fd3c36f7da9517682e63480e.tar.xz
youtube-dl-84c237fb8a2afa06fd3c36f7da9517682e63480e.zip
[utils] Add get_element_by_class
For #9950
-rw-r--r--test/test_utils.py9
-rw-r--r--youtube_dl/utils.py12
2 files changed, 19 insertions, 2 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 85928dbc2..afd273a65 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -33,6 +33,7 @@ from youtube_dl.utils import (
     ExtractorError,
     find_xpath_attr,
     fix_xml_ampersands,
+    get_element_by_class,
     InAdvancePagedList,
     intlist_to_bytes,
     is_html,
@@ -991,5 +992,13 @@ The first line
         self.assertEqual(urshift(3, 1), 1)
         self.assertEqual(urshift(-3, 1), 2147483646)
 
+    def test_get_element_by_class(self):
+        html = '''
+            <span class="foo bar">nice</span>
+        '''
+
+        self.assertEqual(get_element_by_class('foo', html), 'nice')
+        self.assertEqual(get_element_by_class('no-such-class', html), None)
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index 36d5b6c0f..3498697b6 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -310,9 +310,17 @@ def get_element_by_id(id, html):
     return get_element_by_attribute('id', id, html)
 
 
-def get_element_by_attribute(attribute, value, html):
+def get_element_by_class(class_name, html):
+    return get_element_by_attribute(
+        'class', r'[^\'"]*\b%s\b[^\'"]*' % re.escape(class_name),
+        html, escape_value=False)
+
+
+def get_element_by_attribute(attribute, value, html, escape_value=True):
     """Return the content of the tag with the specified attribute in the passed HTML document"""
 
+    value = re.escape(value) if escape_value else value
+
     m = re.search(r'''(?xs)
         <([a-zA-Z0-9:._-]+)
          (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'))*?
@@ -321,7 +329,7 @@ def get_element_by_attribute(attribute, value, html):
         \s*>
         (?P<content>.*?)
         </\1>
-    ''' % (re.escape(attribute), re.escape(value)), html)
+    ''' % (re.escape(attribute), value), html)
 
     if not m:
         return None