summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorPieter Eendebak <pieter.eendebak@gmail.com>2023-02-15 16:59:07 +0100
committerGitHub <noreply@github.com>2023-02-15 16:59:07 +0100
commit4aca866c292eec899f75475ff14f7d5c1025e394 (patch)
tree7a17644cc618f6a88a4c116c685e4d78bdbe4423 /numpy
parentda6cf855a5fa332781f1637775ab3cff1e12b86a (diff)
downloadnumpy-4aca866c292eec899f75475ff14f7d5c1025e394.tar.gz
ENH: Improve performance of finfo and _commonType (#23088)
The finfo contains a cache for dtypes, but the np.complex128 dtype does not end up in the cache. The reason is that the np.complex128 is converted to np.float64 which is in the cache. Performance improvement for finfo(np.complex128): Main: 2.07 µs ± 75 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) Pr: 324 ns ± 28.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) Improve performance of finfo by making the cache check the first action in the __new__ Improve performance of _commonType by re-using the expression for a.dtype.type and eliminating variables The finfo and _commonType was part of the computatation time in lstsq when using scikit-rf. Since these methods are used in various other methods performance can improve there slightly as well.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/getlimits.py21
-rw-r--r--numpy/linalg/linalg.py18
2 files changed, 25 insertions, 14 deletions
diff --git a/numpy/core/getlimits.py b/numpy/core/getlimits.py
index f848af085..da9e1d7f3 100644
--- a/numpy/core/getlimits.py
+++ b/numpy/core/getlimits.py
@@ -482,6 +482,10 @@ class finfo:
_finfo_cache = {}
def __new__(cls, dtype):
+ obj = cls._finfo_cache.get(dtype) # most common path
+ if obj is not None:
+ return obj
+
if dtype is None:
# Deprecated in NumPy 1.25, 2023-01-16
warnings.warn(
@@ -497,7 +501,7 @@ class finfo:
# In case a float instance was given
dtype = numeric.dtype(type(dtype))
- obj = cls._finfo_cache.get(dtype, None)
+ obj = cls._finfo_cache.get(dtype)
if obj is not None:
return obj
dtypes = [dtype]
@@ -507,17 +511,24 @@ class finfo:
dtype = newdtype
if not issubclass(dtype, numeric.inexact):
raise ValueError("data type %r not inexact" % (dtype))
- obj = cls._finfo_cache.get(dtype, None)
+ obj = cls._finfo_cache.get(dtype)
if obj is not None:
return obj
if not issubclass(dtype, numeric.floating):
newdtype = _convert_to_float[dtype]
if newdtype is not dtype:
+ # dtype changed, for example from complex128 to float64
dtypes.append(newdtype)
dtype = newdtype
- obj = cls._finfo_cache.get(dtype, None)
- if obj is not None:
- return obj
+
+ obj = cls._finfo_cache.get(dtype, None)
+ if obj is not None:
+ # the original dtype was not in the cache, but the new
+ # dtype is in the cache. we add the original dtypes to
+ # the cache and return the result
+ for dt in dtypes:
+ cls._finfo_cache[dt] = obj
+ return obj
obj = object.__new__(cls)._init(dtype)
for dt in dtypes:
cls._finfo_cache[dt] = obj
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 92878228d..255de94e5 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -138,24 +138,24 @@ def _commonType(*arrays):
result_type = single
is_complex = False
for a in arrays:
- if issubclass(a.dtype.type, inexact):
- if isComplexType(a.dtype.type):
+ type_ = a.dtype.type
+ if issubclass(type_, inexact):
+ if isComplexType(type_):
is_complex = True
- rt = _realType(a.dtype.type, default=None)
- if rt is None:
+ rt = _realType(type_, default=None)
+ if rt is double:
+ result_type = double
+ elif rt is None:
# unsupported inexact scalar
raise TypeError("array type %s is unsupported in linalg" %
(a.dtype.name,))
else:
- rt = double
- if rt is double:
result_type = double
if is_complex:
- t = cdouble
result_type = _complex_types_map[result_type]
+ return cdouble, result_type
else:
- t = double
- return t, result_type
+ return double, result_type
def _to_native_byte_order(*arrays):