summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-10-11 00:56:26 -0700
committerEric Wieser <wieser.eric@gmail.com>2017-10-12 00:02:16 -0700
commitfb168b8a5ee222ff352d20bfc1efab9009d68347 (patch)
treef78dca9ca806f1e10990c5c39095aaf3ab46eb32 /numpy/core/numeric.py
parente64699dcaca9eb0dd97deabae01ffc2884cacbb0 (diff)
downloadnumpy-fb168b8a5ee222ff352d20bfc1efab9009d68347.tar.gz
MAINT: Fix all special-casing of dtypes in `count_nonzero`
A quick profile reveals that `int.astype(bool)` is faster than `int == 0` anyway
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py32
1 files changed, 6 insertions, 26 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 5b10361fe..6d29785da 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -411,33 +411,13 @@ def count_nonzero(a, axis=None):
a = asanyarray(a)
- if a.dtype == bool:
- return a.sum(axis=axis, dtype=np.intp)
-
- if issubdtype(a.dtype, np.number):
- return (a != 0).sum(axis=axis, dtype=np.intp)
-
- if issubdtype(a.dtype, np.character):
- nullstr = a.dtype.type('')
- return (a != nullstr).sum(axis=axis, dtype=np.intp)
-
- axis = asarray(normalize_axis_tuple(axis, a.ndim))
- counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
-
- if axis.size == 1:
- return counts.astype(np.intp, copy=False)
+ # TODO: this works around .astype(bool) not working properly (gh-9847)
+ if np.issubdtype(a.dtype, np.character):
+ a_bool = a != a.dtype.type()
else:
- # for subsequent axis numbers, that number decreases
- # by one in this new 'counts' array if it was larger
- # than the first axis upon which 'count_nonzero' was
- # applied but remains unchanged if that number was
- # smaller than that first axis
- #
- # this trick enables us to perform counts on object-like
- # elements across multiple axes very quickly because integer
- # addition is very well optimized
- return counts.sum(axis=tuple(axis[1:] - (
- axis[1:] > axis[0])), dtype=np.intp)
+ a_bool = a.astype(np.bool_, copy=False)
+
+ return a_bool.sum(axis=axis, dtype=np.intp)
def asarray(a, dtype=None, order=None):