diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2021-05-04 17:43:26 -0500 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2021-05-04 19:44:31 -0500 |
commit | ad2a73c18dcff95d844c382c94ab7f73b5571cf3 (patch) | |
tree | 5894e417352dc57086d954784b7ef433e587adcb /numpy | |
parent | a59973902c6706cfc4e6958e69e9a9b33db333fb (diff) | |
download | numpy-ad2a73c18dcff95d844c382c94ab7f73b5571cf3.tar.gz |
MAINT: Adjust NumPy float hashing to Python's slightly changed hash
This is necessary, since we use the Python double hash and the
semi-private function to calculate it in Python has a new signature
to return the identity-hash when the value is NaN.
closes gh-18833, gh-18907
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/common/npy_pycompat.h | 16 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_scalarmath.py | 34 |
3 files changed, 57 insertions, 6 deletions
diff --git a/numpy/core/src/common/npy_pycompat.h b/numpy/core/src/common/npy_pycompat.h index aa0b5c122..9e94a9710 100644 --- a/numpy/core/src/common/npy_pycompat.h +++ b/numpy/core/src/common/npy_pycompat.h @@ -3,4 +3,20 @@ #include "numpy/npy_3kcompat.h" + +/* + * In Python 3.10a7 (or b1), python started using the identity for the hash + * when a value is NaN. See https://bugs.python.org/issue43475 + */ +#if PY_VERSION_HEX > 0x030a00a6 +#define Npy_HashDouble _Py_HashDouble +#else +static NPY_INLINE Py_hash_t +Npy_HashDouble(PyObject *NPY_UNUSED(identity), double val) +{ + return _Py_HashDouble(val); +} +#endif + + #endif /* _NPY_COMPAT_H_ */ diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index a001500b0..9930f7791 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -3172,7 +3172,7 @@ static npy_hash_t static npy_hash_t @lname@_arrtype_hash(PyObject *obj) { - return _Py_HashDouble((double) PyArrayScalar_VAL(obj, @name@)); + return Npy_HashDouble(obj, (double)PyArrayScalar_VAL(obj, @name@)); } /* borrowed from complex_hash */ @@ -3180,14 +3180,14 @@ static npy_hash_t c@lname@_arrtype_hash(PyObject *obj) { npy_hash_t hashreal, hashimag, combined; - hashreal = _Py_HashDouble((double) - PyArrayScalar_VAL(obj, C@name@).real); + hashreal = Npy_HashDouble( + obj, (double)PyArrayScalar_VAL(obj, C@name@).real); if (hashreal == -1) { return -1; } - hashimag = _Py_HashDouble((double) - PyArrayScalar_VAL(obj, C@name@).imag); + hashimag = Npy_HashDouble( + obj, (double)PyArrayScalar_VAL(obj, C@name@).imag); if (hashimag == -1) { return -1; } @@ -3202,7 +3202,8 @@ c@lname@_arrtype_hash(PyObject *obj) static npy_hash_t half_arrtype_hash(PyObject *obj) { - return _Py_HashDouble(npy_half_to_double(PyArrayScalar_VAL(obj, Half))); + return Npy_HashDouble( + obj, npy_half_to_double(PyArrayScalar_VAL(obj, Half))); } static npy_hash_t diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py index d91b4a391..09a734284 100644 --- a/numpy/core/tests/test_scalarmath.py +++ b/numpy/core/tests/test_scalarmath.py @@ -712,6 +712,40 @@ class TestBitShifts: assert_equal(res_arr, res_scl) +class TestHash: + @pytest.mark.parametrize("type_code", np.typecodes['AllInteger']) + def test_integer_hashes(self, type_code): + scalar = np.dtype(type_code).type + for i in range(128): + assert hash(i) == hash(scalar(i)) + + @pytest.mark.parametrize("type_code", np.typecodes['AllFloat']) + def test_float_and_complex_hashes(self, type_code): + scalar = np.dtype(type_code).type + for val in [np.pi, np.inf, 3, 6.]: + numpy_val = scalar(val) + # Cast back to Python, in case the NumPy scalar has less precision + if numpy_val.dtype.kind == 'c': + val = complex(numpy_val) + else: + val = float(numpy_val) + assert val == numpy_val + print(repr(numpy_val), repr(val)) + assert hash(val) == hash(numpy_val) + + if hash(float(np.nan)) != hash(float(np.nan)): + # If Python distinguises different NaNs we do so too (gh-18833) + assert hash(scalar(np.nan)) != hash(scalar(np.nan)) + + @pytest.mark.parametrize("type_code", np.typecodes['Complex']) + def test_complex_hashes(self, type_code): + # Test some complex valued hashes specifically: + scalar = np.dtype(type_code).type + for val in [np.pi+1j, np.inf-3j, 3j, 6.+1j]: + numpy_val = scalar(val) + assert hash(complex(numpy_val)) == hash(numpy_val) + + @contextlib.contextmanager def recursionlimit(n): o = sys.getrecursionlimit() |