summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorTravis E. Oliphant <teoliphant@gmail.com>2012-06-21 04:56:49 -0500
committerTravis E. Oliphant <teoliphant@gmail.com>2012-06-21 04:56:49 -0500
commitdb1701baf57965e8533293426a0c3b403e81f39a (patch)
tree70d09f276cba5725eb8d46e6b703065a0ff6d906 /numpy
parent134174c9265dd87ea802c89cac7a89478e3184f4 (diff)
downloadnumpy-db1701baf57965e8533293426a0c3b403e81f39a.tar.gz
BUG: Fix boolean indexing to previous behavior by adding an additional check before using the new code path. Add tests.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/mapping.c7
-rw-r--r--numpy/core/tests/test_multiarray.py35
2 files changed, 40 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c
index 769fb653e..7abb3ff7f 100644
--- a/numpy/core/src/multiarray/mapping.c
+++ b/numpy/core/src/multiarray/mapping.c
@@ -1023,8 +1023,10 @@ array_subscript(PyArrayObject *self, PyObject *op)
}
/* Boolean indexing special case */
+ /* The SIZE check might be overly cautious */
if (PyArray_Check(op) && (PyArray_TYPE((PyArrayObject *)op) == NPY_BOOL)
- && (PyArray_NDIM(self) == PyArray_NDIM((PyArrayObject *)op))) {
+ && (PyArray_NDIM(self) == PyArray_NDIM((PyArrayObject *)op))
+ && (PyArray_SIZE((PyArrayObject *)op) == PyArray_SIZE(self))) {
return (PyObject *)array_boolean_subscript(self,
(PyArrayObject *)op, NPY_CORDER);
}
@@ -1269,7 +1271,8 @@ array_ass_sub(PyArrayObject *self, PyObject *ind, PyObject *op)
/* Boolean indexing special case */
if (PyArray_Check(ind) &&
(PyArray_TYPE((PyArrayObject *)ind) == NPY_BOOL) &&
- (PyArray_NDIM(self) == PyArray_NDIM((PyArrayObject *)ind))) {
+ (PyArray_NDIM(self) == PyArray_NDIM((PyArrayObject *)ind)) &&
+ (PyArray_SIZE(self) == PyArray_SIZE((PyArrayObject *)ind))) {
int retcode;
PyArrayObject *op_arr;
PyArray_Descr *dtype = NULL;
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index bab83031b..82ead8958 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1058,6 +1058,41 @@ class TestFancyIndexing(TestCase):
x[:,:,(0,)] = 2.0
assert_array_equal(x, array([[[2.0]]]))
+ def test_mask(self):
+ x = array([1,2,3,4])
+ m = array([0,1],bool)
+ assert_array_equal(x[m], array([2]))
+
+ def test_mask2(self):
+ x = array([[1,2,3,4],[5,6,7,8]])
+ m = array([0,1],bool)
+ m2 = array([[0,1],[1,0]], bool)
+ m3 = array([[0,1]], bool)
+ assert_array_equal(x[m], array([[5,6,7,8]]))
+ assert_array_equal(x[m2], array([2,5]))
+ assert_array_equal(x[m3], array([2]))
+
+ def test_assign_mask(self):
+ x = array([1,2,3,4])
+ m = array([0,1],bool)
+ x[m] = 5
+ assert_array_equal(x, array([1,5,3,4]))
+
+ def test_assign_mask(self):
+ xorig = array([[1,2,3,4],[5,6,7,8]])
+ m = array([0,1],bool)
+ m2 = array([[0,1],[1,0]],bool)
+ m3 = array([[0,1]], bool)
+ x = xorig.copy()
+ x[m] = 10
+ assert_array_equal(x, array([[1,2,3,4],[10,10,10,10]]))
+ x = xorig.copy()
+ x[m2] = 10
+ assert_array_equal(x, array([[1,10,3,4],[10,6,7,8]]))
+ x = xorig.copy()
+ x[m3] = 10
+ assert_array_equal(x, array([[1,10,3,4],[5,6,7,8]]))
+
class TestStringCompare(TestCase):
def test_string(self):