summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2011-01-13 12:51:38 -0800
committerMark Wiebe <mwwiebe@gmail.com>2011-01-13 12:51:38 -0800
commit2b04a70392fb8a26cb0f2e3f465771be30edf2ad (patch)
tree8d23f7e4a587e3946e578dac8ac3f73a2633f670
parent7c3b6b8d471778eb9fda7636bae33a1f387ee6c1 (diff)
downloadnumpy-2b04a70392fb8a26cb0f2e3f465771be30edf2ad.tar.gz
ENH: iter: Add support for iterating object arrays
-rw-r--r--numpy/core/code_generators/numpy_api.py3
-rw-r--r--numpy/core/include/numpy/ndarraytypes.h18
-rw-r--r--numpy/core/src/multiarray/lowlevel_strided_loops.c.src62
-rw-r--r--numpy/core/src/multiarray/lowlevel_strided_loops.h5
-rw-r--r--numpy/core/src/multiarray/new_iterator.c.src98
-rw-r--r--numpy/core/src/multiarray/new_iterator_pywrap.c36
-rw-r--r--numpy/core/tests/test_new_iterator.py34
7 files changed, 220 insertions, 36 deletions
diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py
index 5b51761f2..907167862 100644
--- a/numpy/core/code_generators/numpy_api.py
+++ b/numpy/core/code_generators/numpy_api.py
@@ -291,7 +291,8 @@ multiarray_funcs_api = {
'NpyIter_GetReadFlags': 257,
'NpyIter_GetWriteFlags': 258,
'NpyIter_DebugPrint': 259,
- 'PyArray_CastingConverter': 260,
+ 'NpyIter_IterationNeedsAPI': 260,
+ 'PyArray_CastingConverter': 261,
}
ufunc_types_api = {
diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h
index 46dbbefc7..db7ffb183 100644
--- a/numpy/core/include/numpy/ndarraytypes.h
+++ b/numpy/core/include/numpy/ndarraytypes.h
@@ -882,14 +882,16 @@ typedef void (*NpyIter_GetCoords_Fn )(NpyIter *iter,
#define NPY_ITER_NO_INNER_ITERATION 0x00000008
/* Convert all the operands to a common data type */
#define NPY_ITER_COMMON_DTYPE 0x00000010
+/* Operands may hold references, requiring API access during iteration */
+#define NPY_ITER_REFS_OK 0x00000020
/* Enables sub-range iteration */
-#define NPY_ITER_RANGED 0x00000020
+#define NPY_ITER_RANGED 0x00000040
/* Enables buffering */
-#define NPY_ITER_BUFFERED 0x00000040
+#define NPY_ITER_BUFFERED 0x00000080
/* When buffering is enabled, grows the inner loop if possible */
-#define NPY_ITER_GROWINNER 0x00000080
+#define NPY_ITER_GROWINNER 0x00000100
/* Delay allocation of buffers until first Reset* call */
-#define NPY_ITER_DELAY_BUFALLOC 0x00000100
+#define NPY_ITER_DELAY_BUFALLOC 0x00000200
/*** Per-operand flags that may be passed to the iterator constructors ***/
@@ -909,14 +911,12 @@ typedef void (*NpyIter_GetCoords_Fn )(NpyIter *iter,
#define NPY_ITER_COPY 0x00400000
/* The operand may be copied with UPDATEIFCOPY to satisfy requirements */
#define NPY_ITER_UPDATEIFCOPY 0x00800000
-/* Allow writeable operands to have references or pointers */
-#define NPY_ITER_WRITEABLE_REFERENCES 0x01000000
/* Allocate the operand if it is NULL */
-#define NPY_ITER_ALLOCATE 0x02000000
+#define NPY_ITER_ALLOCATE 0x01000000
/* If an operand is allocated, don't use any subtype */
-#define NPY_ITER_NO_SUBTYPE 0x04000000
+#define NPY_ITER_NO_SUBTYPE 0x02000000
/* Require that the dimension match the iterator dimensions exactly */
-#define NPY_ITER_NO_BROADCAST 0x08000000
+#define NPY_ITER_NO_BROADCAST 0x04000000
#define NPY_ITER_GLOBAL_FLAGS 0x0000ffff
#define NPY_ITER_PER_OP_FLAGS 0xffff0000
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src
index f02021c86..6b59eca60 100644
--- a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src
+++ b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src
@@ -656,6 +656,57 @@ NPY_NO_EXPORT PyArray_StridedTransferFn
/**end repeat**/
+/* Moves references from src to dst */
+static void
+_strided_to_strided_move_references(char *dst, npy_intp dst_stride,
+ char *src, npy_intp src_stride,
+ npy_intp N, npy_intp src_itemsize,
+ void *data)
+{
+ PyObject *src_ref = NULL, *dst_ref = NULL;
+ while (N > 0) {
+ NPY_COPY_PYOBJECT_PTR(&src_ref, src);
+ NPY_COPY_PYOBJECT_PTR(&dst_ref, dst);
+
+ /* Release the reference in dst */
+ Py_XDECREF(dst_ref);
+ /* Move the reference */
+ NPY_COPY_PYOBJECT_PTR(dst, &src_ref);
+ /* Set the source reference to NULL */
+ src_ref = NULL;
+ NPY_COPY_PYOBJECT_PTR(src, &src_ref);
+
+ src += src_stride;
+ dst += dst_stride;
+ --N;
+ }
+}
+
+/* Copies references from src to dst */
+static void
+_strided_to_strided_copy_references(char *dst, npy_intp dst_stride,
+ char *src, npy_intp src_stride,
+ npy_intp N, npy_intp src_itemsize,
+ void *data)
+{
+ PyObject *src_ref = NULL, *dst_ref = NULL;
+ while (N > 0) {
+ NPY_COPY_PYOBJECT_PTR(&src_ref, src);
+ NPY_COPY_PYOBJECT_PTR(&dst_ref, dst);
+
+ /* Release the reference in dst */
+ Py_XDECREF(dst_ref);
+ /* Copy the reference */
+ NPY_COPY_PYOBJECT_PTR(dst, &src_ref);
+ /* Claim the reference */
+ Py_XINCREF(src_ref);
+
+ src += src_stride;
+ dst += dst_stride;
+ --N;
+ }
+}
+
/* Does a zero-padded copy */
typedef struct {
void *freefunc, *copyfunc;
@@ -919,6 +970,7 @@ NPY_NO_EXPORT int
PyArray_GetDTypeTransferFunction(int aligned,
npy_intp src_stride, npy_intp dst_stride,
PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype,
+ int move_references,
PyArray_StridedTransferFn *outstransfer,
void **outtransferdata)
{
@@ -976,6 +1028,16 @@ PyArray_GetDTypeTransferFunction(int aligned,
src_stride, dst_stride,
src_dtype->elsize, dst_dtype->elsize,
outstransfer, outtransferdata);
+ case NPY_OBJECT:
+ if (move_references) {
+ *outstransfer = &_strided_to_strided_move_references;
+ *outtransferdata = NULL;
+ }
+ else {
+ *outstransfer = &_strided_to_strided_copy_references;
+ *outtransferdata = NULL;
+ }
+ return NPY_SUCCEED;
}
}
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.h b/numpy/core/src/multiarray/lowlevel_strided_loops.h
index 5c6374493..ffa3767d4 100644
--- a/numpy/core/src/multiarray/lowlevel_strided_loops.h
+++ b/numpy/core/src/multiarray/lowlevel_strided_loops.h
@@ -100,12 +100,17 @@ PyArray_GetStridedZeroPadCopyFn(npy_intp aligned,
* must be deallocated with the ``PyArray_FreeStridedTransferData``
* function when the transfer function is no longer required.
*
+ * If move_references is 1, and the 'from' type has references,
+ * the source references will get a DECREF after the reference value is
+ * cast to the dest type, then be set to NULL.
+ *
* Returns NPY_SUCCEED or NPY_FAIL.
*/
NPY_NO_EXPORT int
PyArray_GetDTypeTransferFunction(int aligned,
npy_intp src_stride, npy_intp dst_stride,
PyArray_Descr *from, PyArray_Descr *to,
+ int move_references,
PyArray_StridedTransferFn *outstransfer,
void **outtransferdata);
diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src
index 32ba8f4d3..6f249955a 100644
--- a/numpy/core/src/multiarray/new_iterator.c.src
+++ b/numpy/core/src/multiarray/new_iterator.c.src
@@ -38,6 +38,8 @@
#define NPY_ITFLAG_ONEITERATION 0x200
/* Delay buffer allocation until first Reset* call */
#define NPY_ITFLAG_DELAYBUF 0x400
+/* Iteration needs API access during iternext */
+#define NPY_ITFLAG_NEEDSAPI 0x800
/* Internal iterator per-operand iterator flags */
@@ -570,6 +572,27 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags,
}
}
+ /*
+ * If REFS_OK was specified, check whether there are any
+ * reference arrays and flag it if so.
+ */
+ if (flags&NPY_ITER_REFS_OK) {
+ for (iiter = 0; iiter < niter; ++iiter) {
+ PyArray_Descr *odt = PyArray_DESCR(op[iiter]),
+ *rdt = op_dtype[iiter];
+ if (((odt->flags&(NPY_ITEM_REFCOUNT|
+ NPY_ITEM_IS_POINTER|
+ NPY_NEEDS_PYAPI)) != 0) ||
+ (odt != rdt &&
+ ((rdt->flags&(NPY_ITEM_REFCOUNT|
+ NPY_ITEM_IS_POINTER|
+ NPY_NEEDS_PYAPI))) != 0)) {
+ /* Iteration needs API access */
+ NIT_ITFLAGS(iter) |= NPY_ITFLAG_NEEDSAPI;
+ }
+ }
+ }
+
/* If buffering is set without delayed allocation */
if (itflags&NPY_ITFLAG_BUFFER) {
if (!npyiter_allocate_transfer_functions(iter)) {
@@ -1797,6 +1820,17 @@ NpyIter_HasIndex(NpyIter *iter)
}
/*NUMPY_API
+ * Whether the iteration loop, and in particular the iternext()
+ * function, needs API access. If this is true, the GIL must
+ * be retained while iterating.
+ */
+NPY_NO_EXPORT int
+NpyIter_IterationNeedsAPI(NpyIter *iter)
+{
+ return (NIT_ITFLAGS(iter)&NPY_ITFLAG_NEEDSAPI) != 0;
+}
+
+/*NUMPY_API
* Gets the number of dimensions being iterated
*/
NPY_NO_EXPORT npy_intp
@@ -2421,17 +2455,21 @@ npyiter_prepare_one_operand(PyArrayObject **op,
}
Py_INCREF(*op_dtype);
/*
- * Make sure that if the data type has a Python reference or
- * other pointer, WRITEABLE_REFERENCES was specified.
+ * If references weren't specifically allowed, make sure there
+ * are no references in the inputs or requested dtypes.
*/
- if (((*op_itflags)&NPY_OP_ITFLAG_WRITE) &&
- !(op_flags&NPY_ITER_WRITEABLE_REFERENCES)) {
- if (PyDataType_FLAGCHK(*op_dtype, NPY_ITEM_HASOBJECT) ||
- PyDataType_FLAGCHK(*op_dtype, NPY_ITEM_IS_POINTER)) {
- PyErr_SetString(PyExc_ValueError,
- "Tried to construct an iterator for a writeable "
- "array of references/pointers without specifying the "
- "WRITEABLE_REFERENCES flag.");
+ if (!(flags&NPY_ITER_REFS_OK)) {
+ PyArray_Descr *dt = PyArray_DESCR(*op);
+ if (((dt->flags&(NPY_ITEM_REFCOUNT|
+ NPY_ITEM_IS_POINTER|
+ NPY_NEEDS_PYAPI)) != 0) ||
+ (dt != *op_dtype &&
+ (((*op_dtype)->flags&(NPY_ITEM_REFCOUNT|
+ NPY_ITEM_IS_POINTER|
+ NPY_NEEDS_PYAPI))) != 0)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Iterator operand or requested dtype holds "
+ "references, but the REFS_OK flag was not enabled");
return 0;
}
}
@@ -4178,12 +4216,14 @@ npyiter_allocate_transfer_functions(NpyIter *iter)
*/
if (!(flags&NPY_OP_ITFLAG_BUFNEVER)) {
if (flags&NPY_OP_ITFLAG_READ) {
+ int move_references = 0;
if (PyArray_GetDTypeTransferFunction(
(flags&NPY_OP_ITFLAG_ALIGNED) != 0,
op_stride,
op_dtype[iiter]->elsize,
PyArray_DESCR(op[iiter]),
op_dtype[iiter],
+ move_references,
&stransfer,
&transferdata) != NPY_SUCCEED) {
goto fail;
@@ -4195,12 +4235,14 @@ npyiter_allocate_transfer_functions(NpyIter *iter)
readtransferfn[iiter] = NULL;
}
if (flags&NPY_OP_ITFLAG_WRITE) {
+ int move_references = 1;
if (PyArray_GetDTypeTransferFunction(
(flags&NPY_OP_ITFLAG_ALIGNED) != 0,
op_dtype[iiter]->elsize,
op_stride,
op_dtype[iiter],
PyArray_DESCR(op[iiter]),
+ move_references,
&stransfer,
&transferdata) != NPY_SUCCEED) {
goto fail;
@@ -4395,6 +4437,7 @@ npyiter_copy_from_buffers(NpyIter *iter)
*ad_strides = NAD_STRIDES(axisdata);
char **ptrs = NBF_PTRS(bufferdata), **ad_ptrs = NAD_PTRS(axisdata);
char **buffers = NBF_BUFFERS(bufferdata);
+ char *buffer;
PyArray_StridedTransferFn stransfer = NULL;
void *transferdata = NULL;
@@ -4410,13 +4453,18 @@ npyiter_copy_from_buffers(NpyIter *iter)
for (iiter = 0; iiter < niter; ++iiter) {
stransfer = NBF_WRITETRANSFERFN(bufferdata)[iiter];
transferdata = NBF_WRITETRANSFERDATA(bufferdata)[iiter];
+ buffer = buffers[iiter];
+ /*
+ * Copy the data back to the arrays. If the type has refs,
+ * this function moves them so the buffer's refs are released.
+ */
if ((stransfer != NULL) && (op_itflags[iiter]&NPY_OP_ITFLAG_WRITE)) {
/* Copy back only if the pointer was pointing to the buffer */
- npy_intp delta = (ptrs[iiter] - buffers[iiter]);
+ npy_intp delta = (ptrs[iiter] - buffer);
if (0 <= delta && delta <= transfersize*dtypes[iiter]->elsize) {
PyArray_TransferStridedToNDim(ndim,
ad_ptrs[iiter], &ad_strides[iiter], axisdata_incr,
- buffers[iiter], strides[iiter],
+ buffer, strides[iiter],
&NAD_COORD(axisdata), axisdata_incr,
&NAD_SHAPE(axisdata), axisdata_incr,
transfersize, dtypes[iiter]->elsize,
@@ -4424,6 +4472,19 @@ npyiter_copy_from_buffers(NpyIter *iter)
transferdata);
}
}
+ /* If there's no copy back, we may have to decrement refs */
+ else if (PyDataType_REFCHK(dtypes[iiter])) {
+ /* Decrement refs only if the pointer was pointing to the buffer */
+ npy_intp delta = (ptrs[iiter] - buffer);
+ if (0 <= delta && delta <= transfersize*dtypes[iiter]->elsize) {
+ npy_intp i, size = NBF_SIZE(bufferdata);
+ PyObject **data = (PyObject **)buffer;
+
+ for (i = 0; i < size; ++i, ++data) {
+ PyArray_Item_XDECREF(data, dtypes[iiter]);
+ }
+ }
+ }
}
}
@@ -4523,14 +4584,21 @@ npyiter_copy_to_buffers(NpyIter *iter)
}
if (stransfer != NULL) {
- /*printf("transfer %p -> %p\n", ad_ptrs[iiter], ptrs[iiter]);*/
+ npy_intp itemsize = PyArray_DESCR(operands[iiter])->elsize;
+
any_buffered = 1;
+
+ /* If the data type requires zero-inititialization */
+ if (PyDataType_FLAGCHK(dtypes[iiter], NPY_NEEDS_INIT)) {
+ memset(ptrs[iiter], 0, itemsize*transfersize);
+ }
+
PyArray_TransferNDimToStrided(ndim,
ptrs[iiter], strides[iiter],
ad_ptrs[iiter], &ad_strides[iiter], axisdata_incr,
&NAD_COORD(axisdata), axisdata_incr,
&NAD_SHAPE(axisdata), axisdata_incr,
- transfersize, PyArray_DESCR(operands[iiter])->elsize,
+ transfersize, itemsize,
stransfer,
transferdata);
}
@@ -4587,6 +4655,8 @@ NpyIter_DebugPrint(NpyIter *iter)
printf("ONEITERATION ");
if (itflags&NPY_ITFLAG_DELAYBUF)
printf("DELAYBUF ");
+ if (itflags&NPY_ITFLAG_NEEDSAPI)
+ printf("NEEDSAPI ");
printf("\n");
printf("NDim: %d\n", (int)ndim);
printf("NIter: %d\n", (int)niter);
diff --git a/numpy/core/src/multiarray/new_iterator_pywrap.c b/numpy/core/src/multiarray/new_iterator_pywrap.c
index 3891bb806..2db4801a5 100644
--- a/numpy/core/src/multiarray/new_iterator_pywrap.c
+++ b/numpy/core/src/multiarray/new_iterator_pywrap.c
@@ -160,6 +160,9 @@ NpyIter_GlobalFlagsConverter(PyObject *flags_in, npy_uint32 *flags)
if (strcmp(str, "ranged") == 0) {
flag = NPY_ITER_RANGED;
}
+ else if (strcmp(str, "refs_ok") == 0) {
+ flag = NPY_ITER_REFS_OK;
+ }
break;
}
if (flag == 0) {
@@ -351,18 +354,8 @@ NpyIter_OpFlagsConverter(PyObject *op_flags_in,
}
break;
case 'w':
- if (length > 5) switch (str[5]) {
- case 'a':
- if (strcmp(str,
- "writeable_references") == 0) {
- flag = NPY_ITER_WRITEABLE_REFERENCES;
- }
- break;
- case 'o':
- if (strcmp(str, "writeonly") == 0) {
- flag = NPY_ITER_WRITEONLY;
- }
- break;
+ if (strcmp(str, "writeonly") == 0) {
+ flag = NPY_ITER_WRITEONLY;
}
break;
}
@@ -1689,6 +1682,22 @@ static PyObject *npyiter_hasdelayedbufalloc_get(NewNpyArrayIterObject *self)
}
}
+static PyObject *npyiter_iterationneedsapi_get(NewNpyArrayIterObject *self)
+{
+ if (self->iter == NULL) {
+ PyErr_SetString(PyExc_ValueError,
+ "Iterator is invalid");
+ return NULL;
+ }
+
+ if (NpyIter_IterationNeedsAPI(self->iter)) {
+ Py_RETURN_TRUE;
+ }
+ else {
+ Py_RETURN_FALSE;
+ }
+}
+
static PyObject *npyiter_hascoords_get(NewNpyArrayIterObject *self)
{
if (self->iter == NULL) {
@@ -2044,6 +2053,9 @@ static PyGetSetDef npyiter_getsets[] = {
{"hasdelayedbufalloc",
(getter)npyiter_hasdelayedbufalloc_get,
NULL, NULL, NULL},
+ {"iterationneedsapi",
+ (getter)npyiter_iterationneedsapi_get,
+ NULL, NULL, NULL},
{"hascoords",
(getter)npyiter_hascoords_get,
NULL, NULL, NULL},
diff --git a/numpy/core/tests/test_new_iterator.py b/numpy/core/tests/test_new_iterator.py
index 0f2879346..d83515948 100644
--- a/numpy/core/tests/test_new_iterator.py
+++ b/numpy/core/tests/test_new_iterator.py
@@ -38,6 +38,7 @@ def test_iter_refcount():
[['readwrite','updateifcopy']],
casting='unsafe',
op_dtypes=[dt])
+ assert_(not it.iterationneedsapi)
assert_(sys.getrefcount(a) > rc_a)
assert_(sys.getrefcount(dt) > rc_dt)
it = None
@@ -885,6 +886,39 @@ def test_iter_scalar_cast_errors():
casting='same_kind',
op_dtypes=[np.dtype('i4')])
+def test_iter_object_arrays():
+ # Check that object arrays work
+
+ obj = {'a':3,'b':'d'}
+ a = np.array([[1,2,3], None, obj, None], dtype='O')
+ rc = sys.getrefcount(obj)
+
+ # Need to allow references for object arrays
+ assert_raises(TypeError, newiter, a)
+ assert_equal(sys.getrefcount(obj), rc)
+
+ i = newiter(a, ['refs_ok'], ['readonly'])
+ vals = [x[()] for x in i]
+ assert_equal(np.array(vals, dtype='O'), a)
+ vals, i, x = [None]*3
+ assert_equal(sys.getrefcount(obj), rc)
+
+ i = newiter(a.reshape(2,2).T, ['refs_ok','buffered'],
+ ['readonly'], order='C')
+ assert_(i.iterationneedsapi)
+ vals = [x[()] for x in i]
+ assert_equal(np.array(vals, dtype='O'), a.reshape(2,2).ravel(order='F'))
+ vals, i, x = [None]*3
+ assert_equal(sys.getrefcount(obj), rc)
+
+ i = newiter(a.reshape(2,2).T, ['refs_ok','buffered'],
+ ['readwrite'], order='C')
+ for x in i:
+ x[()] = None
+ vals, i, x = [None]*3
+ assert_equal(sys.getrefcount(obj), rc-1)
+ assert_equal(a, np.array([None]*4, dtype='O'))
+
def test_iter_common_dtype():
# Check that the iterator finds a common data type correctly