diff options
Diffstat (limited to 'numpy/lib/npyio.py')
-rw-r--r-- | numpy/lib/npyio.py | 164 |
1 files changed, 66 insertions, 98 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index e57a6dd47..0db2e6897 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -1,5 +1,3 @@ -from __future__ import division, absolute_import, print_function - import sys import os import re @@ -9,6 +7,7 @@ import warnings import weakref import contextlib from operator import itemgetter, index as opindex +from collections.abc import Mapping import numpy as np from . import format @@ -24,16 +23,10 @@ from ._iotools import ( ) from numpy.compat import ( - asbytes, asstr, asunicode, bytes, basestring, os_fspath, os_PathLike, + asbytes, asstr, asunicode, bytes, os_fspath, os_PathLike, pickle, contextlib_nullcontext ) -if sys.version_info[0] >= 3: - from collections.abc import Mapping -else: - from future_builtins import map - from collections import Mapping - @set_module('numpy') def loads(*args, **kwargs): @@ -55,7 +48,7 @@ array_function_dispatch = functools.partial( overrides.array_function_dispatch, module='numpy') -class BagObj(object): +class BagObj: """ BagObj(obj) @@ -69,7 +62,7 @@ class BagObj(object): Examples -------- >>> from numpy.lib.npyio import BagObj as BO - >>> class BagDemo(object): + >>> class BagDemo: ... def __getitem__(self, key): # An instance of BagObj(BagDemo) ... # will call this method when any ... # attribute look-up is required @@ -266,26 +259,25 @@ class NpzFile(Mapping): raise KeyError("%s is not a file in the archive" % key) - if sys.version_info.major == 3: - # deprecate the python 2 dict apis that we supported by accident in - # python 3. We forgot to implement itervalues() at all in earlier - # versions of numpy, so no need to deprecated it here. + # deprecate the python 2 dict apis that we supported by accident in + # python 3. We forgot to implement itervalues() at all in earlier + # versions of numpy, so no need to deprecated it here. - def iteritems(self): - # Numpy 1.15, 2018-02-20 - warnings.warn( - "NpzFile.iteritems is deprecated in python 3, to match the " - "removal of dict.itertems. Use .items() instead.", - DeprecationWarning, stacklevel=2) - return self.items() + def iteritems(self): + # Numpy 1.15, 2018-02-20 + warnings.warn( + "NpzFile.iteritems is deprecated in python 3, to match the " + "removal of dict.itertems. Use .items() instead.", + DeprecationWarning, stacklevel=2) + return self.items() - def iterkeys(self): - # Numpy 1.15, 2018-02-20 - warnings.warn( - "NpzFile.iterkeys is deprecated in python 3, to match the " - "removal of dict.iterkeys. Use .keys() instead.", - DeprecationWarning, stacklevel=2) - return self.keys() + def iterkeys(self): + # Numpy 1.15, 2018-02-20 + warnings.warn( + "NpzFile.iterkeys is deprecated in python 3, to match the " + "removal of dict.iterkeys. Use .keys() instead.", + DeprecationWarning, stacklevel=2) + return self.keys() @set_module('numpy') @@ -414,21 +406,16 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, # result can similarly silently corrupt numerical data. raise ValueError("encoding must be 'ASCII', 'latin1', or 'bytes'") - if sys.version_info[0] >= 3: - pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports) - else: - # Nothing to do on Python 2 - pickle_kwargs = {} + pickle_kwargs = dict(encoding=encoding, fix_imports=fix_imports) - # TODO: Use contextlib.ExitStack once we drop Python 2 - if hasattr(file, 'read'): - fid = file - own_fid = False - else: - fid = open(os_fspath(file), "rb") - own_fid = True + with contextlib.ExitStack() as stack: + if hasattr(file, 'read'): + fid = file + own_fid = False + else: + fid = stack.enter_context(open(os_fspath(file), "rb")) + own_fid = True - try: # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = b'PK\x03\x04' _ZIP_SUFFIX = b'PK\x05\x06' # empty zip files start with this @@ -439,10 +426,10 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, fid.seek(-min(N, len(magic)), 1) # back-up if magic.startswith(_ZIP_PREFIX) or magic.startswith(_ZIP_SUFFIX): # zip-file (assume .npz) - # Transfer file ownership to NpzFile + # Potentially transfer file ownership to NpzFile + stack.pop_all() ret = NpzFile(fid, own_fid=own_fid, allow_pickle=allow_pickle, pickle_kwargs=pickle_kwargs) - own_fid = False return ret elif magic == format.MAGIC_PREFIX: # .npy file @@ -461,9 +448,6 @@ def load(file, mmap_mode=None, allow_pickle=False, fix_imports=True, except Exception: raise IOError( "Failed to interpret file %s as a pickle" % repr(file)) - finally: - if own_fid: - fid.close() def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None): @@ -480,7 +464,7 @@ def save(file, arr, allow_pickle=True, fix_imports=True): file : file, str, or pathlib.Path File or filename to which the data is saved. If file is a file-object, then the filename is unchanged. If file is a string or Path, a ``.npy`` - extension will be appended to the file name if it does not already + extension will be appended to the filename if it does not already have one. arr : array_like Array data to be saved. @@ -506,9 +490,9 @@ def save(file, arr, allow_pickle=True, fix_imports=True): Notes ----- For a description of the ``.npy`` format, see :py:mod:`numpy.lib.format`. - - Any data saved to the file is appended to the end of the file. - + + Any data saved to the file is appended to the end of the file. + Examples -------- >>> from tempfile import TemporaryFile @@ -524,49 +508,35 @@ def save(file, arr, allow_pickle=True, fix_imports=True): >>> with open('test.npy', 'wb') as f: ... np.save(f, np.array([1, 2])) - ... np.save(f, np.array([1, 3])) + ... np.save(f, np.array([1, 3])) >>> with open('test.npy', 'rb') as f: ... a = np.load(f) ... b = np.load(f) >>> print(a, b) # [1 2] [1 3] """ - own_fid = False if hasattr(file, 'write'): - fid = file + file_ctx = contextlib_nullcontext(file) else: file = os_fspath(file) if not file.endswith('.npy'): file = file + '.npy' - fid = open(file, "wb") - own_fid = True + file_ctx = open(file, "wb") - if sys.version_info[0] >= 3: - pickle_kwargs = dict(fix_imports=fix_imports) - else: - # Nothing to do on Python 2 - pickle_kwargs = None - - try: + with file_ctx as fid: arr = np.asanyarray(arr) format.write_array(fid, arr, allow_pickle=allow_pickle, - pickle_kwargs=pickle_kwargs) - finally: - if own_fid: - fid.close() + pickle_kwargs=dict(fix_imports=fix_imports)) def _savez_dispatcher(file, *args, **kwds): - for a in args: - yield a - for v in kwds.values(): - yield v + yield from args + yield from kwds.values() @array_function_dispatch(_savez_dispatcher) def savez(file, *args, **kwds): - """ - Save several arrays into a single file in uncompressed ``.npz`` format. + """Save several arrays into a single file in uncompressed ``.npz`` format. If arguments are passed in with no keywords, the corresponding variable names, in the ``.npz`` file, are 'arr_0', 'arr_1', etc. If keyword @@ -576,9 +546,9 @@ def savez(file, *args, **kwds): Parameters ---------- file : str or file - Either the file name (string) or an open file (file-like object) + Either the filename (string) or an open file (file-like object) where the data will be saved. If file is a string or a Path, the - ``.npz`` extension will be appended to the file name if it is not + ``.npz`` extension will be appended to the filename if it is not already there. args : Arguments, optional Arrays to save to the file. Since it is not possible for Python to @@ -611,6 +581,10 @@ def savez(file, *args, **kwds): its list of arrays (with the ``.files`` attribute), and for the arrays themselves. + When saving dictionaries, the dictionary keys become filenames + inside the ZIP archive. Therefore, keys should be valid filenames. + E.g., avoid keys that begin with ``/`` or contain ``.``. + Examples -------- >>> from tempfile import TemporaryFile @@ -638,16 +612,13 @@ def savez(file, *args, **kwds): ['x', 'y'] >>> npzfile['x'] array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - """ _savez(file, args, kwds, False) def _savez_compressed_dispatcher(file, *args, **kwds): - for a in args: - yield a - for v in kwds.values(): - yield v + yield from args + yield from kwds.values() @array_function_dispatch(_savez_compressed_dispatcher) @@ -656,15 +627,15 @@ def savez_compressed(file, *args, **kwds): Save several arrays into a single file in compressed ``.npz`` format. If keyword arguments are given, then filenames are taken from the keywords. - If arguments are passed in with no keywords, then stored file names are + If arguments are passed in with no keywords, then stored filenames are arr_0, arr_1, etc. Parameters ---------- file : str or file - Either the file name (string) or an open file (file-like object) + Either the filename (string) or an open file (file-like object) where the data will be saved. If file is a string or a Path, the - ``.npz`` extension will be appended to the file name if it is not + ``.npz`` extension will be appended to the filename if it is not already there. args : Arguments, optional Arrays to save to the file. Since it is not possible for Python to @@ -691,7 +662,7 @@ def savez_compressed(file, *args, **kwds): The ``.npz`` file format is a zipped archive of files named after the variables they contain. The archive is compressed with ``zipfile.ZIP_DEFLATED`` and each file in the archive contains one variable - in ``.npy`` format. For a description of the ``.npy`` format, see + in ``.npy`` format. For a description of the ``.npy`` format, see :py:mod:`numpy.lib.format`. @@ -831,7 +802,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, fname : file, str, or pathlib.Path File, filename, or generator to read. If the filename extension is ``.gz`` or ``.bz2``, the file is first decompressed. Note that - generators should return byte strings for Python 3k. + generators should return byte strings. dtype : data-type, optional Data-type of the resulting array; default: float. If this is a structured data-type, the resulting array will be 1-dimensional, and @@ -934,7 +905,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, """ # Type conversions for Py3 convenience if comments is not None: - if isinstance(comments, (basestring, bytes)): + if isinstance(comments, (str, bytes)): comments = [comments] comments = [_decode_line(x) for x in comments] # Compile regex for comments beforehand @@ -1334,8 +1305,8 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', fmt = asstr(fmt) delimiter = asstr(delimiter) - class WriteWrap(object): - """Convert to unicode in py2 or to bytes on bytestream inputs. + class WriteWrap: + """Convert to bytes on bytestream inputs. """ def __init__(self, fh, encoding): @@ -1375,9 +1346,6 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', open(fname, 'wt').close() fh = np.lib._datasource.open(fname, 'wt', encoding=encoding) own_fh = True - # need to convert str to unicode for text io output - if sys.version_info[0] == 2: - fh = WriteWrap(fh, encoding or 'latin1') elif hasattr(fname, 'write'): # wrap to handle byte output streams fh = WriteWrap(fname, encoding or 'latin1') @@ -1410,7 +1378,7 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', if len(fmt) != ncol: raise AttributeError('fmt has wrong shape. %s' % str(fmt)) format = asstr(delimiter).join(map(asstr, fmt)) - elif isinstance(fmt, basestring): + elif isinstance(fmt, str): n_fmt_chars = fmt.count('%') error = ValueError('fmt has wrong number of %% formats: %s' % fmt) if n_fmt_chars == 1: @@ -1469,7 +1437,7 @@ def fromregex(file, regexp, dtype, encoding=None): Parameters ---------- file : str or file - File name or file object to read. + Filename or file object to read. regexp : str or regexp Regular expression used to parse the file. Groups in the regular expression correspond to fields in the dtype. @@ -1527,9 +1495,9 @@ def fromregex(file, regexp, dtype, encoding=None): dtype = np.dtype(dtype) content = file.read() - if isinstance(content, bytes) and isinstance(regexp, np.unicode): + if isinstance(content, bytes) and isinstance(regexp, np.compat.unicode): regexp = asbytes(regexp) - elif isinstance(content, np.unicode) and isinstance(regexp, bytes): + elif isinstance(content, np.compat.unicode) and isinstance(regexp, bytes): regexp = asstr(regexp) if not hasattr(regexp, 'match'): @@ -1576,7 +1544,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, fname : file, str, pathlib.Path, list of str, generator File, filename, list, or generator to read. If the filename extension is `.gz` or `.bz2`, the file is first decompressed. Note - that generators must return byte strings in Python 3k. The strings + that generators must return byte strings. The strings in a list or produced by a generator are treated as lines. dtype : dtype, optional Data type of the resulting array. @@ -1766,7 +1734,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, try: if isinstance(fname, os_PathLike): fname = os_fspath(fname) - if isinstance(fname, basestring): + if isinstance(fname, str): fid = np.lib._datasource.open(fname, 'rt', encoding=encoding) fid_ctx = contextlib.closing(fid) else: @@ -1908,7 +1876,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, if value not in entry: entry.append(value) # We have a string : apply it to all entries - elif isinstance(user_missing_values, basestring): + elif isinstance(user_missing_values, str): user_value = user_missing_values.split(",") for entry in missing_values: entry.extend(user_value) |