summary refs log tree commit diff
diff options
context:
space:
mode:
authorPhilipp Hagemeister <phihag@phihag.de>2014-08-21 13:01:13 +0200
committerPhilipp Hagemeister <phihag@phihag.de>2014-08-21 13:01:13 +0200
commit181c8655c798562c85ae2af06f1ece7b01632ea9 (patch)
treea5778b80939a41f9ceebbb010ff47e46e827462a
parent3b95347bb6a7cf97bff7107a6f22f3ce858231a2 (diff)
downloadyoutube-dl-181c8655c798562c85ae2af06f1ece7b01632ea9.tar.gz
youtube-dl-181c8655c798562c85ae2af06f1ece7b01632ea9.tar.xz
youtube-dl-181c8655c798562c85ae2af06f1ece7b01632ea9.zip
[utils] Make JSON file writes atomic (Fixes #3549)
-rw-r--r--youtube_dl/utils.py41
1 files changed, 30 insertions, 11 deletions
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index 1081a9368..d11e46c80 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -24,6 +24,7 @@ import socket
 import struct
 import subprocess
 import sys
+import tempfile
 import traceback
 import xml.etree.ElementTree
 import zlib
@@ -228,18 +229,36 @@ else:
         assert type(s) == type(u'')
         print(s)
 
-# In Python 2.x, json.dump expects a bytestream.
-# In Python 3.x, it writes to a character stream
-if sys.version_info < (3,0):
-    def write_json_file(obj, fn):
-        with open(fn, 'wb') as f:
-            json.dump(obj, f)
-else:
-    def write_json_file(obj, fn):
-        with open(fn, 'w', encoding='utf-8') as f:
-            json.dump(obj, f)
 
-if sys.version_info >= (2,7):
+def write_json_file(obj, fn):
+    """ Encode obj as JSON and write it to fn, atomically """
+
+    # In Python 2.x, json.dump expects a bytestream.
+    # In Python 3.x, it writes to a character stream
+    if sys.version_info < (3, 0):
+        mode = 'wb'
+        encoding = None
+    else:
+        mode = 'w'
+        encoding = 'utf-8'
+    tf = tempfile.NamedTemporaryFile(
+        suffix='.tmp', prefix=os.path.basename(fn) + '.',
+        dir=os.path.dirname(fn),
+        delete=False)
+
+    try:
+        with tf:
+            json.dump(obj, tf)
+        os.rename(tf.name, fn)
+    except:
+        try:
+            os.remove(tf.name)
+        except OSError:
+            pass
+        raise
+
+
+if sys.version_info >= (2, 7):
     def find_xpath_attr(node, xpath, key, val):
         """ Find the xpath xpath[@key=val] """
         assert re.match(r'^[a-zA-Z-]+$', key)