diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/__init__.pyi | 2 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 4 | ||||
| -rw-r--r-- | numpy/core/tests/test_dtype.py | 18 |
3 files changed, 21 insertions, 3 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 6e24f6bff..7efe73010 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -1065,8 +1065,6 @@ class dtype(Generic[_DTypeScalar_co]): # literals as of mypy 0.800. Set the return-type to `Any` for now. def __rmul__(self, value: int) -> Any: ... - def __eq__(self, other: DTypeLike) -> bool: ... - def __ne__(self, other: DTypeLike) -> bool: ... def __gt__(self, other: DTypeLike) -> bool: ... def __ge__(self, other: DTypeLike) -> bool: ... def __lt__(self, other: DTypeLike) -> bool: ... diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index f0dfac55d..b8b477e5d 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -3228,7 +3228,9 @@ arraydescr_richcompare(PyArray_Descr *self, PyObject *other, int cmp_op) { PyArray_Descr *new = _convert_from_any(other, 0); if (new == NULL) { - return NULL; + /* Cannot convert `other` to dtype */ + PyErr_Clear(); + Py_RETURN_NOTIMPLEMENTED; } npy_bool ret; diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 8a6b7dcd5..3d15009ea 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -88,6 +88,24 @@ class TestBuiltin: assert_raises(TypeError, np.dtype, 'q8') assert_raises(TypeError, np.dtype, 'Q8') + def test_richcompare_invalid_dtype_equality(self): + # Make sure objects that cannot be converted to valid + # dtypes results in False/True when compared to valid dtypes. + # Here 7 cannot be converted to dtype. No exceptions should be raised + + assert not np.dtype(np.int32) == 7, "dtype richcompare failed for ==" + assert np.dtype(np.int32) != 7, "dtype richcompare failed for !=" + + @pytest.mark.parametrize( + 'operation', + [operator.le, operator.lt, operator.ge, operator.gt]) + def test_richcompare_invalid_dtype_comparison(self, operation): + # Make sure TypeError is raised for comparison operators + # for invalid dtypes. Here 7 is an invalid dtype. + + with pytest.raises(TypeError): + operation(np.dtype(np.int32), 7) + @pytest.mark.parametrize("dtype", ['Bool', 'Complex32', 'Complex64', 'Float16', 'Float32', 'Float64', 'Int8', 'Int16', 'Int32', 'Int64', 'Object0', 'Timedelta64', |
