diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2010-12-21 00:38:33 -0800 |
---|---|---|
committer | Mark Wiebe <mwwiebe@gmail.com> | 2011-01-09 01:55:00 -0800 |
commit | ed53e18a37736cfd7dd51c893d94471c4a37e92e (patch) | |
tree | 760d200f648f0207a7fa14b1b087f87eb814ba74 | |
parent | 82643de85d2a8d25ebba351d707a2f13122d6c52 (diff) | |
download | numpy-ed53e18a37736cfd7dd51c893d94471c4a37e92e.tar.gz |
ENH: iter: Added a no_broadcast flag to ensure an operand matches the iteration dimensions
-rw-r--r-- | numpy/core/src/multiarray/lowlevel_strided_loops.c.src | 4 | ||||
-rw-r--r-- | numpy/core/src/multiarray/new_iterator.c.src | 41 | ||||
-rw-r--r-- | numpy/core/src/multiarray/new_iterator.h | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarray/new_iterator_pywrap.c | 3 | ||||
-rw-r--r-- | numpy/core/tests/test_new_iterator.py | 15 |
5 files changed, 60 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src index f0f22aa78..d317eeb2e 100644 --- a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src +++ b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src @@ -874,7 +874,7 @@ PyArray_GetTransferFunction(int aligned, PyArray_StridedTransferFn tobuffer, frombuffer, casttransfer; /* Get the copy/swap operation from src */ - if (PyArray_ISNBO(src_dtype->byteorder)) { + if (src_itemsize == 1 || PyArray_ISNBO(src_dtype->byteorder)) { tobuffer = PyArray_GetStridedCopyFn(aligned, src_stride, src_itemsize, src_itemsize); @@ -891,7 +891,7 @@ PyArray_GetTransferFunction(int aligned, } /* Get the copy/swap operation to dst */ - if (PyArray_ISNBO(dst_dtype->byteorder)) { + if (dst_itemsize == 1 || PyArray_ISNBO(dst_dtype->byteorder)) { frombuffer = PyArray_GetStridedCopyFn(aligned, dst_itemsize, dst_stride, dst_itemsize); diff --git a/numpy/core/src/multiarray/new_iterator.c.src b/numpy/core/src/multiarray/new_iterator.c.src index c3e4b0aa0..a29d0604d 100644 --- a/numpy/core/src/multiarray/new_iterator.c.src +++ b/numpy/core/src/multiarray/new_iterator.c.src @@ -196,7 +196,7 @@ npyiter_check_casting(npy_intp niter, PyArrayObject **op, static int npyiter_fill_axisdata(NpyIter *iter, PyArrayObject **op, npy_intp *op_ndim, char **op_dataptr, - npy_intp **op_axes); + npy_uint32 *op_flags, npy_intp **op_axes); static void npyiter_replace_axisdata(NpyIter *iter, npy_intp iiter, PyArrayObject *op, @@ -500,7 +500,8 @@ NpyIter_MultiNew(npy_intp niter, PyArrayObject **op_in, npy_uint32 flags, } /* Fill in the AXISDATA arrays and set the ITERSIZE field */ - if (!npyiter_fill_axisdata(iter, op, op_ndim, op_dataptr, op_axes)) { + if (!npyiter_fill_axisdata(iter, op, op_ndim, op_dataptr, + op_flags, op_axes)) { NpyIter_Deallocate(iter); return NULL; } @@ -2145,7 +2146,7 @@ npyiter_check_casting(npy_intp niter, PyArrayObject **op, static int npyiter_fill_axisdata(NpyIter *iter, PyArrayObject **op, npy_intp *op_ndim, char **op_dataptr, - npy_intp **op_axes) + npy_uint32 *op_flags, npy_intp **op_axes) { npy_uint32 itflags = NIT_ITFLAGS(iter); npy_intp idim, ndim = NIT_NDIM(iter); @@ -2357,6 +2358,40 @@ npyiter_fill_axisdata(NpyIter *iter, PyArrayObject **op, } } + /* Go through and check for operands with NPY_ITER_NO_BROADCAST set */ + for (iiter = 0; iiter < niter; ++iiter) { + if (op_flags[iiter]&NPY_ITER_NO_BROADCAST) { + npy_intp *axes; + + if (op_ndim[iiter] != ndim) { + PyErr_SetString(PyExc_ValueError, + "Iterator input has the NO_BROADCAST flag set, but " + "has a different number of dimensions than the " + "iterator"); + return 0; + } + + axes = op_axes ? op_axes[iiter] : NULL; + axisdata = axisdata0; + for (idim = 0; idim < ndim; ++idim, axisdata += sizeof_axisdata) { + npy_intp i; + if (axes) { + i = axes[ndim-idim-1]; + } + else { + i = ndim-idim-1; + } + if (PyArray_DIM(op[iiter], i) != NAD_SHAPE(axisdata)) { + PyErr_SetString(PyExc_ValueError, + "Iterator input has the NO_BROADCAST flag set, " + "but has different dimensions than the final " + "broadcast shape of the iterator"); + return 0; + } + } + } + } + /* Now fill in the ITERSIZE member */ NIT_ITERSIZE(iter) = 1; axisdata = axisdata0; diff --git a/numpy/core/src/multiarray/new_iterator.h b/numpy/core/src/multiarray/new_iterator.h index 905642d79..1c79d57f4 100644 --- a/numpy/core/src/multiarray/new_iterator.h +++ b/numpy/core/src/multiarray/new_iterator.h @@ -138,6 +138,8 @@ NPY_NO_EXPORT void NpyIter_DebugPrint(NpyIter *iter); #define NPY_ITER_ALLOCATE 0x04000000 /* If an operand is allocated, don't use the priority subtype */ #define NPY_ITER_NO_SUBTYPE 0x08000000 +/* Disallows broadcasting of the dimensions, they must match exactly */ +#define NPY_ITER_NO_BROADCAST 0x10000000 #define NPY_ITER_GLOBAL_FLAGS 0x0000ffff #define NPY_ITER_PER_OP_FLAGS 0xffff0000 diff --git a/numpy/core/src/multiarray/new_iterator_pywrap.c b/numpy/core/src/multiarray/new_iterator_pywrap.c index 1f9cd0b53..cf2253754 100644 --- a/numpy/core/src/multiarray/new_iterator_pywrap.c +++ b/numpy/core/src/multiarray/new_iterator_pywrap.c @@ -317,6 +317,9 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds) if (strcmp(str, "no_subtype") == 0) { flag = NPY_ITER_NO_SUBTYPE; } + else if (strcmp(str, "no_broadcast") == 0) { + flag = NPY_ITER_NO_BROADCAST; + } break; } break; diff --git a/numpy/core/tests/test_new_iterator.py b/numpy/core/tests/test_new_iterator.py index c93b13539..909c2f0d2 100644 --- a/numpy/core/tests/test_new_iterator.py +++ b/numpy/core/tests/test_new_iterator.py @@ -1104,6 +1104,8 @@ def test_iter_buffering(): a.dtype = 'i4' a[:] = np.arange(16,dtype='i4') arrays.append(a) + # 4-D F-order array + arrays.append(np.arange(120,dtype='i4').reshape(5,3,2,4).T) for a in arrays: for buffersize in (1,2,3,5,8,11,16,1024): vals = [] @@ -1216,5 +1218,18 @@ def test_iter_buffering_growinner(): # Should end up with just one inner loop here assert_equal(i[0].size, a.size) +def test_iter_no_broadcast(): + # Test that the no_broadcast flag works + a = np.arange(24).reshape(2,3,4) + b = np.arange(6).reshape(2,3,1) + c = np.arange(12).reshape(3,4) + + i = np.newiter([a,b,c], [], + [['readonly','no_broadcast'],['readonly'],['readonly']]) + assert_raises(ValueError, np.newiter, [a,b,c], [], + [['readonly'],['readonly','no_broadcast'],['readonly']]) + assert_raises(ValueError, np.newiter, [a,b,c], [], + [['readonly'],['readonly'],['readonly','no_broadcast']]) + if __name__ == "__main__": run_module_suite() |