diff options
author | Seth Troisi <sethtroisi@google.com> | 2020-01-23 18:47:34 -0800 |
---|---|---|
committer | Seth Troisi <sethtroisi@google.com> | 2020-01-24 15:25:12 -0800 |
commit | aa276641d0943d17432d4ba02f8a00fdd6572237 (patch) | |
tree | 5c08a9afff0460944e66e347efb56594cd45878b /numpy/lib/npyio.py | |
parent | 68224f43d09393c1981bb83ee3c13a5158d2817c (diff) | |
download | numpy-aa276641d0943d17432d4ba02f8a00fdd6572237.tar.gz |
ENH: Make use of ExitStack in npyio.py
Diffstat (limited to 'numpy/lib/npyio.py')
-rw-r--r-- | numpy/lib/npyio.py | 33 |
1 files changed, 12 insertions, 21 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index f43fcf0c0..50f309938 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -408,15 +408,14 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports) - # TODO: Use contextlib.ExitStack once we drop Python 2 - if hasattr(file, 'read'): - fid = file - own_fid = False - else: - fid = open(os_fspath(file), "rb") - own_fid = True + with contextlib.ExitStack() as stack: + if hasattr(file, 'read'): + fid = file + own_fid = False + else: + fid = stack.enter_context(open(os_fspath(file), "rb")) + own_fid = True - try: # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = b'PK\x03\x04' _ZIP_SUFFIX = b'PK\x05\x06' # empty zip files start with this @@ -427,10 +426,10 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, fid.seek(-min(N, len(magic)), 1) # back-up if magic.startswith(_ZIP_PREFIX) or magic.startswith(_ZIP_SUFFIX): # zip-file (assume .npz) - # Transfer file ownership to NpzFile + # Potentially transfer file ownership to NpzFile + stack.pop_all() ret = NpzFile(fid, own_fid=own_fid, allow_pickle=allow_pickle, pickle_kwargs=pickle_kwargs) - own_fid = False return ret elif magic == format.MAGIC_PREFIX: # .npy file @@ -449,9 +448,6 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, except Exception: raise IOError( "Failed to interpret file %s as a pickle" % repr(file)) - finally: - if own_fid: - fid.close() def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None): @@ -519,23 +515,18 @@ def save(file, arr, allow_pickle=True, fix_imports=True): >>> print(a, b) # [1 2] [1 3] """ - own_fid = False if hasattr(file, 'write'): - fid = file + file_ctx = contextlib_nullcontext(file) else: file = os_fspath(file) if not file.endswith('.npy'): file = file + '.npy' - fid = open(file, "wb") - own_fid = True + file_ctx = open(file, "wb") - try: + with file_ctx as fid: arr = np.asanyarray(arr) format.write_array(fid, arr, allow_pickle=allow_pickle, pickle_kwargs=dict(fix_imports=fix_imports)) - finally: - if own_fid: - fid.close() def _savez_dispatcher(file, *args, **kwds): |