diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/npyio.py | 92 |
1 files changed, 50 insertions, 42 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 7c73d9655..983e2615c 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -5,7 +5,7 @@ import itertools import warnings import weakref import contextlib -from operator import itemgetter, index as opindex +from operator import itemgetter, index as opindex, methodcaller from collections.abc import Mapping import numpy as np @@ -728,41 +728,42 @@ def _savez(file, args, kwds, compress, allow_pickle=True, pickle_kwargs=None): zipf.close() +def _floatconv(x): + try: + return float(x) # The fastest path. + except ValueError: + if '0x' in x: # Don't accidentally convert "a" ("0xa") to 10. + try: + return float.fromhex(x) + except ValueError: + pass + raise # Raise the original exception, which makes more sense. + + +_CONVERTERS = [ + (np.bool_, lambda x: bool(int(x))), + (np.uint64, np.uint64), + (np.int64, np.int64), + (np.integer, lambda x: int(float(x))), + (np.longdouble, np.longdouble), + (np.floating, _floatconv), + (complex, lambda x: complex(asstr(x).replace('+-', '-'))), + (np.bytes_, asbytes), + (np.unicode_, asunicode), +] + + def _getconv(dtype): - """ Find the correct dtype converter. Adapted from matplotlib """ + """ + Find the correct dtype converter. Adapted from matplotlib. - def floatconv(x): - try: - return float(x) # The fastest path. - except ValueError: - if '0x' in x: # Don't accidentally convert "a" ("0xa") to 10. - try: - return float.fromhex(x) - except ValueError: - pass - raise # Raise the original exception, which makes more sense. - - typ = dtype.type - if issubclass(typ, np.bool_): - return lambda x: bool(int(x)) - if issubclass(typ, np.uint64): - return np.uint64 - if issubclass(typ, np.int64): - return np.int64 - if issubclass(typ, np.integer): - return lambda x: int(float(x)) - elif issubclass(typ, np.longdouble): - return np.longdouble - elif issubclass(typ, np.floating): - return floatconv - elif issubclass(typ, complex): - return lambda x: complex(asstr(x).replace('+-', '-')) - elif issubclass(typ, np.bytes_): - return asbytes - elif issubclass(typ, np.unicode_): - return asunicode - else: - return asstr + Even when a lambda is returned, it is defined at the toplevel, to allow + testing for equality and enabling optimization for single-type data. + """ + for base, conv in _CONVERTERS: + if issubclass(dtype.type, base): + return conv + return asstr # _loadtxt_flatten_dtype_internal and _loadtxt_pack_items are loadtxt helpers @@ -1011,12 +1012,9 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, line_num = i + skiprows + 1 raise ValueError("Wrong number of columns at line %d" % line_num) - - # Convert each value according to its column and store - items = [conv(val) for (conv, val) in zip(converters, vals)] - - # Then pack it according to the dtype's nesting - items = packer(items) + # Convert each value according to its column, then pack it + # according to the dtype's nesting + items = packer(convert_row(vals)) X.append(items) if len(X) > chunk_size: yield X @@ -1154,8 +1152,18 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, else: converters[i] = conv - converters = [conv if conv is not bytes else - lambda x: x.encode(fencoding) for conv in converters] + fencode = methodcaller("encode", fencoding) + converters = [conv if conv is not bytes else fencode + for conv in converters] + if len(set(converters)) == 1: + # Optimize single-type data. Note that this is only reached if + # `_getconv` returns equal callables (i.e. not local lambdas) on + # equal dtypes. + def convert_row(vals, _conv=converters[0]): + return [*map(_conv, vals)] + else: + def convert_row(vals): + return [conv(val) for conv, val in zip(converters, vals)] # read data in chunks and fill it into an array via resize # over-allocating and shrinking the array later may be faster but is |