summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorgfyoung <gfyoung17@gmail.com>2016-01-29 03:25:53 +0000
committergfyoung <gfyoung17@gmail.com>2016-08-04 22:20:45 -0400
commit0fc9e4520b1d00b58a77f28936da2fec2672de83 (patch)
treeeec3c9c072572b93bc8c4d2ee81fd44b8243e462 /numpy/core/numeric.py
parentbfd91d9e91ec5ea1c1d77b27b09952a11a24e19e (diff)
downloadnumpy-0fc9e4520b1d00b58a77f28936da2fec2672de83.tar.gz
ENH: added axis param for np.count_nonzero
Closes gh-391.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py89
1 files changed, 87 insertions, 2 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index b3eed9714..8db4e1302 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -6,6 +6,7 @@ import operator
import sys
import warnings
+import numpy as np
from . import multiarray
from .multiarray import (
_fastCopyAndTranspose as fastCopyAndTranspose, ALLOW_THREADS,
@@ -376,6 +377,89 @@ def extend_all(module):
__all__.append(a)
+def count_nonzero(a, axis=None):
+ """
+ Counts the number of non-zero values in the array ``a``.
+
+ The word "non-zero" is in reference to the Python 2.x
+ built-in method ``__nonzero__()`` (renamed ``__bool__()``
+ in Python 3.x) of Python objects that tests an object's
+ "truthfulness". For example, any number is considered
+ truthful if it is nonzero, whereas any string is considered
+ truthful if it is not the empty string. Thus, this function
+ (recursively) counts how many elements in ``a`` (and in
+ sub-arrays thereof) have their ``__nonzero__()`` or ``__bool__()``
+ method evaluated to ``True``.
+
+ Parameters
+ ----------
+ a : array_like
+ The array for which to count non-zeros.
+ axis : int or tuple, optional
+ Axis or tuple of axes along which to count non-zeros.
+ Default is None, meaning that non-zeros will be counted
+ along a flattened version of ``a``.
+
+ .. versionadded:: 1.12.0
+
+ Returns
+ -------
+ count : int or array of int
+ Number of non-zero values in the array along a given axis.
+ Otherwise, the total number of non-zero values in the array
+ is returned.
+
+ See Also
+ --------
+ nonzero : Return the coordinates of all the non-zero values.
+
+ Examples
+ --------
+ >>> np.count_nonzero(np.eye(4))
+ 4
+ >>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
+ 5
+ >>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=0)
+ array([1, 1, 1, 1, 1])
+ >>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=1)
+ array([2, 3])
+
+ """
+ if axis is None or axis == ():
+ return multiarray.count_nonzero(a)
+
+ a = asanyarray(a)
+
+ if a.dtype == bool:
+ return a.sum(axis=axis, dtype=np.intp)
+
+ if issubdtype(a.dtype, np.number):
+ return (a != 0).sum(axis=axis, dtype=np.intp)
+
+ if (issubdtype(a.dtype, np.string_) or
+ issubdtype(a.dtype, np.unicode_)):
+ nullstr = a.dtype.type('')
+ return (a != nullstr).sum(axis=axis, dtype=np.intp)
+
+ axis = asarray(_validate_axis(axis, a.ndim, 'axis'))
+ counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
+
+ if axis.size == 1:
+ return counts
+ else:
+ # for subsequent axis numbers, that number decreases
+ # by one in this new 'counts' array if it was larger
+ # than the first axis upon which 'count_nonzero' was
+ # applied but remains unchanged if that number was
+ # smaller than that first axis
+ #
+ # this trick enables us to perform counts on object-like
+ # elements across multiple axes very quickly because integer
+ # addition is very well optimized
+ return counts.sum(axis=tuple(axis[1:] - (
+ axis[1:] > axis[0])), dtype=np.intp)
+
+
def asarray(a, dtype=None, order=None):
"""Convert the input to an array.
@@ -891,7 +975,7 @@ def correlate(a, v, mode='valid'):
return multiarray.correlate2(a, v, mode)
-def convolve(a,v,mode='full'):
+def convolve(a, v, mode='full'):
"""
Returns the discrete, linear convolution of two one-dimensional sequences.
@@ -1752,7 +1836,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
return rollaxis(cp, -1, axisc)
-#Use numarray's printing function
+# Use numarray's printing function
from .arrayprint import array2string, get_printoptions, set_printoptions
@@ -2283,6 +2367,7 @@ def load(file):
# These are all essentially abbreviations
# These might wind up in a special abbreviations module
+
def _maketup(descr, val):
dt = dtype(descr)
# Place val in all scalar tuples: