summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2010-12-17 18:54:25 -0800
committerMark Wiebe <mwwiebe@gmail.com>2011-01-09 01:54:59 -0800
commit3f5de70b2d7841c32547b4a87c16dc6e10658335 (patch)
treed67ae73d046361a7a181e86afed75b0f2234c586
parent6db2223b7c8e6ff0ba338c96a0ac382430930472 (diff)
downloadnumpy-3f5de70b2d7841c32547b4a87c16dc6e10658335.tar.gz
ENH: Add some utility functions for modifying the iterator
-rw-r--r--numpy/core/src/multiarray/new_iterator.c.src55
-rw-r--r--numpy/core/src/multiarray/new_iterator.h5
-rw-r--r--numpy/core/src/multiarray/new_iterator_pywrap.c119
-rw-r--r--numpy/core/tests/test_new_iterator.py29
4 files changed, 170 insertions, 38 deletions
diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src
index c4b80b7d7..62c5929ca 100644
--- a/numpy/core/src/multiarray/new_iterator.c.src
+++ b/numpy/core/src/multiarray/new_iterator.c.src
@@ -191,7 +191,7 @@ static void
npyiter_shrink_ndim(NpyIter *iter, npy_intp new_ndim);
static PyArray_Descr *
-npyiter_get_common_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
+npyiter_get_common_dtype(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
char *op_itflags, PyArray_Descr **op_dtype,
int only_inputs);
static int
@@ -522,7 +522,7 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags,
PyArray_Descr *dtype;
int only_inputs = !(flags&NPY_ITER_COMMON_DATA_TYPE);
- dtype = npyiter_get_common_type(niter, op, op_ndim,
+ dtype = npyiter_get_common_dtype(niter, op, op_ndim,
op_itflags, op_dtype,
only_inputs);
if (dtype == NULL) {
@@ -715,6 +715,55 @@ int NpyIter_Deallocate(NpyIter *iter)
return NPY_SUCCEED;
}
+/* Removes coords support from the iterator */
+int NpyIter_RemoveCoords(NpyIter *iter)
+{
+ npy_uint32 itflags;
+
+ /* Make sure the iterator is reset */
+ NpyIter_Reset(iter);
+
+ itflags = NIT_ITFLAGS(iter);
+ if (itflags&NPY_ITFLAG_HASCOORDS) {
+ NIT_ITFLAGS(iter) = itflags & ~NPY_ITFLAG_HASCOORDS;
+ npyiter_coalesce_axes(iter);
+ }
+
+ return NPY_SUCCEED;
+}
+
+/* Removes the inner loop handling (adds NPY_ITER_NO_INNER_ITERATION) */
+int NpyIter_RemoveInnerLoop(NpyIter *iter)
+{
+ npy_uint32 itflags = NIT_ITFLAGS(iter);;
+ npy_intp ndim = NIT_NDIM(iter);
+ npy_intp niter = NIT_NITER(iter);
+
+ char *axisdata;
+
+ /* Check conditions under which this can be done */
+ if (itflags&(NPY_ITFLAG_HASINDEX|NPY_ITFLAG_HASCOORDS)) {
+ PyErr_SetString(PyExc_ValueError,
+ "Iterator flag NO_INNER_ITERATION cannot be used "
+ "if coords or an index is being tracked");
+ return NPY_FAIL;
+ }
+ /* Set the flag */
+ if (!(itflags&NPY_ITFLAG_NOINNER)) {
+ itflags |= NPY_ITFLAG_NOINNER;
+ NIT_ITFLAGS(iter) = itflags;
+
+ /* Adjust ITERSIZE */
+ axisdata = NIT_AXISDATA(iter);
+ NIT_ITERSIZE(iter) /= NAD_SHAPE(axisdata);
+ }
+
+ /* Reset the iterator */
+ NpyIter_Reset(iter);
+
+ return NPY_SUCCEED;
+}
+
/* Resets the iterator to its initial state */
void NpyIter_Reset(NpyIter *iter)
{
@@ -2743,7 +2792,7 @@ npyiter_new_temp_array(NpyIter *iter, PyTypeObject *subtype,
* are not read from out of the calculation.
*/
static PyArray_Descr *
-npyiter_get_common_type(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
+npyiter_get_common_dtype(npy_intp niter, PyArrayObject **op, npy_intp *op_ndim,
char *op_itflags, PyArray_Descr **op_dtype,
int only_inputs)
{
diff --git a/numpy/core/src/multiarray/new_iterator.h b/numpy/core/src/multiarray/new_iterator.h
index 969f49d8f..63351471a 100644
--- a/numpy/core/src/multiarray/new_iterator.h
+++ b/numpy/core/src/multiarray/new_iterator.h
@@ -24,6 +24,11 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags,
npy_uint32 *op_flags, PyArray_Descr **op_request_dtypes,
npy_intp oa_ndim, npy_intp **op_axes);
+/* Removes coords support from an iterator */
+int NpyIter_RemoveCoords(NpyIter *iter);
+/* Removes the inner loop handling (adds NPY_ITER_NO_INNER_ITERATION) */
+int NpyIter_RemoveInnerLoop(NpyIter *iter);
+
/* Deallocate an iterator */
int NpyIter_Deallocate(NpyIter* iter);
diff --git a/numpy/core/src/multiarray/new_iterator_pywrap.c b/numpy/core/src/multiarray/new_iterator_pywrap.c
index ea0ebfd0f..3ec9008fb 100644
--- a/numpy/core/src/multiarray/new_iterator_pywrap.c
+++ b/numpy/core/src/multiarray/new_iterator_pywrap.c
@@ -30,6 +30,38 @@ struct NewNpyArrayIterObject_tag {
char writeflags[NPY_MAXARGS];
};
+void npyiter_cache_values(NewNpyArrayIterObject *self)
+{
+ NpyIter *iter = self->iter;
+
+ /* iternext and getcoords functions */
+ self->iternext = NpyIter_GetIterNext(iter);
+ if (NpyIter_HasCoords(iter)) {
+ self->getcoords = NpyIter_GetGetCoords(iter);
+ }
+ else {
+ self->getcoords = NULL;
+ }
+
+ /* Internal data pointers */
+ self->dataptrs = NpyIter_GetDataPtrArray(iter);
+ self->dtypes = NpyIter_GetDescrArray(iter);
+ self->objects = NpyIter_GetObjectArray(iter);
+
+ if (NpyIter_HasInnerLoop(iter)) {
+ self->innerstrides = NULL;
+ self->innerloopsizeptr = NULL;
+ }
+ else {
+ self->innerstrides = NpyIter_GetInnerStrideArray(iter);
+ self->innerloopsizeptr = NpyIter_GetInnerLoopSizePtr(iter);
+ }
+
+ /* The read/write settings */
+ NpyIter_GetReadFlags(iter, self->readflags);
+ NpyIter_GetWriteFlags(iter, self->writeflags);
+}
+
static PyObject *
npyiter_new(PyTypeObject *subtype, PyObject *args, PyObject *kwds)
{
@@ -506,30 +538,7 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
}
/* Cache some values for the member functions to use */
- self->iternext = NpyIter_GetIterNext(self->iter);
- if (NpyIter_HasCoords(self->iter)) {
- self->getcoords = NpyIter_GetGetCoords(self->iter);
- }
- else {
- self->getcoords = NULL;
- }
-
- self->dataptrs = NpyIter_GetDataPtrArray(self->iter);
- self->dtypes = NpyIter_GetDescrArray(self->iter);
- self->objects = NpyIter_GetObjectArray(self->iter);
-
- if (NpyIter_HasInnerLoop(self->iter)) {
- self->innerstrides = NULL;
- self->innerloopsizeptr = NULL;
- }
- else {
- self->innerstrides = NpyIter_GetInnerStrideArray(self->iter);
- self->innerloopsizeptr = NpyIter_GetInnerLoopSizePtr(self->iter);
- }
-
- /* Get the read/write settings */
- NpyIter_GetReadFlags(self->iter, self->readflags);
- NpyIter_GetWriteFlags(self->iter, self->writeflags);
+ npyiter_cache_values(self);
/* Release the references we got to the ops and dtypes */
for (iiter = 0; iiter < niter; ++iiter) {
@@ -560,11 +569,15 @@ npyiter_dealloc(NewNpyArrayIterObject *self)
static PyObject *
npyiter_reset(NewNpyArrayIterObject *self)
{
- if (self->iter) {
- NpyIter_Reset(self->iter);
- self->finished = 0;
+ if (self->iter == NULL) {
+ PyErr_SetString(PyExc_ValueError,
+ "Iterator was not constructed correctly");
+ return NULL;
}
+ NpyIter_Reset(self->iter);
+ self->finished = 0;
+
Py_RETURN_NONE;
}
@@ -581,6 +594,42 @@ npyiter_iternext(NewNpyArrayIterObject *self)
}
static PyObject *
+npyiter_remove_coords(NewNpyArrayIterObject *self)
+{
+ if (self->iter == NULL) {
+ PyErr_SetString(PyExc_ValueError,
+ "Iterator was not constructed correctly");
+ return NULL;
+ }
+
+ NpyIter_RemoveCoords(self->iter);
+ /* RemoveCoords invalidates cached values */
+ npyiter_cache_values(self);
+ /* RemoveCoords also resets the iterator */
+ self->finished = 0;
+
+ Py_RETURN_NONE;
+}
+
+static PyObject *
+npyiter_remove_inner_loop(NewNpyArrayIterObject *self)
+{
+ if (self->iter == NULL) {
+ PyErr_SetString(PyExc_ValueError,
+ "Iterator was not constructed correctly");
+ return NULL;
+ }
+
+ NpyIter_RemoveInnerLoop(self->iter);
+ /* RemoveInnerLoop invalidates cached values */
+ npyiter_cache_values(self);
+ /* RemoveInnerLoop also resets the iterator */
+ self->finished = 0;
+
+ Py_RETURN_NONE;
+}
+
+static PyObject *
npyiter_debug_print(NewNpyArrayIterObject *self)
{
if (self->iter != NULL) {
@@ -1081,10 +1130,6 @@ npyiter_seq_ass_item(NewNpyArrayIterObject *self, Py_ssize_t i, PyObject *v)
return -1;
}
niter = NpyIter_GetNIter(self->iter);
- /* Python negative indexing */
- if (i < 0) {
- i += niter;
- }
if (i < 0 || i >= niter) {
PyErr_Format(PyExc_IndexError,
"Iterator operand index %d is out of bounds", (int)i);
@@ -1100,16 +1145,17 @@ npyiter_seq_ass_item(NewNpyArrayIterObject *self, Py_ssize_t i, PyObject *v)
dtype = self->dtypes[i];
object = self->objects[i];
- /*
- * TODO: When buffering is enabled for an operand, the object won't
- * correspond to the data, so that will have to be accounted for
- */
if (NpyIter_HasInnerLoop(self->iter)) {
+ /*
+ * TODO: When buffering is enabled for an operand, the object won't
+ * correspond to the data, so that will have to be accounted for
+ */
return dtype->f->setitem(v, dataptr, object);
} else {
PyArrayObject *tmp;
int ret;
Py_INCREF(dtype);
+ /* TODO - there should be a better way than this... */
tmp = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type, dtype,
1, self->innerloopsizeptr,
&self->innerstrides[i], dataptr,
@@ -1127,6 +1173,9 @@ npyiter_seq_ass_item(NewNpyArrayIterObject *self, Py_ssize_t i, PyObject *v)
static PyMethodDef npyiter_methods[] = {
{"reset", (PyCFunction)npyiter_reset, METH_NOARGS, NULL},
{"iternext", (PyCFunction)npyiter_iternext, METH_NOARGS, NULL},
+ {"remove_coords", (PyCFunction)npyiter_remove_coords, METH_NOARGS, NULL},
+ {"remove_inner_loop", (PyCFunction)npyiter_remove_inner_loop,
+ METH_NOARGS, NULL},
{"debug_print", (PyCFunction)npyiter_debug_print, METH_NOARGS, NULL},
{NULL, NULL, 0, NULL},
};
diff --git a/numpy/core/tests/test_new_iterator.py b/numpy/core/tests/test_new_iterator.py
index f82127112..9002d231e 100644
--- a/numpy/core/tests/test_new_iterator.py
+++ b/numpy/core/tests/test_new_iterator.py
@@ -810,8 +810,11 @@ def test_iter_common_data_type():
[['readonly','copy','same_kind_casts']]*2)
assert_equal(i.dtypes[0], np.dtype('f4'));
assert_equal(i.dtypes[1], np.dtype('f4'));
+ # TODO
# This case is weird - the scalar/array combination produces a cast
# classified as unsafe. I think this NumPy rule needs to be revisited.
+ # For example, when the scalar is writeable, a negative value could
+ # be written during iteration, invalidating the scalar kind assumed!
i = newiter([array([3],dtype='u4'),array(0,dtype='i4')],
['common_data_type'],
[['readonly','copy','unsafe_casts']]*2)
@@ -1062,5 +1065,31 @@ def test_iter_allocate_output_errors():
op_dtypes=[None,np.dtype('f4')],
op_axes=[None,[0,2,1,0]])
+def test_iter_remove_coords_inner_loop():
+ # Check that removing coords support works
+
+ a = arange(24).reshape(2,3,4)
+
+ i = newiter(a,['coords'])
+ assert_equal(i.ndim, 3)
+ assert_equal(i.shape, (2,3,4))
+ assert_equal(i.itviews[0].shape, (2,3,4))
+
+ # Removing coords causes all dimensions to coalesce
+ before = [x for x in i]
+ i.remove_coords()
+ after = [x for x in i]
+
+ assert_equal(before, after)
+ assert_equal(i.ndim, 1)
+ assert_raises(ValueError, lambda i:i.shape, i)
+ assert_equal(i.itviews[0].shape, (24,))
+
+ # Removing the inner loop means there's just one iteration
+ assert_equal(i.itersize, 24)
+ i.remove_inner_loop()
+ assert_equal(i.itersize, 1)
+ assert_equal(i.value, arange(24))
+
if __name__ == "__main__":
run_module_suite()