summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/lib/format.py48
-rw-r--r--numpy/lib/tests/test_format.py56
2 files changed, 86 insertions, 18 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py
index 4ac1427b4..4cfbbe05d 100644
--- a/numpy/lib/format.py
+++ b/numpy/lib/format.py
@@ -138,6 +138,7 @@ from __future__ import division, absolute_import, print_function
import numpy
import sys
+import io
from numpy.lib.utils import safe_eval
from numpy.compat import asbytes, isfileobj, long, basestring
@@ -187,10 +188,7 @@ def read_magic(fp):
major : int
minor : int
"""
- magic_str = fp.read(MAGIC_LEN)
- if len(magic_str) != MAGIC_LEN:
- msg = "could not read %d characters for the magic string; got %r"
- raise ValueError(msg % (MAGIC_LEN, magic_str))
+ magic_str = _read_bytes(fp, MAGIC_LEN, "magic string")
if magic_str[:-2] != MAGIC_PREFIX:
msg = "the magic string is not correct; expected %r, got %r"
raise ValueError(msg % (MAGIC_PREFIX, magic_str[:-2]))
@@ -322,14 +320,9 @@ def read_array_header_1_0(fp):
# Read an unsigned, little-endian short int which has the length of the
# header.
import struct
- hlength_str = fp.read(2)
- if len(hlength_str) != 2:
- msg = "EOF at %s before reading array header length"
- raise ValueError(msg % fp.tell())
+ hlength_str = _read_bytes(fp, 2, "array header length")
header_length = struct.unpack('<H', hlength_str)[0]
- header = fp.read(header_length)
- if len(header) != header_length:
- raise ValueError("EOF at %s before reading array header" % fp.tell())
+ header = _read_bytes(fp, header_length, "array header")
# The header is a pretty-printed string representation of a literal Python
# dictionary with trailing newlines padded to a 16-byte boundary. The keys
@@ -476,11 +469,10 @@ def read_array(fp):
max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)
array = numpy.empty(count, dtype=dtype)
-
for i in range(0, count, max_read_count):
read_count = min(max_read_count, count - i)
-
- data = fp.read(int(read_count * dtype.itemsize))
+ read_size = int(read_count * dtype.itemsize)
+ data = _read_bytes(fp, read_size, "array data")
array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,
count=read_count)
@@ -601,3 +593,31 @@ def open_memmap(filename, mode='r+', dtype=None, shape=None,
mode=mode, offset=offset)
return marray
+
+
+def _read_bytes(fp, size, error_template="ran out of data"):
+ """
+ Read from file-like object until size bytes are read.
+ Raises ValueError if not EOF is encountered before size bytes are read.
+ Non-blocking objects only supported if they derive from io objects.
+
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
+ requested.
+ """
+ data = bytes()
+ while True:
+ # io files (default in python3) return None or raise on would-block,
+ # python2 file will truncate, probably nothing can be done about that.
+ # note that regular files can't be non-blocking
+ try:
+ r = fp.read(size - len(data))
+ data += r
+ if len(r) == 0 or len(data) == size:
+ break
+ except io.BlockingIOError:
+ pass
+ if len(data) != size:
+ msg = "EOF: reading %s, expected %d bytes got %d"
+ raise ValueError(msg %(error_template, size, len(data)))
+ else:
+ return data
diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py
index deec2e4eb..b9be643c8 100644
--- a/numpy/lib/tests/test_format.py
+++ b/numpy/lib/tests/test_format.py
@@ -325,7 +325,7 @@ basic_arrays = []
for scalar in scalars:
for endian in '<>':
dtype = np.dtype(scalar).newbyteorder(endian)
- basic = np.arange(15).astype(dtype)
+ basic = np.arange(1500).astype(dtype)
basic_arrays.extend([
# Empty
np.array([], dtype=dtype),
@@ -334,11 +334,11 @@ for scalar in scalars:
# 1-D
basic,
# 2-D C-contiguous
- basic.reshape((3, 5)),
+ basic.reshape((30, 50)),
# 2-D F-contiguous
- basic.reshape((3, 5)).T,
+ basic.reshape((30, 50)).T,
# 2-D non-contiguous
- basic.reshape((3, 5))[::-1, ::2],
+ basic.reshape((30, 50))[::-1, ::2],
])
# More complicated record arrays.
@@ -411,6 +411,14 @@ record_arrays = [
]
+#BytesIO that reads a random number of bytes at a time
+class BytesIOSRandomSize(BytesIO):
+ def read(self, size=None):
+ import random
+ size = random.randint(1, size)
+ return super(BytesIOSRandomSize, self).read(size)
+
+
def roundtrip(arr):
f = BytesIO()
format.write_array(f, arr)
@@ -419,6 +427,23 @@ def roundtrip(arr):
return arr2
+def roundtrip_randsize(arr):
+ f = BytesIO()
+ format.write_array(f, arr)
+ f2 = BytesIOSRandomSize(f.getvalue())
+ arr2 = format.read_array(f2)
+ return arr2
+
+
+def roundtrip_truncated(arr):
+ f = BytesIO()
+ format.write_array(f, arr)
+ #BytesIO is one byte short
+ f2 = BytesIO(f.getvalue()[0:-1])
+ arr2 = format.read_array(f2)
+ return arr2
+
+
def assert_equal(o1, o2):
assert_(o1 == o2)
@@ -428,12 +453,27 @@ def test_roundtrip():
arr2 = roundtrip(arr)
yield assert_array_equal, arr, arr2
+
+def test_roundtrip_randsize():
+ for arr in basic_arrays + record_arrays:
+ if arr.dtype != object:
+ arr2 = roundtrip_randsize(arr)
+ yield assert_array_equal, arr, arr2
+
+
+def test_roundtrip_truncated():
+ for arr in basic_arrays:
+ if arr.dtype != object:
+ yield assert_raises, ValueError, roundtrip_truncated, arr
+
+
def test_long_str():
# check items larger than internal buffer size, gh-4027
long_str_arr = np.ones(1, dtype=np.dtype((str, format.BUFFER_SIZE + 1)))
long_str_arr2 = roundtrip(long_str_arr)
assert_array_equal(long_str_arr, long_str_arr2)
+
@dec.slow
def test_memmap_roundtrip():
# XXX: test crashes nose on windows. Fix this
@@ -473,6 +513,14 @@ def test_memmap_roundtrip():
del ma
+def test_compressed_roundtrip():
+ arr = np.random.rand(200, 200)
+ npz_file = os.path.join(tempdir, 'compressed.npz')
+ np.savez_compressed(npz_file, arr=arr)
+ arr1 = np.load(npz_file)['arr']
+ assert_array_equal(arr, arr1)
+
+
def test_write_version_1_0():
f = BytesIO()
arr = np.arange(1)