summaryrefslogtreecommitdiff
path: root/numpy/core/tests
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2019-09-06 09:45:21 -0500
committerGitHub <noreply@github.com>2019-09-06 09:45:21 -0500
commit94ae1c59d78e64a2eda219ade28f1180e3c2d9af (patch)
treea90ef89986b30b1ae8098fb7e30f186d355003e3 /numpy/core/tests
parente4e12cba59134e7ec2c4e98b98a7fa162b590f66 (diff)
parentb6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 (diff)
downloadnumpy-94ae1c59d78e64a2eda219ade28f1180e3c2d9af.tar.gz
Merge pull request #13610 from eric-wieser/argwhere
ENH: Always produce a consistent shape in the result of `argwhere`
Diffstat (limited to 'numpy/core/tests')
-rw-r--r--numpy/core/tests/test_numeric.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index c479a0f6d..1358b45e9 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2583,6 +2583,30 @@ class TestConvolve(object):
class TestArgwhere(object):
+
+ @pytest.mark.parametrize('nd', [0, 1, 2])
+ def test_nd(self, nd):
+ # get an nd array with multiple elements in every dimension
+ x = np.empty((2,)*nd, bool)
+
+ # none
+ x[...] = False
+ assert_equal(np.argwhere(x).shape, (0, nd))
+
+ # only one
+ x[...] = False
+ x.flat[0] = True
+ assert_equal(np.argwhere(x).shape, (1, nd))
+
+ # all but one
+ x[...] = True
+ x.flat[0] = False
+ assert_equal(np.argwhere(x).shape, (x.size - 1, nd))
+
+ # all
+ x[...] = True
+ assert_equal(np.argwhere(x).shape, (x.size, nd))
+
def test_2D(self):
x = np.arange(6).reshape((2, 3))
assert_array_equal(np.argwhere(x > 1),