summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2009-07-08 20:47:05 +0000
committerStefan van der Walt <stefan@sun.ac.za>2009-07-08 20:47:05 +0000
commit8ea0648357ae54483f9eed69b55b418b80acdb1f (patch)
treed82755607529330a91cf9c4ae04d8ad5eb18c5cd
parent8ba97d80c6b0e2834b92e1d85bb63430917bcc7b (diff)
downloadnumpy-8ea0648357ae54483f9eed69b55b418b80acdb1f.tar.gz
Fix argwhere for masked arrays.
-rw-r--r--numpy/core/numeric.py2
-rw-r--r--numpy/core/tests/test_numeric.py4
2 files changed, 5 insertions, 1 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index b2ac65cda..b497e4898 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -535,7 +535,7 @@ def argwhere(a):
[1, 2]])
"""
- return transpose(asarray(a).nonzero())
+ return transpose(asanyarray(a).nonzero())
def flatnonzero(a):
"""
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index b0227ba12..18f879cb4 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -947,5 +947,9 @@ class TestArgwhere:
def test_list(self):
assert_equal(np.argwhere([4, 0, 2, 1, 3]), [[0], [2], [3], [4]])
+ def test_masked_array(self):
+ a = np.ma.array([0, 1, 2, 3], mask=[0, 0, 1, 0])
+ assert_equal(np.argwhere(a), [[1], [3]])
+
if __name__ == "__main__":
run_module_suite()