summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2019-06-04 23:07:26 -0700
committerEric Wieser <wieser.eric@gmail.com>2019-06-04 23:26:43 -0700
commitb0399b905d5813529d068f6cb9a6cd8376b684bd (patch)
tree1530435c3a8e5d428fa95f99ca7cb1d7408571fa /numpy
parent06a32ea32a0a69990e6ca936c810bceab808ebd0 (diff)
downloadnumpy-b0399b905d5813529d068f6cb9a6cd8376b684bd.tar.gz
BUG: Ensure that np.core.records.fromfile closes its file if something goes wrong
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/records.py58
1 files changed, 29 insertions, 29 deletions
diff --git a/numpy/core/records.py b/numpy/core/records.py
index 2adcdae61..659ffa42b 100644
--- a/numpy/core/records.py
+++ b/numpy/core/records.py
@@ -42,7 +42,9 @@ from collections import Counter, OrderedDict
from . import numeric as sb
from . import numerictypes as nt
-from numpy.compat import isfileobj, bytes, long, unicode, os_fspath
+from numpy.compat import (
+ isfileobj, bytes, long, unicode, os_fspath, contextlib_nullcontext
+)
from numpy.core.overrides import set_module
from .arrayprint import get_printoptions
@@ -777,44 +779,42 @@ def fromfile(fd, dtype=None, shape=None, offset=0, formats=None,
if isfileobj(fd):
# file already opened
- name = 0
+ ctx = contextlib_nullcontext(fd)
else:
# open file
- fd = open(os_fspath(fd), 'rb')
- name = 1
+ ctx = open(os_fspath(fd), 'rb')
- if (offset > 0):
- fd.seek(offset, 1)
- size = get_remaining_size(fd)
+ with ctx as fd:
+ if (offset > 0):
+ fd.seek(offset, 1)
+ size = get_remaining_size(fd)
- if dtype is not None:
- descr = sb.dtype(dtype)
- else:
- descr = format_parser(formats, names, titles, aligned, byteorder)._descr
+ if dtype is not None:
+ descr = sb.dtype(dtype)
+ else:
+ descr = format_parser(formats, names, titles, aligned, byteorder)._descr
- itemsize = descr.itemsize
+ itemsize = descr.itemsize
- shapeprod = sb.array(shape).prod(dtype=nt.intp)
- shapesize = shapeprod * itemsize
- if shapesize < 0:
- shape = list(shape)
- shape[shape.index(-1)] = size // -shapesize
- shape = tuple(shape)
shapeprod = sb.array(shape).prod(dtype=nt.intp)
+ shapesize = shapeprod * itemsize
+ if shapesize < 0:
+ shape = list(shape)
+ shape[shape.index(-1)] = size // -shapesize
+ shape = tuple(shape)
+ shapeprod = sb.array(shape).prod(dtype=nt.intp)
- nbytes = shapeprod * itemsize
+ nbytes = shapeprod * itemsize
- if nbytes > size:
- raise ValueError(
- "Not enough bytes left in file for specified shape and type")
+ if nbytes > size:
+ raise ValueError(
+ "Not enough bytes left in file for specified shape and type")
- # create the array
- _array = recarray(shape, descr)
- nbytesread = fd.readinto(_array.data)
- if nbytesread != nbytes:
- raise IOError("Didn't read as many bytes as expected")
- if name:
- fd.close()
+ # create the array
+ _array = recarray(shape, descr)
+ nbytesread = fd.readinto(_array.data)
+ if nbytesread != nbytes:
+ raise IOError("Didn't read as many bytes as expected")
return _array