summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorJim Crist <crist042@umn.edu>2016-11-15 15:28:12 -0600
committerJim Crist <crist042@umn.edu>2017-01-23 12:21:18 -0600
commit6cbb46e45dde71be76ec6c3f4e8e6fccb4227976 (patch)
tree2bb3621c4b73895476f2347743d8d25134cb116b /numpy
parent1c8ecc783c43269631b1fdc8f9e06355d231aec6 (diff)
downloadnumpy-6cbb46e45dde71be76ec6c3f4e8e6fccb4227976.tar.gz
BUG: bool(dtype) is True
Previously `bool(dtype(...))` would fallback to the default implementation of `__nonzero__`, which checks if `len(object) > 0`. Since `dtype` objects implement `__len__` as the number of record fields, `bool` of scalar dtypes like `bool(dtype('i8'))` would evaluate as `False`. This fixes that by implementing `__nonzero__` to always return True. Fixes #6294.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/descriptor.c27
-rw-r--r--numpy/core/tests/test_dtype.py7
2 files changed, 33 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index 9fa9ad088..613c4ca1c 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -3651,6 +3651,31 @@ arraydescr_richcompare(PyArray_Descr *self, PyObject *other, int cmp_op)
return result;
}
+static int
+descr_nonzero(PyObject *self)
+{
+ /* `bool(np.dtype(...)) == True` for all dtypes. Needed to override default
+ * nonzero implementation, which checks if `len(object) > 0`. */
+ return 1;
+}
+
+static PyNumberMethods descr_as_number = {
+ (binaryfunc)0, /* nb_add */
+ (binaryfunc)0, /* nb_subtract */
+ (binaryfunc)0, /* nb_multiply */
+ #if defined(NPY_PY3K)
+ #else
+ (binaryfunc)0, /* nb_divide */
+ #endif
+ (binaryfunc)0, /* nb_remainder */
+ (binaryfunc)0, /* nb_divmod */
+ (ternaryfunc)0, /* nb_power */
+ (unaryfunc)0, /* nb_negative */
+ (unaryfunc)0, /* nb_positive */
+ (unaryfunc)0, /* nb_absolute */
+ (inquiry)descr_nonzero, /* nb_nonzero */
+};
+
/*************************************************************************
**************** Implement Mapping Protocol ***************************
*************************************************************************/
@@ -3800,7 +3825,7 @@ NPY_NO_EXPORT PyTypeObject PyArrayDescr_Type = {
0, /* tp_compare */
#endif
(reprfunc)arraydescr_repr, /* tp_repr */
- 0, /* tp_as_number */
+ &descr_as_number, /* tp_as_number */
&descr_as_sequence, /* tp_as_sequence */
&descr_as_mapping, /* tp_as_mapping */
0, /* tp_hash */
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index c52d480a7..c0143aae3 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -620,5 +620,12 @@ def test_rational_dtype():
assert_equal(np.array([x,x]).dtype, np.dtype(rational))
+def test_dtypes_are_true():
+ # test for gh-6294
+ assert bool(np.dtype('f8'))
+ assert bool(np.dtype('i8'))
+ assert bool(np.dtype([('a', 'i8'), ('b', 'f4')]))
+
+
if __name__ == "__main__":
run_module_suite()