summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Cournapeau <cournape@gmail.com>2009-07-21 05:38:34 +0000
committerDavid Cournapeau <cournape@gmail.com>2009-07-21 05:38:34 +0000
commit62b73d811900ef457d6e3eb55def0ab5f6592a47 (patch)
tree39716261a306714b020c6780f96d7c346e2c9db5
parent956ddcff78a3273c45cb4afe6fa1a2bfb5e5fed9 (diff)
downloadnumpy-62b73d811900ef457d6e3eb55def0ab5f6592a47.tar.gz
Simplify neighborhood iterator API.
-rw-r--r--numpy/core/include/numpy/ndarrayobject.h5
-rw-r--r--numpy/core/src/multiarray/iterators.c67
-rw-r--r--numpy/core/src/multiarray/multiarray_tests.c.src12
3 files changed, 37 insertions, 47 deletions
diff --git a/numpy/core/include/numpy/ndarrayobject.h b/numpy/core/include/numpy/ndarrayobject.h
index 5d50ad87f..aa6c471f3 100644
--- a/numpy/core/include/numpy/ndarrayobject.h
+++ b/numpy/core/include/numpy/ndarrayobject.h
@@ -918,11 +918,6 @@ enum {
};
typedef struct {
- int mode;
- PyObject* constant;
-} PyArrayNeighborhoodIterMode;
-
-typedef struct {
PyObject_HEAD
/*
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index 729dd3963..e351fd4cc 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1741,7 +1741,7 @@ NPY_NO_EXPORT PyTypeObject PyArrayMultiIter_Type = {
static void neighiter_dealloc(PyArrayNeighborhoodIterObject* iter);
static char* _set_constant(PyArrayNeighborhoodIterObject* iter,
- PyArrayNeighborhoodIterMode* mode)
+ PyArrayObject *fill)
{
char *ret;
PyArrayIterObject *ar = iter->_internal_iter;
@@ -1754,14 +1754,14 @@ static char* _set_constant(PyArrayNeighborhoodIterObject* iter,
}
if (PyArray_ISOBJECT(ar->ao)) {
- memcpy(ret, &mode->constant, sizeof(PyObject*));
+ memcpy(ret, fill->data, sizeof(PyObject*));
Py_INCREF(*(PyObject**)ret);
} else {
/* Non-object types */
storeflags = ar->ao->flags;
ar->ao->flags |= BEHAVED;
- st = ar->ao->descr->f->setitem(mode->constant, ret, ar->ao);
+ st = ar->ao->descr->f->setitem((PyObject*)fill, ret, ar->ao);
ar->ao->flags = storeflags;
if (st < 0) {
@@ -1774,9 +1774,12 @@ static char* _set_constant(PyArrayNeighborhoodIterObject* iter,
}
/*NUMPY_API*/
+/*
+ * fill and x->ao should have equivalent types
+ */
NPY_NO_EXPORT PyObject*
PyArray_NeighborhoodIterNew(PyArrayIterObject *x, intp *bounds,
- PyArrayNeighborhoodIterMode* mode)
+ int mode, PyArrayObject* fill)
{
int i;
PyArrayNeighborhoodIterObject *ret;
@@ -1805,37 +1808,33 @@ PyArray_NeighborhoodIterNew(PyArrayIterObject *x, intp *bounds,
ret->dimensions[i] = x->ao->dimensions[i];
}
- if (mode == NULL) {
- ret->constant = PyArray_Zero(x->ao);
- ret->mode = NPY_NEIGHBORHOOD_ITER_ZERO_PADDING;
- } else {
- switch (mode->mode) {
- case NPY_NEIGHBORHOOD_ITER_ZERO_PADDING:
- ret->constant = PyArray_Zero(x->ao);
- ret->mode = mode->mode;
- break;
- case NPY_NEIGHBORHOOD_ITER_ONE_PADDING:
- ret->constant = PyArray_One(x->ao);
- ret->mode = mode->mode;
- break;
- case NPY_NEIGHBORHOOD_ITER_CONSTANT_PADDING:
- /* New reference in returned value of _set_constant if array
- * object */
- ret->constant = _set_constant(ret, mode);
- if (ret->constant == NULL) {
- goto clean_x;
- }
- ret->mode = mode->mode;
- break;
- case NPY_NEIGHBORHOOD_ITER_MIRROR_PADDING:
- case NPY_NEIGHBORHOOD_ITER_CIRCULAR_PADDING:
- ret->mode = mode->mode;
- ret->constant = NULL;
- break;
- default:
- PyErr_SetString(PyExc_ValueError, "Unsupported padding mode");
+ switch (mode) {
+ case NPY_NEIGHBORHOOD_ITER_ZERO_PADDING:
+ ret->constant = PyArray_Zero(x->ao);
+ ret->mode = mode;
+ break;
+ case NPY_NEIGHBORHOOD_ITER_ONE_PADDING:
+ ret->constant = PyArray_One(x->ao);
+ ret->mode = mode;
+ break;
+ case NPY_NEIGHBORHOOD_ITER_CONSTANT_PADDING:
+ /* New reference in returned value of _set_constant if array
+ * object */
+ assert(PyArray_EquivArrTypes(x->ao, fill) == NPY_TRUE);
+ ret->constant = _set_constant(ret, fill);
+ if (ret->constant == NULL) {
goto clean_x;
- }
+ }
+ ret->mode = mode;
+ break;
+ case NPY_NEIGHBORHOOD_ITER_MIRROR_PADDING:
+ case NPY_NEIGHBORHOOD_ITER_CIRCULAR_PADDING:
+ ret->mode = mode;
+ ret->constant = NULL;
+ break;
+ default:
+ PyErr_SetString(PyExc_ValueError, "Unsupported padding mode");
+ goto clean_x;
}
/*
diff --git a/numpy/core/src/multiarray/multiarray_tests.c.src b/numpy/core/src/multiarray/multiarray_tests.c.src
index 040026f2c..859fe2277 100644
--- a/numpy/core/src/multiarray/multiarray_tests.c.src
+++ b/numpy/core/src/multiarray/multiarray_tests.c.src
@@ -95,12 +95,11 @@ test_neighborhood_iterator(PyObject* NPY_UNUSED(self), PyObject* args)
PyObject *x, *fill, *out, *b;
PyArrayObject *ax, *afill;
PyArrayIterObject *itx;
- int i, typenum, imode, st;
+ int i, typenum, mode, st;
npy_intp bounds[NPY_MAXDIMS*2];
PyArrayNeighborhoodIterObject *niterx;
- PyArrayNeighborhoodIterMode mode;
- if (!PyArg_ParseTuple(args, "OOOi", &x, &b, &fill, &imode)) {
+ if (!PyArg_ParseTuple(args, "OOOi", &x, &b, &fill, &mode)) {
return NULL;
}
@@ -148,19 +147,16 @@ test_neighborhood_iterator(PyObject* NPY_UNUSED(self), PyObject* args)
}
/* Create the neighborhood iterator */
- mode.mode = imode;
- mode.constant = NULL;
afill = NULL;
- if (imode == NPY_NEIGHBORHOOD_ITER_CONSTANT_PADDING) {
+ if (mode == NPY_NEIGHBORHOOD_ITER_CONSTANT_PADDING) {
afill = (PyArrayObject *)PyArray_FromObject(fill, typenum, 0, 0);
if (afill == NULL) {
goto clean_itx;
}
- mode.constant = (PyObject*)afill;
}
niterx = (PyArrayNeighborhoodIterObject*)PyArray_NeighborhoodIterNew(
- (PyArrayIterObject*)itx, bounds, &mode);
+ (PyArrayIterObject*)itx, bounds, mode, afill);
if (niterx == NULL) {
goto clean_afill;
}