diff options
-rw-r--r-- | numpy/lib/format.py | 48 | ||||
-rw-r--r-- | numpy/lib/tests/test_format.py | 56 |
2 files changed, 86 insertions, 18 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 4ac1427b4..4cfbbe05d 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -138,6 +138,7 @@ from __future__ import division, absolute_import, print_function import numpy import sys +import io from numpy.lib.utils import safe_eval from numpy.compat import asbytes, isfileobj, long, basestring @@ -187,10 +188,7 @@ def read_magic(fp): major : int minor : int """ - magic_str = fp.read(MAGIC_LEN) - if len(magic_str) != MAGIC_LEN: - msg = "could not read %d characters for the magic string; got %r" - raise ValueError(msg % (MAGIC_LEN, magic_str)) + magic_str = _read_bytes(fp, MAGIC_LEN, "magic string") if magic_str[:-2] != MAGIC_PREFIX: msg = "the magic string is not correct; expected %r, got %r" raise ValueError(msg % (MAGIC_PREFIX, magic_str[:-2])) @@ -322,14 +320,9 @@ def read_array_header_1_0(fp): # Read an unsigned, little-endian short int which has the length of the # header. import struct - hlength_str = fp.read(2) - if len(hlength_str) != 2: - msg = "EOF at %s before reading array header length" - raise ValueError(msg % fp.tell()) + hlength_str = _read_bytes(fp, 2, "array header length") header_length = struct.unpack('<H', hlength_str)[0] - header = fp.read(header_length) - if len(header) != header_length: - raise ValueError("EOF at %s before reading array header" % fp.tell()) + header = _read_bytes(fp, header_length, "array header") # The header is a pretty-printed string representation of a literal Python # dictionary with trailing newlines padded to a 16-byte boundary. The keys @@ -476,11 +469,10 @@ def read_array(fp): max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize) array = numpy.empty(count, dtype=dtype) - for i in range(0, count, max_read_count): read_count = min(max_read_count, count - i) - - data = fp.read(int(read_count * dtype.itemsize)) + read_size = int(read_count * dtype.itemsize) + data = _read_bytes(fp, read_size, "array data") array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype, count=read_count) @@ -601,3 +593,31 @@ def open_memmap(filename, mode='r+', dtype=None, shape=None, mode=mode, offset=offset) return marray + + +def _read_bytes(fp, size, error_template="ran out of data"): + """ + Read from file-like object until size bytes are read. + Raises ValueError if not EOF is encountered before size bytes are read. + Non-blocking objects only supported if they derive from io objects. + + Required as e.g. ZipExtFile in python 2.6 can return less data than + requested. + """ + data = bytes() + while True: + # io files (default in python3) return None or raise on would-block, + # python2 file will truncate, probably nothing can be done about that. + # note that regular files can't be non-blocking + try: + r = fp.read(size - len(data)) + data += r + if len(r) == 0 or len(data) == size: + break + except io.BlockingIOError: + pass + if len(data) != size: + msg = "EOF: reading %s, expected %d bytes got %d" + raise ValueError(msg %(error_template, size, len(data))) + else: + return data diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py index deec2e4eb..b9be643c8 100644 --- a/numpy/lib/tests/test_format.py +++ b/numpy/lib/tests/test_format.py @@ -325,7 +325,7 @@ basic_arrays = [] for scalar in scalars: for endian in '<>': dtype = np.dtype(scalar).newbyteorder(endian) - basic = np.arange(15).astype(dtype) + basic = np.arange(1500).astype(dtype) basic_arrays.extend([ # Empty np.array([], dtype=dtype), @@ -334,11 +334,11 @@ for scalar in scalars: # 1-D basic, # 2-D C-contiguous - basic.reshape((3, 5)), + basic.reshape((30, 50)), # 2-D F-contiguous - basic.reshape((3, 5)).T, + basic.reshape((30, 50)).T, # 2-D non-contiguous - basic.reshape((3, 5))[::-1, ::2], + basic.reshape((30, 50))[::-1, ::2], ]) # More complicated record arrays. @@ -411,6 +411,14 @@ record_arrays = [ ] +#BytesIO that reads a random number of bytes at a time +class BytesIOSRandomSize(BytesIO): + def read(self, size=None): + import random + size = random.randint(1, size) + return super(BytesIOSRandomSize, self).read(size) + + def roundtrip(arr): f = BytesIO() format.write_array(f, arr) @@ -419,6 +427,23 @@ def roundtrip(arr): return arr2 +def roundtrip_randsize(arr): + f = BytesIO() + format.write_array(f, arr) + f2 = BytesIOSRandomSize(f.getvalue()) + arr2 = format.read_array(f2) + return arr2 + + +def roundtrip_truncated(arr): + f = BytesIO() + format.write_array(f, arr) + #BytesIO is one byte short + f2 = BytesIO(f.getvalue()[0:-1]) + arr2 = format.read_array(f2) + return arr2 + + def assert_equal(o1, o2): assert_(o1 == o2) @@ -428,12 +453,27 @@ def test_roundtrip(): arr2 = roundtrip(arr) yield assert_array_equal, arr, arr2 + +def test_roundtrip_randsize(): + for arr in basic_arrays + record_arrays: + if arr.dtype != object: + arr2 = roundtrip_randsize(arr) + yield assert_array_equal, arr, arr2 + + +def test_roundtrip_truncated(): + for arr in basic_arrays: + if arr.dtype != object: + yield assert_raises, ValueError, roundtrip_truncated, arr + + def test_long_str(): # check items larger than internal buffer size, gh-4027 long_str_arr = np.ones(1, dtype=np.dtype((str, format.BUFFER_SIZE + 1))) long_str_arr2 = roundtrip(long_str_arr) assert_array_equal(long_str_arr, long_str_arr2) + @dec.slow def test_memmap_roundtrip(): # XXX: test crashes nose on windows. Fix this @@ -473,6 +513,14 @@ def test_memmap_roundtrip(): del ma +def test_compressed_roundtrip(): + arr = np.random.rand(200, 200) + npz_file = os.path.join(tempdir, 'compressed.npz') + np.savez_compressed(npz_file, arr=arr) + arr1 = np.load(npz_file)['arr'] + assert_array_equal(arr, arr1) + + def test_write_version_1_0(): f = BytesIO() arr = np.arange(1) |