summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2008-01-23 22:13:10 +0000
committerStefan van der Walt <stefan@sun.ac.za>2008-01-23 22:13:10 +0000
commite71a28f3c36b3a2be86d9d6ed472b4d27e63bf8c (patch)
tree77ad712c7517ca5f8df40263f510d1a7496c4775 /numpy
parent1447cc9d29e98ea7debf88a8442f6a7fa0835103 (diff)
downloadnumpy-e71a28f3c36b3a2be86d9d6ed472b4d27e63bf8c.tar.gz
Add 'compress'.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py10
-rw-r--r--numpy/ma/tests/test_core.py6
2 files changed, 14 insertions, 2 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index bb7e339fc..2ce1a2bd5 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -26,8 +26,8 @@ __all__ = ['MAError', 'MaskType', 'MaskedArray',
'arctanh', 'argmax', 'argmin', 'argsort', 'around',
'array', 'asarray','asanyarray',
'bitwise_and', 'bitwise_or', 'bitwise_xor',
- 'ceil', 'choose', 'compressed', 'concatenate', 'conjugate',
- 'cos', 'cosh', 'count',
+ 'ceil', 'choose', 'compress', 'compressed', 'concatenate',
+ 'conjugate', 'cos', 'cosh', 'count',
'default_fill_value', 'diagonal', 'divide', 'dump', 'dumps',
'empty', 'empty_like', 'equal', 'exp',
'fabs', 'fmod', 'filled', 'floor', 'floor_divide','fix_invalid',
@@ -3099,6 +3099,12 @@ def choose (indices, t, out=None, mode='raise'):
m = make_mask(mask_or(m, getmask(indices)), copy=0, shrink=True)
return masked_array(d, mask=m)
+def compress(a, condition):
+ """Return a where condition is True.
+
+ """
+ return a[condition]
+
def round_(a, decimals=0, out=None):
"""Return a copy of a, rounded to 'decimals' places.
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 2dc758513..9b4f19284 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1388,6 +1388,12 @@ class TestArrayMethods(NumpyTestCase):
putmask(mxx, mask, values)
assert_equal(mxx, [1,2,30,4,5,60])
+ def test_compress(self):
+ a = array([1,2,3],mask=[True,False,False])
+ b = compress(a,a<3)
+ assert_equal(b,[1,2])
+ assert_equal(b.mask,[True,False])
+
#..............................................................................