diff options
-rw-r--r-- | numpy/core/src/arrayobject.c | 51 | ||||
-rw-r--r-- | numpy/core/src/scalartypes.inc.src | 1 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 21 | ||||
-rw-r--r-- | numpy/lib/index_tricks.py | 2 |
4 files changed, 54 insertions, 21 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 03fd4d445..3b99cf550 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2175,7 +2175,7 @@ fancy_indexing_check(PyObject *args) for (i=0; i<n; i++) { obj = PyTuple_GET_ITEM(args,i); if (PyArray_Check(obj)) { - if (PyArray_ISINTEGER(obj)) + if (PyArray_ISINTEGER(obj) || PyArray_ISBOOL(obj)) retval = SOBJ_ISFANCY; else { retval = SOBJ_BADARRAY; @@ -5429,7 +5429,6 @@ static PyObject * array_struct_get(PyArrayObject *self) { PyArrayInterface *inter; - PyObject *tup; inter = (PyArrayInterface *)_pya_malloc(sizeof(PyArrayInterface)); inter->dummy = 2; @@ -6844,7 +6843,6 @@ static PyObject * PyArray_FromStructInterface(PyObject *input) { PyArray_Descr *thetype=NULL; - PyObject *obj; char buf[40]; PyArrayInterface *inter; PyObject *attr, *r; @@ -8256,6 +8254,8 @@ static PyTypeObject PyArrayIter_Type = { * and so that indexing can be set up ahead of time * */ + +static int _nonzero_indices(PyObject *myBool, PyArrayIterObject **iters); /* convert an indexing object to an INTP indexing array iterator if possible -- otherwise, it is a Slice or Ellipsis object and has to be interpreted on bind to a particular @@ -8268,7 +8268,10 @@ _convert_obj(PyObject *obj, PyArrayIterObject **iter) PyObject *arr; if (PySlice_Check(obj) || (obj == Py_Ellipsis)) - *iter = NULL; + return 0; + else if (PyArray_Check(obj) && PyArray_ISBOOL(obj)) { + return _nonzero_indices(obj, iter); + } else { indtype = PyArray_DescrFromType(PyArray_INTP); arr = PyArray_FromAny(obj, indtype, 0, 0, FORCECAST, NULL); @@ -8277,7 +8280,7 @@ _convert_obj(PyObject *obj, PyArrayIterObject **iter) Py_DECREF(arr); if (*iter == NULL) return -1; } - return 0; + return 1; } /* Adjust dimensionality and strides for index object iterators @@ -8698,7 +8701,7 @@ PyArray_MapIterNew(PyObject *indexobj, int oned, int fancy) if (fancy == SOBJ_BADARRAY) { PyErr_SetString(PyExc_IndexError, \ "arrays used as indices must be of " \ - "integer type"); + "integer (or boolean) type"); return NULL; } if (fancy == SOBJ_TOOMANY) { @@ -8773,32 +8776,52 @@ PyArray_MapIterNew(PyObject *indexobj, int oned, int fancy) } else { /* must be a tuple */ PyObject *obj; - PyArrayIterObject *iter; + PyArrayIterObject **iterp; PyObject *new; + int numiters, j, n2; /* Make a copy of the tuple -- we will be replacing index objects with 0's */ n = PyTuple_GET_SIZE(indexobj); - new = PyTuple_New(n); + n2 = n; + new = PyTuple_New(n2); if (new == NULL) goto fail; started = 0; nonindex = 0; + j = 0; for (i=0; i<n; i++) { obj = PyTuple_GET_ITEM(indexobj,i); - if (_convert_obj(obj, &iter) < 0) { + iterp = mit->iters + mit->numiter; + if ((numiters=_convert_obj(obj, iterp)) < 0) { Py_DECREF(new); goto fail; } - if (iter!= NULL) { + if (numiters > 0) { started = 1; if (nonindex) mit->consec = 0; - mit->iters[(mit->numiter)++] = iter; - PyTuple_SET_ITEM(new,i, - PyInt_FromLong(0)); + mit->numiter += numiters; + if (numiters == 1) { + PyTuple_SET_ITEM(new,j++, + PyInt_FromLong(0)); + } + else { /* we need to grow the + new indexing object and fill + it with 0s for each of the iterators + produced */ + int k; + n2 += numiters - 1; + if (_PyTuple_Resize(&new, n2) < 0) + goto fail; + for (k=0;k<numiters;k++) { + PyTuple_SET_ITEM \ + (new,j++, + PyInt_FromLong(0)); + } + } } else { if (started) nonindex = 1; Py_INCREF(obj); - PyTuple_SET_ITEM(new,i,obj); + PyTuple_SET_ITEM(new,j++,obj); } } Py_DECREF(mit->indexobj); diff --git a/numpy/core/src/scalartypes.inc.src b/numpy/core/src/scalartypes.inc.src index 03ab93da3..53cd6e883 100644 --- a/numpy/core/src/scalartypes.inc.src +++ b/numpy/core/src/scalartypes.inc.src @@ -726,7 +726,6 @@ gentype_struct_get(PyObject *self) { PyArrayObject *arr; PyArrayInterface *inter; - PyObject *tup; arr = (PyArrayObject *)PyArray_FromScalar(self, NULL); inter = (PyArrayInterface *)_pya_malloc(sizeof(PyArrayInterface)); diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 6c3fef8b3..9f6118f19 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -1,9 +1,9 @@ from numpy.core import * -from numpy.random import rand +from numpy.random import rand, randint from numpy.testing import * from numpy.core.multiarray import dot as dot_ -class test_dot(ScipyTestCase): +class test_dot(NumpyTestCase): def setUp(self): self.A = rand(10,8) self.b1 = rand(8,1) @@ -104,7 +104,7 @@ class test_dot(ScipyTestCase): assert_almost_equal(c1, c2, decimal=self.N) -class test_bool_scalar(ScipyTestCase): +class test_bool_scalar(NumpyTestCase): def test_logical(self): f = False_ t = True_ @@ -137,7 +137,7 @@ class test_bool_scalar(ScipyTestCase): self.failUnless((f ^ f) is f) -class test_seterr(ScipyTestCase): +class test_seterr(NumpyTestCase): def test_set(self): err = seterr() old = seterr(divide='warn') @@ -161,7 +161,7 @@ class test_seterr(ScipyTestCase): array([1.]) / array([0.]) -class test_fromiter(ScipyTestCase): +class test_fromiter(NumpyTestCase): def makegen(self): for x in xrange(24): @@ -195,7 +195,16 @@ class test_fromiter(ScipyTestCase): self.failUnless(alltrue(a == expected)) self.failUnless(alltrue(a20 == expected[:20])) -class test_binary_repr(ScipyTestCase): +class test_index(NumpyTestCase): + def test_boolean(self): + a = rand(3,5,8) + V = rand(5,8) + g1 = randint(0,5,size=15) + g2 = randint(0,8,size=15) + V[g1,g2] = -V[g1,g2] + assert (array([a[0][V>0],a[1][V>0],a[2][V>0]]) == a[:,V>0]).all() + +class test_binary_repr(NumpyTestCase): def test_zero(self): assert_equal(binary_repr(0),'0') diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index c659c7584..6f06bab78 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -75,6 +75,8 @@ def ix_(*args): new = _nx.array(args[k]) if (new.ndim != 1): raise ValueError, "Cross index must be 1 dimensional" + if issubclass(new.dtype.type, _nx.bool_): + new = new.nonzero()[0] baseshape[k] = len(new) new.shape = tuple(baseshape) out.append(new) |