diff options
author | gfyoung <gfyoung17@gmail.com> | 2016-01-29 03:25:53 +0000 |
---|---|---|
committer | gfyoung <gfyoung17@gmail.com> | 2016-08-04 22:20:45 -0400 |
commit | 0fc9e4520b1d00b58a77f28936da2fec2672de83 (patch) | |
tree | eec3c9c072572b93bc8c4d2ee81fd44b8243e462 /numpy/core/numeric.py | |
parent | bfd91d9e91ec5ea1c1d77b27b09952a11a24e19e (diff) | |
download | numpy-0fc9e4520b1d00b58a77f28936da2fec2672de83.tar.gz |
ENH: added axis param for np.count_nonzero
Closes gh-391.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 89 |
1 files changed, 87 insertions, 2 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index b3eed9714..8db4e1302 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -6,6 +6,7 @@ import operator import sys import warnings +import numpy as np from . import multiarray from .multiarray import ( _fastCopyAndTranspose as fastCopyAndTranspose, ALLOW_THREADS, @@ -376,6 +377,89 @@ def extend_all(module): __all__.append(a) +def count_nonzero(a, axis=None): + """ + Counts the number of non-zero values in the array ``a``. + + The word "non-zero" is in reference to the Python 2.x + built-in method ``__nonzero__()`` (renamed ``__bool__()`` + in Python 3.x) of Python objects that tests an object's + "truthfulness". For example, any number is considered + truthful if it is nonzero, whereas any string is considered + truthful if it is not the empty string. Thus, this function + (recursively) counts how many elements in ``a`` (and in + sub-arrays thereof) have their ``__nonzero__()`` or ``__bool__()`` + method evaluated to ``True``. + + Parameters + ---------- + a : array_like + The array for which to count non-zeros. + axis : int or tuple, optional + Axis or tuple of axes along which to count non-zeros. + Default is None, meaning that non-zeros will be counted + along a flattened version of ``a``. + + .. versionadded:: 1.12.0 + + Returns + ------- + count : int or array of int + Number of non-zero values in the array along a given axis. + Otherwise, the total number of non-zero values in the array + is returned. + + See Also + -------- + nonzero : Return the coordinates of all the non-zero values. + + Examples + -------- + >>> np.count_nonzero(np.eye(4)) + 4 + >>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]]) + 5 + >>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=0) + array([1, 1, 1, 1, 1]) + >>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=1) + array([2, 3]) + + """ + if axis is None or axis == (): + return multiarray.count_nonzero(a) + + a = asanyarray(a) + + if a.dtype == bool: + return a.sum(axis=axis, dtype=np.intp) + + if issubdtype(a.dtype, np.number): + return (a != 0).sum(axis=axis, dtype=np.intp) + + if (issubdtype(a.dtype, np.string_) or + issubdtype(a.dtype, np.unicode_)): + nullstr = a.dtype.type('') + return (a != nullstr).sum(axis=axis, dtype=np.intp) + + axis = asarray(_validate_axis(axis, a.ndim, 'axis')) + counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a) + + if axis.size == 1: + return counts + else: + # for subsequent axis numbers, that number decreases + # by one in this new 'counts' array if it was larger + # than the first axis upon which 'count_nonzero' was + # applied but remains unchanged if that number was + # smaller than that first axis + # + # this trick enables us to perform counts on object-like + # elements across multiple axes very quickly because integer + # addition is very well optimized + return counts.sum(axis=tuple(axis[1:] - ( + axis[1:] > axis[0])), dtype=np.intp) + + def asarray(a, dtype=None, order=None): """Convert the input to an array. @@ -891,7 +975,7 @@ def correlate(a, v, mode='valid'): return multiarray.correlate2(a, v, mode) -def convolve(a,v,mode='full'): +def convolve(a, v, mode='full'): """ Returns the discrete, linear convolution of two one-dimensional sequences. @@ -1752,7 +1836,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): return rollaxis(cp, -1, axisc) -#Use numarray's printing function +# Use numarray's printing function from .arrayprint import array2string, get_printoptions, set_printoptions @@ -2283,6 +2367,7 @@ def load(file): # These are all essentially abbreviations # These might wind up in a special abbreviations module + def _maketup(descr, val): dt = dtype(descr) # Place val in all scalar tuples: |