summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2010-12-21 00:38:33 -0800
committerMark Wiebe <mwwiebe@gmail.com>2011-01-09 01:55:00 -0800
commited53e18a37736cfd7dd51c893d94471c4a37e92e (patch)
tree760d200f648f0207a7fa14b1b087f87eb814ba74
parent82643de85d2a8d25ebba351d707a2f13122d6c52 (diff)
downloadnumpy-ed53e18a37736cfd7dd51c893d94471c4a37e92e.tar.gz
ENH: iter: Added a no_broadcast flag to ensure an operand matches the iteration dimensions
-rw-r--r--numpy/core/src/multiarray/lowlevel_strided_loops.c.src4
-rw-r--r--numpy/core/src/multiarray/new_iterator.c.src41
-rw-r--r--numpy/core/src/multiarray/new_iterator.h2
-rw-r--r--numpy/core/src/multiarray/new_iterator_pywrap.c3
-rw-r--r--numpy/core/tests/test_new_iterator.py15
5 files changed, 60 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src
index f0f22aa78..d317eeb2e 100644
--- a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src
+++ b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src
@@ -874,7 +874,7 @@ PyArray_GetTransferFunction(int aligned,
PyArray_StridedTransferFn tobuffer, frombuffer, casttransfer;
/* Get the copy/swap operation from src */
- if (PyArray_ISNBO(src_dtype->byteorder)) {
+ if (src_itemsize == 1 || PyArray_ISNBO(src_dtype->byteorder)) {
tobuffer = PyArray_GetStridedCopyFn(aligned,
src_stride, src_itemsize,
src_itemsize);
@@ -891,7 +891,7 @@ PyArray_GetTransferFunction(int aligned,
}
/* Get the copy/swap operation to dst */
- if (PyArray_ISNBO(dst_dtype->byteorder)) {
+ if (dst_itemsize == 1 || PyArray_ISNBO(dst_dtype->byteorder)) {
frombuffer = PyArray_GetStridedCopyFn(aligned,
dst_itemsize, dst_stride,
dst_itemsize);
diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src
index c3e4b0aa0..a29d0604d 100644
--- a/numpy/core/src/multiarray/new_iterator.c.src
+++ b/numpy/core/src/multiarray/new_iterator.c.src
@@ -196,7 +196,7 @@ npyiter_check_casting(npy_intp niter, PyArrayObject **op,
static int
npyiter_fill_axisdata(NpyIter *iter, PyArrayObject **op,
npy_intp *op_ndim, char **op_dataptr,
- npy_intp **op_axes);
+ npy_uint32 *op_flags, npy_intp **op_axes);
static void
npyiter_replace_axisdata(NpyIter *iter, npy_intp iiter,
PyArrayObject *op,
@@ -500,7 +500,8 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags,
}
/* Fill in the AXISDATA arrays and set the ITERSIZE field */
- if (!npyiter_fill_axisdata(iter, op, op_ndim, op_dataptr, op_axes)) {
+ if (!npyiter_fill_axisdata(iter, op, op_ndim, op_dataptr,
+ op_flags, op_axes)) {
NpyIter_Deallocate(iter);
return NULL;
}
@@ -2145,7 +2146,7 @@ npyiter_check_casting(npy_intp niter, PyArrayObject **op,
static int
npyiter_fill_axisdata(NpyIter *iter, PyArrayObject **op,
npy_intp *op_ndim, char **op_dataptr,
- npy_intp **op_axes)
+ npy_uint32 *op_flags, npy_intp **op_axes)
{
npy_uint32 itflags = NIT_ITFLAGS(iter);
npy_intp idim, ndim = NIT_NDIM(iter);
@@ -2357,6 +2358,40 @@ npyiter_fill_axisdata(NpyIter *iter, PyArrayObject **op,
}
}
+ /* Go through and check for operands with NPY_ITER_NO_BROADCAST set */
+ for (iiter = 0; iiter < niter; ++iiter) {
+ if (op_flags[iiter]&NPY_ITER_NO_BROADCAST) {
+ npy_intp *axes;
+
+ if (op_ndim[iiter] != ndim) {
+ PyErr_SetString(PyExc_ValueError,
+ "Iterator input has the NO_BROADCAST flag set, but "
+ "has a different number of dimensions than the "
+ "iterator");
+ return 0;
+ }
+
+ axes = op_axes ? op_axes[iiter] : NULL;
+ axisdata = axisdata0;
+ for (idim = 0; idim < ndim; ++idim, axisdata += sizeof_axisdata) {
+ npy_intp i;
+ if (axes) {
+ i = axes[ndim-idim-1];
+ }
+ else {
+ i = ndim-idim-1;
+ }
+ if (PyArray_DIM(op[iiter], i) != NAD_SHAPE(axisdata)) {
+ PyErr_SetString(PyExc_ValueError,
+ "Iterator input has the NO_BROADCAST flag set, "
+ "but has different dimensions than the final "
+ "broadcast shape of the iterator");
+ return 0;
+ }
+ }
+ }
+ }
+
/* Now fill in the ITERSIZE member */
NIT_ITERSIZE(iter) = 1;
axisdata = axisdata0;
diff --git a/numpy/core/src/multiarray/new_iterator.h b/numpy/core/src/multiarray/new_iterator.h
index 905642d79..1c79d57f4 100644
--- a/numpy/core/src/multiarray/new_iterator.h
+++ b/numpy/core/src/multiarray/new_iterator.h
@@ -138,6 +138,8 @@ NPY_NO_EXPORT void NpyIter_DebugPrint(NpyIter *iter);
#define NPY_ITER_ALLOCATE 0x04000000
/* If an operand is allocated, don't use the priority subtype */
#define NPY_ITER_NO_SUBTYPE 0x08000000
+/* Disallows broadcasting of the dimensions, they must match exactly */
+#define NPY_ITER_NO_BROADCAST 0x10000000
#define NPY_ITER_GLOBAL_FLAGS 0x0000ffff
#define NPY_ITER_PER_OP_FLAGS 0xffff0000
diff --git a/numpy/core/src/multiarray/new_iterator_pywrap.c b/numpy/core/src/multiarray/new_iterator_pywrap.c
index 1f9cd0b53..cf2253754 100644
--- a/numpy/core/src/multiarray/new_iterator_pywrap.c
+++ b/numpy/core/src/multiarray/new_iterator_pywrap.c
@@ -317,6 +317,9 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
if (strcmp(str, "no_subtype") == 0) {
flag = NPY_ITER_NO_SUBTYPE;
}
+ else if (strcmp(str, "no_broadcast") == 0) {
+ flag = NPY_ITER_NO_BROADCAST;
+ }
break;
}
break;
diff --git a/numpy/core/tests/test_new_iterator.py b/numpy/core/tests/test_new_iterator.py
index c93b13539..909c2f0d2 100644
--- a/numpy/core/tests/test_new_iterator.py
+++ b/numpy/core/tests/test_new_iterator.py
@@ -1104,6 +1104,8 @@ def test_iter_buffering():
a.dtype = 'i4'
a[:] = np.arange(16,dtype='i4')
arrays.append(a)
+ # 4-D F-order array
+ arrays.append(np.arange(120,dtype='i4').reshape(5,3,2,4).T)
for a in arrays:
for buffersize in (1,2,3,5,8,11,16,1024):
vals = []
@@ -1216,5 +1218,18 @@ def test_iter_buffering_growinner():
# Should end up with just one inner loop here
assert_equal(i[0].size, a.size)
+def test_iter_no_broadcast():
+ # Test that the no_broadcast flag works
+ a = np.arange(24).reshape(2,3,4)
+ b = np.arange(6).reshape(2,3,1)
+ c = np.arange(12).reshape(3,4)
+
+ i = np.newiter([a,b,c], [],
+ [['readonly','no_broadcast'],['readonly'],['readonly']])
+ assert_raises(ValueError, np.newiter, [a,b,c], [],
+ [['readonly'],['readonly','no_broadcast'],['readonly']])
+ assert_raises(ValueError, np.newiter, [a,b,c], [],
+ [['readonly'],['readonly'],['readonly','no_broadcast']])
+
if __name__ == "__main__":
run_module_suite()