diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2019-09-06 09:45:21 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-09-06 09:45:21 -0500 |
commit | 94ae1c59d78e64a2eda219ade28f1180e3c2d9af (patch) | |
tree | a90ef89986b30b1ae8098fb7e30f186d355003e3 /numpy/core/tests | |
parent | e4e12cba59134e7ec2c4e98b98a7fa162b590f66 (diff) | |
parent | b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 (diff) | |
download | numpy-94ae1c59d78e64a2eda219ade28f1180e3c2d9af.tar.gz |
Merge pull request #13610 from eric-wieser/argwhere
ENH: Always produce a consistent shape in the result of `argwhere`
Diffstat (limited to 'numpy/core/tests')
-rw-r--r-- | numpy/core/tests/test_numeric.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index c479a0f6d..1358b45e9 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2583,6 +2583,30 @@ class TestConvolve(object): class TestArgwhere(object): + + @pytest.mark.parametrize('nd', [0, 1, 2]) + def test_nd(self, nd): + # get an nd array with multiple elements in every dimension + x = np.empty((2,)*nd, bool) + + # none + x[...] = False + assert_equal(np.argwhere(x).shape, (0, nd)) + + # only one + x[...] = False + x.flat[0] = True + assert_equal(np.argwhere(x).shape, (1, nd)) + + # all but one + x[...] = True + x.flat[0] = False + assert_equal(np.argwhere(x).shape, (x.size - 1, nd)) + + # all + x[...] = True + assert_equal(np.argwhere(x).shape, (x.size, nd)) + def test_2D(self): x = np.arange(6).reshape((2, 3)) assert_array_equal(np.argwhere(x > 1), |