summaryrefslogtreecommitdiff
path: root/numpy/core/tests
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2019-05-23 06:41:18 -0700
committerEric Wieser <wieser.eric@gmail.com>2019-09-05 22:07:06 -0700
commitb6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 (patch)
treee10b5be4cbd0c9064060a183572ad22c6a6ee044 /numpy/core/tests
parente3f4c536a0014789dbd0321926b4f62c39d73719 (diff)
downloadnumpy-b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2.tar.gz
ENH: Always produce a consistent shape in the result of `argwhere`
Previously this would return 1d indices even though the array is zero-d. Note that using atleast1d inside numeric required an import change to avoid a circular import.
Diffstat (limited to 'numpy/core/tests')
-rw-r--r--numpy/core/tests/test_numeric.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index c479a0f6d..1358b45e9 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2583,6 +2583,30 @@ class TestConvolve(object):
class TestArgwhere(object):
+
+ @pytest.mark.parametrize('nd', [0, 1, 2])
+ def test_nd(self, nd):
+ # get an nd array with multiple elements in every dimension
+ x = np.empty((2,)*nd, bool)
+
+ # none
+ x[...] = False
+ assert_equal(np.argwhere(x).shape, (0, nd))
+
+ # only one
+ x[...] = False
+ x.flat[0] = True
+ assert_equal(np.argwhere(x).shape, (1, nd))
+
+ # all but one
+ x[...] = True
+ x.flat[0] = False
+ assert_equal(np.argwhere(x).shape, (x.size - 1, nd))
+
+ # all
+ x[...] = True
+ assert_equal(np.argwhere(x).shape, (x.size, nd))
+
def test_2D(self):
x = np.arange(6).reshape((2, 3))
assert_array_equal(np.argwhere(x > 1),