From 3cc4c43bfd2226f76f35054a99e8f2d2a3ac466a Mon Sep 17 00:00:00 2001 From: Stefan van der Walt Date: Sat, 29 Nov 2008 12:07:54 +0000 Subject: Add test for load's mmap_mode. --- numpy/lib/tests/test_io.py | 87 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 26 deletions(-) (limited to 'numpy/lib/tests/test_io.py') diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index e78fd0579..9fce79a4a 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -2,51 +2,86 @@ from numpy.testing import * import numpy as np import StringIO +from tempfile import NamedTemporaryFile class RoundtripTest: + def roundtrip(self, save_func, *args, **kwargs): + """ + save_func : callable + Function used to save arrays to file. + file_on_disk : bool + If true, store the file on disk, instead of in a + string buffer. + save_kwds : dict + Parameters passed to `save_func`. + load_kwds : dict + Parameters passed to `numpy.load`. + args : tuple of arrays + Arrays stored to file. + + """ + save_kwds = kwargs.get('save_kwds', {}) + load_kwds = kwargs.get('load_kwds', {}) + file_on_disk = kwargs.get('file_on_disk', False) + + if file_on_disk: + target_file = NamedTemporaryFile() + load_file = target_file.name + else: + target_file = StringIO.StringIO() + load_file = target_file + + arr = args + + save_func(target_file, *arr, **save_kwds) + target_file.flush() + target_file.seek(0) + + arr_reloaded = np.load(load_file, **load_kwds) + + self.arr = arr + self.arr_reloaded = arr_reloaded + def test_array(self): - a = np.array( [[1,2],[3,4]], float) - self.do(a) + a = np.array([[1, 2], [3, 4]], float) + self.roundtrip(a) - a = np.array( [[1,2],[3,4]], int) - self.do(a) + a = np.array([[1, 2], [3, 4]], int) + self.roundtrip(a) - a = np.array( [[1+5j,2+6j],[3+7j,4+8j]], dtype=np.csingle) - self.do(a) + a = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.csingle) + self.roundtrip(a) - a = np.array( [[1+5j,2+6j],[3+7j,4+8j]], dtype=np.cdouble) - self.do(a) + a = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.cdouble) + self.roundtrip(a) def test_1D(self): - a = np.array([1,2,3,4], int) - self.do(a) + a = np.array([1, 2, 3, 4], int) + self.roundtrip(a) + + def test_mmap(self): + a = np.array([[1, 2.5], [4, 7.3]]) + self.roundtrip(a, file_on_disk=True, load_kwds={'mmap_mode': 'r'}) def test_record(self): a = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')]) - self.do(a) + self.roundtrip(a) class TestSaveLoad(RoundtripTest, TestCase): - def do(self, a): - c = StringIO.StringIO() - np.save(c, a) - c.seek(0) - a_reloaded = np.load(c) - assert_equal(a, a_reloaded) - + def roundtrip(self, *args, **kwargs): + RoundtripTest.roundtrip(self, np.save, *args, **kwargs) + assert_equal(self.arr[0], self.arr_reloaded) class TestSavezLoad(RoundtripTest, TestCase): - def do(self, *arrays): - c = StringIO.StringIO() - np.savez(c, *arrays) - c.seek(0) - l = np.load(c) - for n, a in enumerate(arrays): - assert_equal(a, l['arr_%d' % n]) + def roundtrip(self, *args, **kwargs): + RoundtripTest.roundtrip(self, np.savez, *args, **kwargs) + for n, arr in enumerate(self.arr): + assert_equal(arr, self.arr_reloaded['arr_%d' % n]) def test_multiple_arrays(self): a = np.array( [[1,2],[3,4]], float) b = np.array( [[1+2j,2+7j],[3-6j,4+12j]], complex) - self.do(a,b) + self.roundtrip(a,b) def test_named_arrays(self): a = np.array( [[1,2],[3,4]], float) -- cgit v1.2.1