diff options
Diffstat (limited to 'numpy/core/src')
-rw-r--r-- | numpy/core/src/arrayobject.c | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 582421ee2..9de78f996 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2720,6 +2720,25 @@ array_subscript(PyArrayObject *self, PyObject *op) return NULL; return add_new_axes_0d(self, nd); } + /* Allow Boolean mask selection also */ + if (PyBool_Check(op) || PyArray_IsScalar(op, Bool) || + (PyArray_Check(op) && (PyArray_DIMS(op)==0) && + PyArray_ISBOOL(op))) { + if (PyObject_IsTrue(op)) { + Py_INCREF(self); + return (PyObject *)self; + } + else { + oned = 0; + Py_INCREF(self->descr); + return PyArray_NewFromDescr(self->ob_type, + self->descr, + 1, &oned, + NULL, NULL, + NPY_DEFAULT, + NULL); + } + } PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed."); return NULL; @@ -2886,10 +2905,27 @@ array_ass_sub(PyArrayObject *self, PyObject *index, PyObject *op) } if (self->nd == 0) { + /* Several different exceptions to the 0-d no-indexing rule + + 1) ellipses + 2) empty tuple + 3) Using newaxis (None) + 4) Boolean mask indexing + */ if (index == Py_Ellipsis || index == Py_None || \ (PyTuple_Check(index) && (0 == PyTuple_GET_SIZE(index) || \ count_new_axes_0d(index) > 0))) return self->descr->f->setitem(op, self->data, self); + if (PyBool_Check(index) || PyArray_IsScalar(index, Bool) || + (PyArray_Check(index) && (PyArray_DIMS(index)==0) && + PyArray_ISBOOL(index))) { + if (PyObject_IsTrue(index)) { + return self->descr->f->setitem(op, self->data, self); + } + else { /* don't do anything */ + return 0; + } + } PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed."); return -1; @@ -3008,6 +3044,10 @@ array_subscript_nice(PyArrayObject *self, PyObject *op) Bool noellipses = TRUE; if (op == Py_Ellipsis) noellipses = FALSE; + else if (PyBool_Check(op) || PyArray_IsScalar(op, Bool) || + (PyArray_Check(op) && (PyArray_DIMS(op)==0) && + PyArray_ISBOOL(op))) + noellipses = FALSE; else if (PySequence_Check(op)) { int n, i; PyObject *temp; |