diff options
author | Travis E. Oliphant <teoliphant@gmail.com> | 2012-06-21 04:56:49 -0500 |
---|---|---|
committer | Travis E. Oliphant <teoliphant@gmail.com> | 2012-06-21 04:56:49 -0500 |
commit | db1701baf57965e8533293426a0c3b403e81f39a (patch) | |
tree | 70d09f276cba5725eb8d46e6b703065a0ff6d906 /numpy | |
parent | 134174c9265dd87ea802c89cac7a89478e3184f4 (diff) | |
download | numpy-db1701baf57965e8533293426a0c3b403e81f39a.tar.gz |
BUG: Fix boolean indexing to previous behavior by adding an additional check before using the new code path. Add tests.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/mapping.c | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 35 |
2 files changed, 40 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c index 769fb653e..7abb3ff7f 100644 --- a/numpy/core/src/multiarray/mapping.c +++ b/numpy/core/src/multiarray/mapping.c @@ -1023,8 +1023,10 @@ array_subscript(PyArrayObject *self, PyObject *op) } /* Boolean indexing special case */ + /* The SIZE check might be overly cautious */ if (PyArray_Check(op) && (PyArray_TYPE((PyArrayObject *)op) == NPY_BOOL) - && (PyArray_NDIM(self) == PyArray_NDIM((PyArrayObject *)op))) { + && (PyArray_NDIM(self) == PyArray_NDIM((PyArrayObject *)op)) + && (PyArray_SIZE((PyArrayObject *)op) == PyArray_SIZE(self))) { return (PyObject *)array_boolean_subscript(self, (PyArrayObject *)op, NPY_CORDER); } @@ -1269,7 +1271,8 @@ array_ass_sub(PyArrayObject *self, PyObject *ind, PyObject *op) /* Boolean indexing special case */ if (PyArray_Check(ind) && (PyArray_TYPE((PyArrayObject *)ind) == NPY_BOOL) && - (PyArray_NDIM(self) == PyArray_NDIM((PyArrayObject *)ind))) { + (PyArray_NDIM(self) == PyArray_NDIM((PyArrayObject *)ind)) && + (PyArray_SIZE(self) == PyArray_SIZE((PyArrayObject *)ind))) { int retcode; PyArrayObject *op_arr; PyArray_Descr *dtype = NULL; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index bab83031b..82ead8958 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1058,6 +1058,41 @@ class TestFancyIndexing(TestCase): x[:,:,(0,)] = 2.0 assert_array_equal(x, array([[[2.0]]])) + def test_mask(self): + x = array([1,2,3,4]) + m = array([0,1],bool) + assert_array_equal(x[m], array([2])) + + def test_mask2(self): + x = array([[1,2,3,4],[5,6,7,8]]) + m = array([0,1],bool) + m2 = array([[0,1],[1,0]], bool) + m3 = array([[0,1]], bool) + assert_array_equal(x[m], array([[5,6,7,8]])) + assert_array_equal(x[m2], array([2,5])) + assert_array_equal(x[m3], array([2])) + + def test_assign_mask(self): + x = array([1,2,3,4]) + m = array([0,1],bool) + x[m] = 5 + assert_array_equal(x, array([1,5,3,4])) + + def test_assign_mask(self): + xorig = array([[1,2,3,4],[5,6,7,8]]) + m = array([0,1],bool) + m2 = array([[0,1],[1,0]],bool) + m3 = array([[0,1]], bool) + x = xorig.copy() + x[m] = 10 + assert_array_equal(x, array([[1,2,3,4],[10,10,10,10]])) + x = xorig.copy() + x[m2] = 10 + assert_array_equal(x, array([[1,10,3,4],[10,6,7,8]])) + x = xorig.copy() + x[m3] = 10 + assert_array_equal(x, array([[1,10,3,4],[5,6,7,8]])) + class TestStringCompare(TestCase): def test_string(self): |