diff options
| author | Pieter Eendebak <pieter.eendebak@gmail.com> | 2023-02-15 16:59:07 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-15 16:59:07 +0100 |
| commit | 4aca866c292eec899f75475ff14f7d5c1025e394 (patch) | |
| tree | 7a17644cc618f6a88a4c116c685e4d78bdbe4423 /numpy | |
| parent | da6cf855a5fa332781f1637775ab3cff1e12b86a (diff) | |
| download | numpy-4aca866c292eec899f75475ff14f7d5c1025e394.tar.gz | |
ENH: Improve performance of finfo and _commonType (#23088)
The finfo contains a cache for dtypes, but the np.complex128 dtype does not end up in the cache. The reason is that the np.complex128 is converted to np.float64 which is in the cache.
Performance improvement for finfo(np.complex128):
Main: 2.07 µs ± 75 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
Pr: 324 ns ± 28.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
Improve performance of finfo by making the cache check the first action in the __new__
Improve performance of _commonType by re-using the expression for a.dtype.type and eliminating variables
The finfo and _commonType was part of the computatation time in lstsq when using scikit-rf. Since these methods are used in various other methods performance can improve there slightly as well.
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/getlimits.py | 21 | ||||
| -rw-r--r-- | numpy/linalg/linalg.py | 18 |
2 files changed, 25 insertions, 14 deletions
diff --git a/numpy/core/getlimits.py b/numpy/core/getlimits.py index f848af085..da9e1d7f3 100644 --- a/numpy/core/getlimits.py +++ b/numpy/core/getlimits.py @@ -482,6 +482,10 @@ class finfo: _finfo_cache = {} def __new__(cls, dtype): + obj = cls._finfo_cache.get(dtype) # most common path + if obj is not None: + return obj + if dtype is None: # Deprecated in NumPy 1.25, 2023-01-16 warnings.warn( @@ -497,7 +501,7 @@ class finfo: # In case a float instance was given dtype = numeric.dtype(type(dtype)) - obj = cls._finfo_cache.get(dtype, None) + obj = cls._finfo_cache.get(dtype) if obj is not None: return obj dtypes = [dtype] @@ -507,17 +511,24 @@ class finfo: dtype = newdtype if not issubclass(dtype, numeric.inexact): raise ValueError("data type %r not inexact" % (dtype)) - obj = cls._finfo_cache.get(dtype, None) + obj = cls._finfo_cache.get(dtype) if obj is not None: return obj if not issubclass(dtype, numeric.floating): newdtype = _convert_to_float[dtype] if newdtype is not dtype: + # dtype changed, for example from complex128 to float64 dtypes.append(newdtype) dtype = newdtype - obj = cls._finfo_cache.get(dtype, None) - if obj is not None: - return obj + + obj = cls._finfo_cache.get(dtype, None) + if obj is not None: + # the original dtype was not in the cache, but the new + # dtype is in the cache. we add the original dtypes to + # the cache and return the result + for dt in dtypes: + cls._finfo_cache[dt] = obj + return obj obj = object.__new__(cls)._init(dtype) for dt in dtypes: cls._finfo_cache[dt] = obj diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 92878228d..255de94e5 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -138,24 +138,24 @@ def _commonType(*arrays): result_type = single is_complex = False for a in arrays: - if issubclass(a.dtype.type, inexact): - if isComplexType(a.dtype.type): + type_ = a.dtype.type + if issubclass(type_, inexact): + if isComplexType(type_): is_complex = True - rt = _realType(a.dtype.type, default=None) - if rt is None: + rt = _realType(type_, default=None) + if rt is double: + result_type = double + elif rt is None: # unsupported inexact scalar raise TypeError("array type %s is unsupported in linalg" % (a.dtype.name,)) else: - rt = double - if rt is double: result_type = double if is_complex: - t = cdouble result_type = _complex_types_map[result_type] + return cdouble, result_type else: - t = double - return t, result_type + return double, result_type def _to_native_byte_order(*arrays): |
