summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-10-17 20:41:39 -0700
committerEric Wieser <wieser.eric@gmail.com>2017-10-17 20:41:39 -0700
commit3856a73ec59d9dcf8252a2d38b43191eacacbd2e (patch)
tree1f234ad1ec5a45546d6d011efa3a0845647661eb
parent77f9540f277de3e727d696eb5987a0505ed8cdf9 (diff)
downloadnumpy-3856a73ec59d9dcf8252a2d38b43191eacacbd2e.tar.gz
BUG: count_nonzero treats empty axis tuples strangely
Fixes #9728 This bug was introduced with the `axis` keyword in #7177, as a misguided optimization.
-rw-r--r--doc/release/1.14.0-notes.rst7
-rw-r--r--numpy/core/numeric.py2
-rw-r--r--numpy/core/tests/test_numeric.py4
3 files changed, 12 insertions, 1 deletions
diff --git a/doc/release/1.14.0-notes.rst b/doc/release/1.14.0-notes.rst
index 5dbd9b44c..f576923b2 100644
--- a/doc/release/1.14.0-notes.rst
+++ b/doc/release/1.14.0-notes.rst
@@ -122,6 +122,13 @@ passed, despite not doing so under the simple cases::
This change affects only ``float32`` and ``float16`` arrays.
+``count_nonzero(arr, axis=())`` now counts over no axes, not all axes
+---------------------------------------------------------------------
+Elsewhere, ``axis==()`` is always understood as "no axes", but
+`count_nonzero` had a special case to treat this as "all axes". This was
+inconsistent and surprising. The correct way to count over all axes has always
+been to pass ``axis == None``.
+
``__init__.py`` files added to test directories
-----------------------------------------------
This is for pytest compatibility in the case of duplicate test file names in
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 6d29785da..1f8f5d43e 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -406,7 +406,7 @@ def count_nonzero(a, axis=None):
array([2, 3])
"""
- if axis is None or (isinstance(axis, tuple) and axis == ()):
+ if axis is None:
return multiarray.count_nonzero(a)
a = asanyarray(a)
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index f1133b8c9..7afdc01ca 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1126,6 +1126,10 @@ class TestNonzero(object):
np.count_nonzero(n, axis=perm),
err_msg=msg % (perm,))
+ def test_countnonzero_axis_empty(self):
+ a = np.array([[0, 0, 1], [1, 0, 1]])
+ assert_equal(np.count_nonzero(a, axis=()), a.astype(bool))
+
def test_array_method(self):
# Tests that the array method
# call to nonzero works