diff options
Diffstat (limited to 'numpy/lib/npyio.py')
-rw-r--r-- | numpy/lib/npyio.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 62fc9c5b3..733795671 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -12,6 +12,7 @@ import numpy as np from . import format from ._datasource import DataSource from numpy.core.multiarray import packbits, unpackbits +from numpy.core.overrides import array_function_dispatch from numpy.core._internal import recursive from ._iotools import ( LineSplitter, NameValidator, StringConverter, ConverterError, @@ -447,6 +448,11 @@ def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, fid.close() +def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None): + return (arr,) + + +@array_function_dispatch(_save_dispatcher) def save(file, arr, allow_pickle=True, fix_imports=True): """ Save an array to a binary file in NumPy ``.npy`` format. @@ -525,6 +531,14 @@ def save(file, arr, allow_pickle=True, fix_imports=True): fid.close() +def _savez_dispatcher(file, *args, **kwds): + for a in args: + yield a + for v in kwds.values(): + yield v + + +@array_function_dispatch(_savez_dispatcher) def savez(file, *args, **kwds): """ Save several arrays into a single file in uncompressed ``.npz`` format. @@ -604,6 +618,14 @@ def savez(file, *args, **kwds): _savez(file, args, kwds, False) +def _savez_compressed_dispatcher(file, *args, **kwds): + for a in args: + yield a + for v in kwds.values(): + yield v + + +@array_function_dispatch(_savez_compressed_dispatcher) def savez_compressed(file, *args, **kwds): """ Save several arrays into a single file in compressed ``.npz`` format. @@ -1154,6 +1176,13 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, return X +def _savetxt_dispatcher(fname, X, fmt=None, delimiter=None, newline=None, + header=None, footer=None, comments=None, + encoding=None): + return (X,) + + +@array_function_dispatch(_savetxt_dispatcher) def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', footer='', comments='# ', encoding=None): """ |