diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2010-12-17 18:54:25 -0800 |
---|---|---|
committer | Mark Wiebe <mwwiebe@gmail.com> | 2011-01-09 01:54:59 -0800 |
commit | 3f5de70b2d7841c32547b4a87c16dc6e10658335 (patch) | |
tree | d67ae73d046361a7a181e86afed75b0f2234c586 | |
parent | 6db2223b7c8e6ff0ba338c96a0ac382430930472 (diff) | |
download | numpy-3f5de70b2d7841c32547b4a87c16dc6e10658335.tar.gz |
ENH: Add some utility functions for modifying the iterator
-rw-r--r-- | numpy/core/src/multiarray/new_iterator.c.src | 55 | ||||
-rw-r--r-- | numpy/core/src/multiarray/new_iterator.h | 5 | ||||
-rw-r--r-- | numpy/core/src/multiarray/new_iterator_pywrap.c | 119 | ||||
-rw-r--r-- | numpy/core/tests/test_new_iterator.py | 29 |
4 files changed, 170 insertions, 38 deletions
diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src index c4b80b7d7..62c5929ca 100644 --- a/numpy/core/src/multiarray/new_iterator.c.src +++ b/numpy/core/src/multiarray/new_iterator.c.src @@ -191,7 +191,7 @@ static void npyiter_shrink_ndim(NpyIter *iter, npy_intp new_ndim); static PyArray_Descr * -npyiter_get_common_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, +npyiter_get_common_dtype(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, char *op_itflags, PyArray_Descr **op_dtype, int only_inputs); static int @@ -522,7 +522,7 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags, PyArray_Descr *dtype; int only_inputs = !(flags&NPY_ITER_COMMON_DATA_TYPE); - dtype = npyiter_get_common_type(niter, op, op_ndim, + dtype = npyiter_get_common_dtype(niter, op, op_ndim, op_itflags, op_dtype, only_inputs); if (dtype == NULL) { @@ -715,6 +715,55 @@ int NpyIter_Deallocate(NpyIter *iter) return NPY_SUCCEED; } +/* Removes coords support from the iterator */ +int NpyIter_RemoveCoords(NpyIter *iter) +{ + npy_uint32 itflags; + + /* Make sure the iterator is reset */ + NpyIter_Reset(iter); + + itflags = NIT_ITFLAGS(iter); + if (itflags&NPY_ITFLAG_HASCOORDS) { + NIT_ITFLAGS(iter) = itflags & ~NPY_ITFLAG_HASCOORDS; + npyiter_coalesce_axes(iter); + } + + return NPY_SUCCEED; +} + +/* Removes the inner loop handling (adds NPY_ITER_NO_INNER_ITERATION) */ +int NpyIter_RemoveInnerLoop(NpyIter *iter) +{ + npy_uint32 itflags = NIT_ITFLAGS(iter);; + npy_intp ndim = NIT_NDIM(iter); + npy_intp niter = NIT_NITER(iter); + + char *axisdata; + + /* Check conditions under which this can be done */ + if (itflags&(NPY_ITFLAG_HASINDEX|NPY_ITFLAG_HASCOORDS)) { + PyErr_SetString(PyExc_ValueError, + "Iterator flag NO_INNER_ITERATION cannot be used " + "if coords or an index is being tracked"); + return NPY_FAIL; + } + /* Set the flag */ + if (!(itflags&NPY_ITFLAG_NOINNER)) { + itflags |= NPY_ITFLAG_NOINNER; + NIT_ITFLAGS(iter) = itflags; + + /* Adjust ITERSIZE */ + axisdata = NIT_AXISDATA(iter); + NIT_ITERSIZE(iter) /= NAD_SHAPE(axisdata); + } + + /* Reset the iterator */ + NpyIter_Reset(iter); + + return NPY_SUCCEED; +} + /* Resets the iterator to its initial state */ void NpyIter_Reset(NpyIter *iter) { @@ -2743,7 +2792,7 @@ npyiter_new_temp_array(NpyIter *iter, PyTypeObject *subtype, * are not read from out of the calculation. */ static PyArray_Descr * -npyiter_get_common_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, +npyiter_get_common_dtype(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim, char *op_itflags, PyArray_Descr **op_dtype, int only_inputs) { diff --git a/numpy/core/src/multiarray/new_iterator.h b/numpy/core/src/multiarray/new_iterator.h index 969f49d8f..63351471a 100644 --- a/numpy/core/src/multiarray/new_iterator.h +++ b/numpy/core/src/multiarray/new_iterator.h @@ -24,6 +24,11 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags, npy_uint32 *op_flags, PyArray_Descr **op_request_dtypes, npy_intp oa_ndim, npy_intp **op_axes); +/* Removes coords support from an iterator */ +int NpyIter_RemoveCoords(NpyIter *iter); +/* Removes the inner loop handling (adds NPY_ITER_NO_INNER_ITERATION) */ +int NpyIter_RemoveInnerLoop(NpyIter *iter); + /* Deallocate an iterator */ int NpyIter_Deallocate(NpyIter* iter); diff --git a/numpy/core/src/multiarray/new_iterator_pywrap.c b/numpy/core/src/multiarray/new_iterator_pywrap.c index ea0ebfd0f..3ec9008fb 100644 --- a/numpy/core/src/multiarray/new_iterator_pywrap.c +++ b/numpy/core/src/multiarray/new_iterator_pywrap.c @@ -30,6 +30,38 @@ struct NewNpyArrayIterObject_tag { char writeflags[NPY_MAXARGS]; }; +void npyiter_cache_values(NewNpyArrayIterObject *self) +{ + NpyIter *iter = self->iter; + + /* iternext and getcoords functions */ + self->iternext = NpyIter_GetIterNext(iter); + if (NpyIter_HasCoords(iter)) { + self->getcoords = NpyIter_GetGetCoords(iter); + } + else { + self->getcoords = NULL; + } + + /* Internal data pointers */ + self->dataptrs = NpyIter_GetDataPtrArray(iter); + self->dtypes = NpyIter_GetDescrArray(iter); + self->objects = NpyIter_GetObjectArray(iter); + + if (NpyIter_HasInnerLoop(iter)) { + self->innerstrides = NULL; + self->innerloopsizeptr = NULL; + } + else { + self->innerstrides = NpyIter_GetInnerStrideArray(iter); + self->innerloopsizeptr = NpyIter_GetInnerLoopSizePtr(iter); + } + + /* The read/write settings */ + NpyIter_GetReadFlags(iter, self->readflags); + NpyIter_GetWriteFlags(iter, self->writeflags); +} + static PyObject * npyiter_new(PyTypeObject *subtype, PyObject *args, PyObject *kwds) { @@ -506,30 +538,7 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds) } /* Cache some values for the member functions to use */ - self->iternext = NpyIter_GetIterNext(self->iter); - if (NpyIter_HasCoords(self->iter)) { - self->getcoords = NpyIter_GetGetCoords(self->iter); - } - else { - self->getcoords = NULL; - } - - self->dataptrs = NpyIter_GetDataPtrArray(self->iter); - self->dtypes = NpyIter_GetDescrArray(self->iter); - self->objects = NpyIter_GetObjectArray(self->iter); - - if (NpyIter_HasInnerLoop(self->iter)) { - self->innerstrides = NULL; - self->innerloopsizeptr = NULL; - } - else { - self->innerstrides = NpyIter_GetInnerStrideArray(self->iter); - self->innerloopsizeptr = NpyIter_GetInnerLoopSizePtr(self->iter); - } - - /* Get the read/write settings */ - NpyIter_GetReadFlags(self->iter, self->readflags); - NpyIter_GetWriteFlags(self->iter, self->writeflags); + npyiter_cache_values(self); /* Release the references we got to the ops and dtypes */ for (iiter = 0; iiter < niter; ++iiter) { @@ -560,11 +569,15 @@ npyiter_dealloc(NewNpyArrayIterObject *self) static PyObject * npyiter_reset(NewNpyArrayIterObject *self) { - if (self->iter) { - NpyIter_Reset(self->iter); - self->finished = 0; + if (self->iter == NULL) { + PyErr_SetString(PyExc_ValueError, + "Iterator was not constructed correctly"); + return NULL; } + NpyIter_Reset(self->iter); + self->finished = 0; + Py_RETURN_NONE; } @@ -581,6 +594,42 @@ npyiter_iternext(NewNpyArrayIterObject *self) } static PyObject * +npyiter_remove_coords(NewNpyArrayIterObject *self) +{ + if (self->iter == NULL) { + PyErr_SetString(PyExc_ValueError, + "Iterator was not constructed correctly"); + return NULL; + } + + NpyIter_RemoveCoords(self->iter); + /* RemoveCoords invalidates cached values */ + npyiter_cache_values(self); + /* RemoveCoords also resets the iterator */ + self->finished = 0; + + Py_RETURN_NONE; +} + +static PyObject * +npyiter_remove_inner_loop(NewNpyArrayIterObject *self) +{ + if (self->iter == NULL) { + PyErr_SetString(PyExc_ValueError, + "Iterator was not constructed correctly"); + return NULL; + } + + NpyIter_RemoveInnerLoop(self->iter); + /* RemoveInnerLoop invalidates cached values */ + npyiter_cache_values(self); + /* RemoveInnerLoop also resets the iterator */ + self->finished = 0; + + Py_RETURN_NONE; +} + +static PyObject * npyiter_debug_print(NewNpyArrayIterObject *self) { if (self->iter != NULL) { @@ -1081,10 +1130,6 @@ npyiter_seq_ass_item(NewNpyArrayIterObject *self, Py_ssize_t i, PyObject *v) return -1; } niter = NpyIter_GetNIter(self->iter); - /* Python negative indexing */ - if (i < 0) { - i += niter; - } if (i < 0 || i >= niter) { PyErr_Format(PyExc_IndexError, "Iterator operand index %d is out of bounds", (int)i); @@ -1100,16 +1145,17 @@ npyiter_seq_ass_item(NewNpyArrayIterObject *self, Py_ssize_t i, PyObject *v) dtype = self->dtypes[i]; object = self->objects[i]; - /* - * TODO: When buffering is enabled for an operand, the object won't - * correspond to the data, so that will have to be accounted for - */ if (NpyIter_HasInnerLoop(self->iter)) { + /* + * TODO: When buffering is enabled for an operand, the object won't + * correspond to the data, so that will have to be accounted for + */ return dtype->f->setitem(v, dataptr, object); } else { PyArrayObject *tmp; int ret; Py_INCREF(dtype); + /* TODO - there should be a better way than this... */ tmp = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type, dtype, 1, self->innerloopsizeptr, &self->innerstrides[i], dataptr, @@ -1127,6 +1173,9 @@ npyiter_seq_ass_item(NewNpyArrayIterObject *self, Py_ssize_t i, PyObject *v) static PyMethodDef npyiter_methods[] = { {"reset", (PyCFunction)npyiter_reset, METH_NOARGS, NULL}, {"iternext", (PyCFunction)npyiter_iternext, METH_NOARGS, NULL}, + {"remove_coords", (PyCFunction)npyiter_remove_coords, METH_NOARGS, NULL}, + {"remove_inner_loop", (PyCFunction)npyiter_remove_inner_loop, + METH_NOARGS, NULL}, {"debug_print", (PyCFunction)npyiter_debug_print, METH_NOARGS, NULL}, {NULL, NULL, 0, NULL}, }; diff --git a/numpy/core/tests/test_new_iterator.py b/numpy/core/tests/test_new_iterator.py index f82127112..9002d231e 100644 --- a/numpy/core/tests/test_new_iterator.py +++ b/numpy/core/tests/test_new_iterator.py @@ -810,8 +810,11 @@ def test_iter_common_data_type(): [['readonly','copy','same_kind_casts']]*2) assert_equal(i.dtypes[0], np.dtype('f4')); assert_equal(i.dtypes[1], np.dtype('f4')); + # TODO # This case is weird - the scalar/array combination produces a cast # classified as unsafe. I think this NumPy rule needs to be revisited. + # For example, when the scalar is writeable, a negative value could + # be written during iteration, invalidating the scalar kind assumed! i = newiter([array([3],dtype='u4'),array(0,dtype='i4')], ['common_data_type'], [['readonly','copy','unsafe_casts']]*2) @@ -1062,5 +1065,31 @@ def test_iter_allocate_output_errors(): op_dtypes=[None,np.dtype('f4')], op_axes=[None,[0,2,1,0]]) +def test_iter_remove_coords_inner_loop(): + # Check that removing coords support works + + a = arange(24).reshape(2,3,4) + + i = newiter(a,['coords']) + assert_equal(i.ndim, 3) + assert_equal(i.shape, (2,3,4)) + assert_equal(i.itviews[0].shape, (2,3,4)) + + # Removing coords causes all dimensions to coalesce + before = [x for x in i] + i.remove_coords() + after = [x for x in i] + + assert_equal(before, after) + assert_equal(i.ndim, 1) + assert_raises(ValueError, lambda i:i.shape, i) + assert_equal(i.itviews[0].shape, (24,)) + + # Removing the inner loop means there's just one iteration + assert_equal(i.itersize, 24) + i.remove_inner_loop() + assert_equal(i.itersize, 1) + assert_equal(i.value, arange(24)) + if __name__ == "__main__": run_module_suite() |