diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2019-05-23 06:41:18 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2019-09-05 22:07:06 -0700 |
commit | b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 (patch) | |
tree | e10b5be4cbd0c9064060a183572ad22c6a6ee044 /numpy/core/numeric.py | |
parent | e3f4c536a0014789dbd0321926b4f62c39d73719 (diff) | |
download | numpy-b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2.tar.gz |
ENH: Always produce a consistent shape in the result of `argwhere`
Previously this would return 1d indices even though the array is zero-d.
Note that using atleast1d inside numeric required an import change to avoid a circular import.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 8ada87b9f..c395b1348 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -26,6 +26,7 @@ if sys.version_info[0] < 3: from . import overrides from . import umath +from . import shape_base from .overrides import set_module from .umath import (multiply, invert, sin, PINF, NAN) from . import numerictypes @@ -545,8 +546,10 @@ def argwhere(a): Returns ------- - index_array : ndarray + index_array : (N, a.ndim) ndarray Indices of elements that are non-zero. Indices are grouped by element. + This array will have shape ``(N, a.ndim)`` where ``N`` is the number of + non-zero items. See Also -------- @@ -554,7 +557,8 @@ def argwhere(a): Notes ----- - ``np.argwhere(a)`` is the same as ``np.transpose(np.nonzero(a))``. + ``np.argwhere(a)`` is almost the same as ``np.transpose(np.nonzero(a))``, + but produces a result of the correct shape for a 0D array. The output of ``argwhere`` is not suitable for indexing arrays. For this purpose use ``nonzero(a)`` instead. @@ -572,6 +576,11 @@ def argwhere(a): [1, 2]]) """ + # nonzero does not behave well on 0d, so promote to 1d + if np.ndim(a) == 0: + a = shape_base.atleast_1d(a) + # then remove the added dimension + return argwhere(a)[:,:0] return transpose(nonzero(a)) |