diff options
author | Jim Crist <crist042@umn.edu> | 2016-11-15 15:28:12 -0600 |
---|---|---|
committer | Jim Crist <crist042@umn.edu> | 2017-01-23 12:21:18 -0600 |
commit | 6cbb46e45dde71be76ec6c3f4e8e6fccb4227976 (patch) | |
tree | 2bb3621c4b73895476f2347743d8d25134cb116b /numpy | |
parent | 1c8ecc783c43269631b1fdc8f9e06355d231aec6 (diff) | |
download | numpy-6cbb46e45dde71be76ec6c3f4e8e6fccb4227976.tar.gz |
BUG: bool(dtype) is True
Previously `bool(dtype(...))` would fallback to the default
implementation of `__nonzero__`, which checks if `len(object) > 0`.
Since `dtype` objects implement `__len__` as the number of record
fields, `bool` of scalar dtypes like `bool(dtype('i8'))` would evaluate
as `False`. This fixes that by implementing `__nonzero__` to always
return True.
Fixes #6294.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 27 | ||||
-rw-r--r-- | numpy/core/tests/test_dtype.py | 7 |
2 files changed, 33 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 9fa9ad088..613c4ca1c 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -3651,6 +3651,31 @@ arraydescr_richcompare(PyArray_Descr *self, PyObject *other, int cmp_op) return result; } +static int +descr_nonzero(PyObject *self) +{ + /* `bool(np.dtype(...)) == True` for all dtypes. Needed to override default + * nonzero implementation, which checks if `len(object) > 0`. */ + return 1; +} + +static PyNumberMethods descr_as_number = { + (binaryfunc)0, /* nb_add */ + (binaryfunc)0, /* nb_subtract */ + (binaryfunc)0, /* nb_multiply */ + #if defined(NPY_PY3K) + #else + (binaryfunc)0, /* nb_divide */ + #endif + (binaryfunc)0, /* nb_remainder */ + (binaryfunc)0, /* nb_divmod */ + (ternaryfunc)0, /* nb_power */ + (unaryfunc)0, /* nb_negative */ + (unaryfunc)0, /* nb_positive */ + (unaryfunc)0, /* nb_absolute */ + (inquiry)descr_nonzero, /* nb_nonzero */ +}; + /************************************************************************* **************** Implement Mapping Protocol *************************** *************************************************************************/ @@ -3800,7 +3825,7 @@ NPY_NO_EXPORT PyTypeObject PyArrayDescr_Type = { 0, /* tp_compare */ #endif (reprfunc)arraydescr_repr, /* tp_repr */ - 0, /* tp_as_number */ + &descr_as_number, /* tp_as_number */ &descr_as_sequence, /* tp_as_sequence */ &descr_as_mapping, /* tp_as_mapping */ 0, /* tp_hash */ diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index c52d480a7..c0143aae3 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -620,5 +620,12 @@ def test_rational_dtype(): assert_equal(np.array([x,x]).dtype, np.dtype(rational)) +def test_dtypes_are_true(): + # test for gh-6294 + assert bool(np.dtype('f8')) + assert bool(np.dtype('i8')) + assert bool(np.dtype([('a', 'i8'), ('b', 'f4')])) + + if __name__ == "__main__": run_module_suite() |