diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2019-05-23 06:41:18 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2019-09-05 22:07:06 -0700 |
commit | b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 (patch) | |
tree | e10b5be4cbd0c9064060a183572ad22c6a6ee044 /numpy/core/tests | |
parent | e3f4c536a0014789dbd0321926b4f62c39d73719 (diff) | |
download | numpy-b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2.tar.gz |
ENH: Always produce a consistent shape in the result of `argwhere`
Previously this would return 1d indices even though the array is zero-d.
Note that using atleast1d inside numeric required an import change to avoid a circular import.
Diffstat (limited to 'numpy/core/tests')
-rw-r--r-- | numpy/core/tests/test_numeric.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index c479a0f6d..1358b45e9 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2583,6 +2583,30 @@ class TestConvolve(object): class TestArgwhere(object): + + @pytest.mark.parametrize('nd', [0, 1, 2]) + def test_nd(self, nd): + # get an nd array with multiple elements in every dimension + x = np.empty((2,)*nd, bool) + + # none + x[...] = False + assert_equal(np.argwhere(x).shape, (0, nd)) + + # only one + x[...] = False + x.flat[0] = True + assert_equal(np.argwhere(x).shape, (1, nd)) + + # all but one + x[...] = True + x.flat[0] = False + assert_equal(np.argwhere(x).shape, (x.size - 1, nd)) + + # all + x[...] = True + assert_equal(np.argwhere(x).shape, (x.size, nd)) + def test_2D(self): x = np.arange(6).reshape((2, 3)) assert_array_equal(np.argwhere(x > 1), |