summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2019-05-23 06:41:18 -0700
committerEric Wieser <wieser.eric@gmail.com>2019-09-05 22:07:06 -0700
commitb6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 (patch)
treee10b5be4cbd0c9064060a183572ad22c6a6ee044 /numpy/core/numeric.py
parente3f4c536a0014789dbd0321926b4f62c39d73719 (diff)
downloadnumpy-b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2.tar.gz
ENH: Always produce a consistent shape in the result of `argwhere`
Previously this would return 1d indices even though the array is zero-d. Note that using atleast1d inside numeric required an import change to avoid a circular import.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 8ada87b9f..c395b1348 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -26,6 +26,7 @@ if sys.version_info[0] < 3:
from . import overrides
from . import umath
+from . import shape_base
from .overrides import set_module
from .umath import (multiply, invert, sin, PINF, NAN)
from . import numerictypes
@@ -545,8 +546,10 @@ def argwhere(a):
Returns
-------
- index_array : ndarray
+ index_array : (N, a.ndim) ndarray
Indices of elements that are non-zero. Indices are grouped by element.
+ This array will have shape ``(N, a.ndim)`` where ``N`` is the number of
+ non-zero items.
See Also
--------
@@ -554,7 +557,8 @@ def argwhere(a):
Notes
-----
- ``np.argwhere(a)`` is the same as ``np.transpose(np.nonzero(a))``.
+ ``np.argwhere(a)`` is almost the same as ``np.transpose(np.nonzero(a))``,
+ but produces a result of the correct shape for a 0D array.
The output of ``argwhere`` is not suitable for indexing arrays.
For this purpose use ``nonzero(a)`` instead.
@@ -572,6 +576,11 @@ def argwhere(a):
[1, 2]])
"""
+ # nonzero does not behave well on 0d, so promote to 1d
+ if np.ndim(a) == 0:
+ a = shape_base.atleast_1d(a)
+ # then remove the added dimension
+ return argwhere(a)[:,:0]
return transpose(nonzero(a))