summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/arrayobject.c51
-rw-r--r--numpy/core/src/scalartypes.inc.src1
-rw-r--r--numpy/core/tests/test_numeric.py21
-rw-r--r--numpy/lib/index_tricks.py2
4 files changed, 54 insertions, 21 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 03fd4d445..3b99cf550 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -2175,7 +2175,7 @@ fancy_indexing_check(PyObject *args)
for (i=0; i<n; i++) {
obj = PyTuple_GET_ITEM(args,i);
if (PyArray_Check(obj)) {
- if (PyArray_ISINTEGER(obj))
+ if (PyArray_ISINTEGER(obj) || PyArray_ISBOOL(obj))
retval = SOBJ_ISFANCY;
else {
retval = SOBJ_BADARRAY;
@@ -5429,7 +5429,6 @@ static PyObject *
array_struct_get(PyArrayObject *self)
{
PyArrayInterface *inter;
- PyObject *tup;
inter = (PyArrayInterface *)_pya_malloc(sizeof(PyArrayInterface));
inter->dummy = 2;
@@ -6844,7 +6843,6 @@ static PyObject *
PyArray_FromStructInterface(PyObject *input)
{
PyArray_Descr *thetype=NULL;
- PyObject *obj;
char buf[40];
PyArrayInterface *inter;
PyObject *attr, *r;
@@ -8256,6 +8254,8 @@ static PyTypeObject PyArrayIter_Type = {
* and so that indexing can be set up ahead of time *
*/
+
+static int _nonzero_indices(PyObject *myBool, PyArrayIterObject **iters);
/* convert an indexing object to an INTP indexing array iterator
if possible -- otherwise, it is a Slice or Ellipsis object
and has to be interpreted on bind to a particular
@@ -8268,7 +8268,10 @@ _convert_obj(PyObject *obj, PyArrayIterObject **iter)
PyObject *arr;
if (PySlice_Check(obj) || (obj == Py_Ellipsis))
- *iter = NULL;
+ return 0;
+ else if (PyArray_Check(obj) && PyArray_ISBOOL(obj)) {
+ return _nonzero_indices(obj, iter);
+ }
else {
indtype = PyArray_DescrFromType(PyArray_INTP);
arr = PyArray_FromAny(obj, indtype, 0, 0, FORCECAST, NULL);
@@ -8277,7 +8280,7 @@ _convert_obj(PyObject *obj, PyArrayIterObject **iter)
Py_DECREF(arr);
if (*iter == NULL) return -1;
}
- return 0;
+ return 1;
}
/* Adjust dimensionality and strides for index object iterators
@@ -8698,7 +8701,7 @@ PyArray_MapIterNew(PyObject *indexobj, int oned, int fancy)
if (fancy == SOBJ_BADARRAY) {
PyErr_SetString(PyExc_IndexError, \
"arrays used as indices must be of " \
- "integer type");
+ "integer (or boolean) type");
return NULL;
}
if (fancy == SOBJ_TOOMANY) {
@@ -8773,32 +8776,52 @@ PyArray_MapIterNew(PyObject *indexobj, int oned, int fancy)
}
else { /* must be a tuple */
PyObject *obj;
- PyArrayIterObject *iter;
+ PyArrayIterObject **iterp;
PyObject *new;
+ int numiters, j, n2;
/* Make a copy of the tuple -- we will be replacing
index objects with 0's */
n = PyTuple_GET_SIZE(indexobj);
- new = PyTuple_New(n);
+ n2 = n;
+ new = PyTuple_New(n2);
if (new == NULL) goto fail;
started = 0;
nonindex = 0;
+ j = 0;
for (i=0; i<n; i++) {
obj = PyTuple_GET_ITEM(indexobj,i);
- if (_convert_obj(obj, &iter) < 0) {
+ iterp = mit->iters + mit->numiter;
+ if ((numiters=_convert_obj(obj, iterp)) < 0) {
Py_DECREF(new);
goto fail;
}
- if (iter!= NULL) {
+ if (numiters > 0) {
started = 1;
if (nonindex) mit->consec = 0;
- mit->iters[(mit->numiter)++] = iter;
- PyTuple_SET_ITEM(new,i,
- PyInt_FromLong(0));
+ mit->numiter += numiters;
+ if (numiters == 1) {
+ PyTuple_SET_ITEM(new,j++,
+ PyInt_FromLong(0));
+ }
+ else { /* we need to grow the
+ new indexing object and fill
+ it with 0s for each of the iterators
+ produced */
+ int k;
+ n2 += numiters - 1;
+ if (_PyTuple_Resize(&new, n2) < 0)
+ goto fail;
+ for (k=0;k<numiters;k++) {
+ PyTuple_SET_ITEM \
+ (new,j++,
+ PyInt_FromLong(0));
+ }
+ }
}
else {
if (started) nonindex = 1;
Py_INCREF(obj);
- PyTuple_SET_ITEM(new,i,obj);
+ PyTuple_SET_ITEM(new,j++,obj);
}
}
Py_DECREF(mit->indexobj);
diff --git a/numpy/core/src/scalartypes.inc.src b/numpy/core/src/scalartypes.inc.src
index 03ab93da3..53cd6e883 100644
--- a/numpy/core/src/scalartypes.inc.src
+++ b/numpy/core/src/scalartypes.inc.src
@@ -726,7 +726,6 @@ gentype_struct_get(PyObject *self)
{
PyArrayObject *arr;
PyArrayInterface *inter;
- PyObject *tup;
arr = (PyArrayObject *)PyArray_FromScalar(self, NULL);
inter = (PyArrayInterface *)_pya_malloc(sizeof(PyArrayInterface));
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 6c3fef8b3..9f6118f19 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1,9 +1,9 @@
from numpy.core import *
-from numpy.random import rand
+from numpy.random import rand, randint
from numpy.testing import *
from numpy.core.multiarray import dot as dot_
-class test_dot(ScipyTestCase):
+class test_dot(NumpyTestCase):
def setUp(self):
self.A = rand(10,8)
self.b1 = rand(8,1)
@@ -104,7 +104,7 @@ class test_dot(ScipyTestCase):
assert_almost_equal(c1, c2, decimal=self.N)
-class test_bool_scalar(ScipyTestCase):
+class test_bool_scalar(NumpyTestCase):
def test_logical(self):
f = False_
t = True_
@@ -137,7 +137,7 @@ class test_bool_scalar(ScipyTestCase):
self.failUnless((f ^ f) is f)
-class test_seterr(ScipyTestCase):
+class test_seterr(NumpyTestCase):
def test_set(self):
err = seterr()
old = seterr(divide='warn')
@@ -161,7 +161,7 @@ class test_seterr(ScipyTestCase):
array([1.]) / array([0.])
-class test_fromiter(ScipyTestCase):
+class test_fromiter(NumpyTestCase):
def makegen(self):
for x in xrange(24):
@@ -195,7 +195,16 @@ class test_fromiter(ScipyTestCase):
self.failUnless(alltrue(a == expected))
self.failUnless(alltrue(a20 == expected[:20]))
-class test_binary_repr(ScipyTestCase):
+class test_index(NumpyTestCase):
+ def test_boolean(self):
+ a = rand(3,5,8)
+ V = rand(5,8)
+ g1 = randint(0,5,size=15)
+ g2 = randint(0,8,size=15)
+ V[g1,g2] = -V[g1,g2]
+ assert (array([a[0][V>0],a[1][V>0],a[2][V>0]]) == a[:,V>0]).all()
+
+class test_binary_repr(NumpyTestCase):
def test_zero(self):
assert_equal(binary_repr(0),'0')
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index c659c7584..6f06bab78 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -75,6 +75,8 @@ def ix_(*args):
new = _nx.array(args[k])
if (new.ndim != 1):
raise ValueError, "Cross index must be 1 dimensional"
+ if issubclass(new.dtype.type, _nx.bool_):
+ new = new.nonzero()[0]
baseshape[k] = len(new)
new.shape = tuple(baseshape)
out.append(new)