summaryrefslogtreecommitdiff
path: root/numpy/lib/npyio.py
diff options
context:
space:
mode:
authorSeth Troisi <sethtroisi@google.com>2020-01-23 18:47:34 -0800
committerSeth Troisi <sethtroisi@google.com>2020-01-24 15:25:12 -0800
commitaa276641d0943d17432d4ba02f8a00fdd6572237 (patch)
tree5c08a9afff0460944e66e347efb56594cd45878b /numpy/lib/npyio.py
parent68224f43d09393c1981bb83ee3c13a5158d2817c (diff)
downloadnumpy-aa276641d0943d17432d4ba02f8a00fdd6572237.tar.gz
ENH: Make use of ExitStack in npyio.py
Diffstat (limited to 'numpy/lib/npyio.py')
-rw-r--r--numpy/lib/npyio.py33
1 files changed, 12 insertions, 21 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index f43fcf0c0..50f309938 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -408,15 +408,14 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True,
pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports)
- # TODO: Use contextlib.ExitStack once we drop Python 2
- if hasattr(file, 'read'):
- fid = file
- own_fid = False
- else:
- fid = open(os_fspath(file), "rb")
- own_fid = True
+ with contextlib.ExitStack() as stack:
+ if hasattr(file, 'read'):
+ fid = file
+ own_fid = False
+ else:
+ fid = stack.enter_context(open(os_fspath(file), "rb"))
+ own_fid = True
- try:
# Code to distinguish from NumPy binary files and pickles.
_ZIP_PREFIX = b'PK\x03\x04'
_ZIP_SUFFIX = b'PK\x05\x06' # empty zip files start with this
@@ -427,10 +426,10 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True,
fid.seek(-min(N, len(magic)), 1) # back-up
if magic.startswith(_ZIP_PREFIX) or magic.startswith(_ZIP_SUFFIX):
# zip-file (assume .npz)
- # Transfer file ownership to NpzFile
+ # Potentially transfer file ownership to NpzFile
+ stack.pop_all()
ret = NpzFile(fid, own_fid=own_fid, allow_pickle=allow_pickle,
pickle_kwargs=pickle_kwargs)
- own_fid = False
return ret
elif magic == format.MAGIC_PREFIX:
# .npy file
@@ -449,9 +448,6 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True,
except Exception:
raise IOError(
"Failed to interpret file %s as a pickle" % repr(file))
- finally:
- if own_fid:
- fid.close()
def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None):
@@ -519,23 +515,18 @@ def save(file, arr, allow_pickle=True, fix_imports=True):
>>> print(a, b)
# [1 2] [1 3]
"""
- own_fid = False
if hasattr(file, 'write'):
- fid = file
+ file_ctx = contextlib_nullcontext(file)
else:
file = os_fspath(file)
if not file.endswith('.npy'):
file = file + '.npy'
- fid = open(file, "wb")
- own_fid = True
+ file_ctx = open(file, "wb")
- try:
+ with file_ctx as fid:
arr = np.asanyarray(arr)
format.write_array(fid, arr, allow_pickle=allow_pickle,
pickle_kwargs=dict(fix_imports=fix_imports))
- finally:
- if own_fid:
- fid.close()
def _savez_dispatcher(file, *args, **kwds):