summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/tests/test_io.py33
-rw-r--r--numpy/testing/tests/test_utils.py37
-rw-r--r--numpy/testing/utils.py31
3 files changed, 75 insertions, 26 deletions
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index bffc5c63e..45ee0a477 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -19,7 +19,7 @@ from numpy.ma.testutils import assert_equal
from numpy.testing import (
TestCase, run_module_suite, assert_warns, assert_,
assert_raises_regex, assert_raises, assert_allclose,
- assert_array_equal,
+ assert_array_equal,temppath
)
from numpy.testing.utils import tempdir
@@ -259,26 +259,17 @@ class TestSavezLoad(RoundtripTest, TestCase):
def test_not_closing_opened_fid(self):
# Test that issue #2178 is fixed:
# verify could seek on 'loaded' file
-
- fd, tmp = mkstemp(suffix='.npz')
- os.close(fd)
- try:
- fp = open(tmp, 'wb')
- np.savez(fp, data='LOVELY LOAD')
- fp.close()
-
- fp = open(tmp, 'rb', 10000)
- fp.seek(0)
- assert_(not fp.closed)
- np.load(fp)['data']
- # fp must not get closed by .load
- assert_(not fp.closed)
- fp.seek(0)
- assert_(not fp.closed)
-
- finally:
- fp.close()
- os.remove(tmp)
+ with temppath(suffix='.npz') as tmp:
+ with open(tmp, 'wb') as fp:
+ np.savez(fp, data='LOVELY LOAD')
+ with open(tmp, 'rb', 10000) as fp:
+ fp.seek(0)
+ assert_(not fp.closed)
+ np.load(fp)['data']
+ # fp must not get closed by .load
+ assert_(not fp.closed)
+ fp.seek(0)
+ assert_(not fp.closed)
def test_closing_fid(self):
# Test that issue #1517 (too many opened files) remains closed
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 13aeffe02..23bd491bc 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -2,6 +2,7 @@ from __future__ import division, absolute_import, print_function
import warnings
import sys
+import os
import numpy as np
from numpy.testing import (
@@ -10,7 +11,7 @@ from numpy.testing import (
assert_warns, assert_no_warnings, assert_allclose, assert_approx_equal,
assert_array_almost_equal_nulp, assert_array_max_ulp,
clear_and_catch_warnings, run_module_suite,
- assert_string_equal
+ assert_string_equal, assert_, tempdir, temppath,
)
import unittest
@@ -780,6 +781,40 @@ def test_clear_and_catch_warnings():
assert_warn_len_equal(my_mod, 2)
+def test_tempdir():
+ with tempdir() as tdir:
+ fpath = os.path.join(tdir, 'tmp')
+ with open(fpath, 'w'):
+ pass
+ assert_(not os.path.isdir(tdir))
+
+ raised = False
+ try:
+ with tempdir() as tdir:
+ raise ValueError()
+ except ValueError:
+ raised = True
+ assert_(raised)
+ assert_(not os.path.isdir(tdir))
+
+
+
+def test_temppath():
+ with temppath() as fpath:
+ with open(fpath, 'w') as f:
+ pass
+ assert_(not os.path.isfile(fpath))
+
+ raised = False
+ try:
+ with temppath() as fpath:
+ raise ValueError()
+ except ValueError:
+ raised = True
+ assert_(raised)
+ assert_(not os.path.isfile(fpath))
+
+
class my_cacw(clear_and_catch_warnings):
class_modules = (sys.modules[__name__],)
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index 49d249339..0c4ebe1b9 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -12,7 +12,7 @@ import warnings
from functools import partial
import shutil
import contextlib
-from tempfile import mkdtemp
+from tempfile import mkdtemp, mkstemp
from .nosetester import import_nose
from numpy.core import float32, empty, arange, array_repr, ndarray
@@ -30,7 +30,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal',
'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings',
- 'SkipTest', 'KnownFailureException']
+ 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir']
class KnownFailureException(Exception):
@@ -1810,8 +1810,31 @@ def tempdir(*args, **kwargs):
"""
tmpdir = mkdtemp(*args, **kwargs)
- yield tmpdir
- shutil.rmtree(tmpdir)
+ try:
+ yield tmpdir
+ finally:
+ shutil.rmtree(tmpdir)
+
+@contextlib.contextmanager
+def temppath(*args, **kwargs):
+ """Context manager for temporary files.
+
+ Context manager that returns the path to a closed temporary file. Its
+ parameters are the same as for tempfile.mkstemp and are passed directly
+ to that function. The underlying file is removed when the context is
+ exited, so it should be closed at that time.
+
+ Windows does not allow a temporary file to be opened if it is already
+ open, so the underlying file must be closed after opening before it
+ can be opened again.
+
+ """
+ fd, path = mkstemp(*args, **kwargs)
+ os.close(fd)
+ try:
+ yield path
+ finally:
+ os.remove(path)
class clear_and_catch_warnings(warnings.catch_warnings):