summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/__init__.pyi2
-rw-r--r--numpy/core/src/multiarray/descriptor.c4
-rw-r--r--numpy/core/tests/test_dtype.py18
3 files changed, 21 insertions, 3 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 6e24f6bff..7efe73010 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -1065,8 +1065,6 @@ class dtype(Generic[_DTypeScalar_co]):
# literals as of mypy 0.800. Set the return-type to `Any` for now.
def __rmul__(self, value: int) -> Any: ...
- def __eq__(self, other: DTypeLike) -> bool: ...
- def __ne__(self, other: DTypeLike) -> bool: ...
def __gt__(self, other: DTypeLike) -> bool: ...
def __ge__(self, other: DTypeLike) -> bool: ...
def __lt__(self, other: DTypeLike) -> bool: ...
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index f0dfac55d..b8b477e5d 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -3228,7 +3228,9 @@ arraydescr_richcompare(PyArray_Descr *self, PyObject *other, int cmp_op)
{
PyArray_Descr *new = _convert_from_any(other, 0);
if (new == NULL) {
- return NULL;
+ /* Cannot convert `other` to dtype */
+ PyErr_Clear();
+ Py_RETURN_NOTIMPLEMENTED;
}
npy_bool ret;
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 8a6b7dcd5..3d15009ea 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -88,6 +88,24 @@ class TestBuiltin:
assert_raises(TypeError, np.dtype, 'q8')
assert_raises(TypeError, np.dtype, 'Q8')
+ def test_richcompare_invalid_dtype_equality(self):
+ # Make sure objects that cannot be converted to valid
+ # dtypes results in False/True when compared to valid dtypes.
+ # Here 7 cannot be converted to dtype. No exceptions should be raised
+
+ assert not np.dtype(np.int32) == 7, "dtype richcompare failed for =="
+ assert np.dtype(np.int32) != 7, "dtype richcompare failed for !="
+
+ @pytest.mark.parametrize(
+ 'operation',
+ [operator.le, operator.lt, operator.ge, operator.gt])
+ def test_richcompare_invalid_dtype_comparison(self, operation):
+ # Make sure TypeError is raised for comparison operators
+ # for invalid dtypes. Here 7 is an invalid dtype.
+
+ with pytest.raises(TypeError):
+ operation(np.dtype(np.int32), 7)
+
@pytest.mark.parametrize("dtype",
['Bool', 'Complex32', 'Complex64', 'Float16', 'Float32', 'Float64',
'Int8', 'Int16', 'Int32', 'Int64', 'Object0', 'Timedelta64',