diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/tests/test_io.py | 33 | ||||
-rw-r--r-- | numpy/testing/tests/test_utils.py | 37 | ||||
-rw-r--r-- | numpy/testing/utils.py | 31 |
3 files changed, 75 insertions, 26 deletions
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index bffc5c63e..45ee0a477 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -19,7 +19,7 @@ from numpy.ma.testutils import assert_equal from numpy.testing import ( TestCase, run_module_suite, assert_warns, assert_, assert_raises_regex, assert_raises, assert_allclose, - assert_array_equal, + assert_array_equal,temppath ) from numpy.testing.utils import tempdir @@ -259,26 +259,17 @@ class TestSavezLoad(RoundtripTest, TestCase): def test_not_closing_opened_fid(self): # Test that issue #2178 is fixed: # verify could seek on 'loaded' file - - fd, tmp = mkstemp(suffix='.npz') - os.close(fd) - try: - fp = open(tmp, 'wb') - np.savez(fp, data='LOVELY LOAD') - fp.close() - - fp = open(tmp, 'rb', 10000) - fp.seek(0) - assert_(not fp.closed) - np.load(fp)['data'] - # fp must not get closed by .load - assert_(not fp.closed) - fp.seek(0) - assert_(not fp.closed) - - finally: - fp.close() - os.remove(tmp) + with temppath(suffix='.npz') as tmp: + with open(tmp, 'wb') as fp: + np.savez(fp, data='LOVELY LOAD') + with open(tmp, 'rb', 10000) as fp: + fp.seek(0) + assert_(not fp.closed) + np.load(fp)['data'] + # fp must not get closed by .load + assert_(not fp.closed) + fp.seek(0) + assert_(not fp.closed) def test_closing_fid(self): # Test that issue #1517 (too many opened files) remains closed diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 13aeffe02..23bd491bc 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -2,6 +2,7 @@ from __future__ import division, absolute_import, print_function import warnings import sys +import os import numpy as np from numpy.testing import ( @@ -10,7 +11,7 @@ from numpy.testing import ( assert_warns, assert_no_warnings, assert_allclose, assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp, clear_and_catch_warnings, run_module_suite, - assert_string_equal + assert_string_equal, assert_, tempdir, temppath, ) import unittest @@ -780,6 +781,40 @@ def test_clear_and_catch_warnings(): assert_warn_len_equal(my_mod, 2) +def test_tempdir(): + with tempdir() as tdir: + fpath = os.path.join(tdir, 'tmp') + with open(fpath, 'w'): + pass + assert_(not os.path.isdir(tdir)) + + raised = False + try: + with tempdir() as tdir: + raise ValueError() + except ValueError: + raised = True + assert_(raised) + assert_(not os.path.isdir(tdir)) + + + +def test_temppath(): + with temppath() as fpath: + with open(fpath, 'w') as f: + pass + assert_(not os.path.isfile(fpath)) + + raised = False + try: + with temppath() as fpath: + raise ValueError() + except ValueError: + raised = True + assert_(raised) + assert_(not os.path.isfile(fpath)) + + class my_cacw(clear_and_catch_warnings): class_modules = (sys.modules[__name__],) diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index 49d249339..0c4ebe1b9 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -12,7 +12,7 @@ import warnings from functools import partial import shutil import contextlib -from tempfile import mkdtemp +from tempfile import mkdtemp, mkstemp from .nosetester import import_nose from numpy.core import float32, empty, arange, array_repr, ndarray @@ -30,7 +30,7 @@ __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal', 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings', - 'SkipTest', 'KnownFailureException'] + 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir'] class KnownFailureException(Exception): @@ -1810,8 +1810,31 @@ def tempdir(*args, **kwargs): """ tmpdir = mkdtemp(*args, **kwargs) - yield tmpdir - shutil.rmtree(tmpdir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + +@contextlib.contextmanager +def temppath(*args, **kwargs): + """Context manager for temporary files. + + Context manager that returns the path to a closed temporary file. Its + parameters are the same as for tempfile.mkstemp and are passed directly + to that function. The underlying file is removed when the context is + exited, so it should be closed at that time. + + Windows does not allow a temporary file to be opened if it is already + open, so the underlying file must be closed after opening before it + can be opened again. + + """ + fd, path = mkstemp(*args, **kwargs) + os.close(fd) + try: + yield path + finally: + os.remove(path) class clear_and_catch_warnings(warnings.catch_warnings): |