summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py29
1 files changed, 20 insertions, 9 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 5e8151e68..83d985a7c 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -384,12 +384,12 @@ def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None):
return res
-def _count_nonzero_dispatcher(a, axis=None):
+def _count_nonzero_dispatcher(a, axis=None, *, keepdims=None):
return (a,)
@array_function_dispatch(_count_nonzero_dispatcher)
-def count_nonzero(a, axis=None):
+def count_nonzero(a, axis=None, *, keepdims=False):
"""
Counts the number of non-zero values in the array ``a``.
@@ -414,6 +414,13 @@ def count_nonzero(a, axis=None):
.. versionadded:: 1.12.0
+ keepdims : bool, optional
+ If this is set to True, the axes that are counted are left
+ in the result as dimensions with size one. With this option,
+ the result will broadcast correctly against the input array.
+
+ .. versionadded:: 1.19.0
+
Returns
-------
count : int or array of int
@@ -429,15 +436,19 @@ def count_nonzero(a, axis=None):
--------
>>> np.count_nonzero(np.eye(4))
4
- >>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
+ >>> a = np.array([[0, 1, 7, 0],
+ ... [3, 0, 2, 19]])
+ >>> np.count_nonzero(a)
5
- >>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=0)
- array([1, 1, 1, 1, 1])
- >>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=1)
+ >>> np.count_nonzero(a, axis=0)
+ array([1, 1, 2, 1])
+ >>> np.count_nonzero(a, axis=1)
array([2, 3])
-
+ >>> np.count_nonzero(a, axis=1, keepdims=True)
+ array([[2],
+ [3]])
"""
- if axis is None:
+ if axis is None and not keepdims:
return multiarray.count_nonzero(a)
a = asanyarray(a)
@@ -448,7 +459,7 @@ def count_nonzero(a, axis=None):
else:
a_bool = a.astype(np.bool_, copy=False)
- return a_bool.sum(axis=axis, dtype=np.intp)
+ return a_bool.sum(axis=axis, dtype=np.intp, keepdims=keepdims)
@set_module('numpy')