summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2011-07-10 10:54:25 -0500
committerCharles Harris <charlesr.harris@gmail.com>2011-07-11 09:15:31 -0600
commitd3172da58be7f3dc886c0a4c627b0b9b56efcf4d (patch)
tree82b1465d3ef151b459de743808072981e7c5b07c
parent62d5414c1b796e0e2e09b3b7446a46716f10a9ea (diff)
downloadnumpy-d3172da58be7f3dc886c0a4c627b0b9b56efcf4d.tar.gz
ENH: nditer: Add tests for writemasked iteration, also some small fixes
-rw-r--r--numpy/core/src/multiarray/dtype_transfer.c2
-rw-r--r--numpy/core/src/multiarray/nditer_api.c17
-rw-r--r--numpy/core/src/multiarray/nditer_constr.c2
-rw-r--r--numpy/core/tests/test_nditer.py104
4 files changed, 122 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/dtype_transfer.c b/numpy/core/src/multiarray/dtype_transfer.c
index 2d6bd189e..ce688efd5 100644
--- a/numpy/core/src/multiarray/dtype_transfer.c
+++ b/numpy/core/src/multiarray/dtype_transfer.c
@@ -3765,7 +3765,7 @@ PyArray_GetMaskedDTypeTransferFunction(int aligned,
/* TODO: Add struct-based mask_dtype support later */
if (mask_dtype->type_num != NPY_BOOL &&
mask_dtype->type_num != NPY_MASK) {
- PyErr_SetString(PyExc_RuntimeError,
+ PyErr_SetString(PyExc_TypeError,
"Only bool and uint8 masks are supported at the moment, "
"structs of bool/uint8 is planned for the future");
return NPY_FAIL;
diff --git a/numpy/core/src/multiarray/nditer_api.c b/numpy/core/src/multiarray/nditer_api.c
index 3d95ee859..dfc891925 100644
--- a/numpy/core/src/multiarray/nditer_api.c
+++ b/numpy/core/src/multiarray/nditer_api.c
@@ -1864,10 +1864,25 @@ npyiter_copy_from_buffers(NpyIter *iter)
(int)iop, (int)op_transfersize);
if (op_itflags[iop]&NPY_OP_ITFLAG_WRITEMASKED) {
+ npy_uint8 *maskptr;
+
+ /*
+ * The mask pointer may be in the buffer or in
+ * the array, detect which one.
+ */
+ delta = (ptrs[maskop] - buffers[maskop]);
+ if (0 <= delta &&
+ delta <= buffersize*dtypes[maskop]->elsize) {
+ maskptr = buffers[maskop];
+ }
+ else {
+ maskptr = ad_ptrs[maskop];
+ }
+
PyArray_TransferMaskedStridedToNDim(ndim_transfer,
ad_ptrs[iop], dst_strides, axisdata_incr,
buffer, src_stride,
- (npy_uint8 *)buffers[maskop], strides[maskop],
+ maskptr, strides[maskop],
dst_coords, axisdata_incr,
dst_shape, axisdata_incr,
op_transfersize, dtypes[iop]->elsize,
diff --git a/numpy/core/src/multiarray/nditer_constr.c b/numpy/core/src/multiarray/nditer_constr.c
index 1c8e3e9b6..b9f79f1d3 100644
--- a/numpy/core/src/multiarray/nditer_constr.c
+++ b/numpy/core/src/multiarray/nditer_constr.c
@@ -903,7 +903,7 @@ npyiter_check_per_op_flags(npy_uint32 op_flags, char *op_itflags)
/* Check the flag for a write masked operands */
if (op_flags & NPY_ITER_WRITEMASKED) {
- if (!(*op_itflags) & NPY_OP_ITFLAG_WRITE) {
+ if (!((*op_itflags) & NPY_OP_ITFLAG_WRITE)) {
PyErr_SetString(PyExc_ValueError,
"The iterator flag WRITEMASKED may only "
"be used with READWRITE or WRITEONLY");
diff --git a/numpy/core/tests/test_nditer.py b/numpy/core/tests/test_nditer.py
index 0a0dc2c28..cf0a44c63 100644
--- a/numpy/core/tests/test_nditer.py
+++ b/numpy/core/tests/test_nditer.py
@@ -2295,5 +2295,109 @@ def test_iter_buffering_reduction():
it.reset()
assert_equal(it[0], [1,2,1,2])
+def test_iter_writemasked_badinput():
+ a = np.zeros((2,3))
+ b = np.zeros((3,))
+ m = np.array([[True,True,False],[False,True,False]])
+ m2 = np.array([True,True,False])
+ m3 = np.array([0,1,1], dtype='u1')
+ mbad1 = np.array([0,1,1], dtype='i1')
+ mbad2 = np.array([0,1,1], dtype='f4')
+
+ # Need an 'arraymask' if any operand is 'writemasked'
+ assert_raises(ValueError, nditer, [a,m], [],
+ [['readwrite','writemasked'],['readonly']])
+
+ # A 'writemasked' operand must not be readonly
+ assert_raises(ValueError, nditer, [a,m], [],
+ [['readonly','writemasked'],['readonly','arraymask']])
+
+ # 'writemasked' and 'arraymask' may not be used together
+ assert_raises(ValueError, nditer, [a,m], [],
+ [['readonly'],['readwrite','arraymask','writemasked']])
+
+ # 'arraymask' may only be specified once
+ assert_raises(ValueError, nditer, [a,m, m2], [],
+ [['readwrite','writemasked'],
+ ['readonly','arraymask'],
+ ['readonly','arraymask']])
+
+ # An 'arraymask' with nothing 'writemasked' also doesn't make sense
+ assert_raises(ValueError, nditer, [a,m], [],
+ [['readwrite'],['readonly','arraymask']])
+
+ # A writemasked reduction requires a similarly smaller mask
+ assert_raises(ValueError, nditer, [a,b,m], ['reduce_ok'],
+ [['readonly'],
+ ['readwrite','writemasked'],
+ ['readonly','arraymask']])
+ # But this should work with a smaller/equal mask to the reduction operand
+ np.nditer([a,b,m2], ['reduce_ok'],
+ [['readonly'],
+ ['readwrite','writemasked'],
+ ['readonly','arraymask']])
+ # The arraymask itself cannot be a reduction
+ assert_raises(ValueError, nditer, [a,b,m2], ['reduce_ok'],
+ [['readonly'],
+ ['readwrite','writemasked'],
+ ['readwrite','arraymask']])
+
+ # A uint8 mask is ok too
+ np.nditer([a,m3], ['buffered'],
+ [['readwrite','writemasked'],
+ ['readonly','arraymask']],
+ op_dtypes=['f4',None],
+ casting='same_kind')
+ # An int8 mask isn't ok
+ assert_raises(TypeError, np.nditer, [a,mbad1], ['buffered'],
+ [['readwrite','writemasked'],
+ ['readonly','arraymask']],
+ op_dtypes=['f4',None],
+ casting='same_kind')
+ # A float32 mask isn't ok
+ assert_raises(TypeError, np.nditer, [a,mbad2], ['buffered'],
+ [['readwrite','writemasked'],
+ ['readonly','arraymask']],
+ op_dtypes=['f4',None],
+ casting='same_kind')
+
+def test_iter_writemasked():
+ a = np.zeros((3,), dtype='f8')
+ msk = np.array([True,True,False])
+
+ # When buffering is unused, 'writemasked' effectively does nothing.
+ # It's up to the user of the iterator to obey the requested semantics.
+ it = np.nditer([a,msk], [],
+ [['readwrite','writemasked'],
+ ['readonly','arraymask']])
+ for x, m in it:
+ x[...] = 1
+ # Because we violated the semantics, all the values became 1
+ assert_equal(a, [1,1,1])
+
+ # Even if buffering is enabled, we still may be accessing the array
+ # directly.
+ it = np.nditer([a,msk], ['buffered'],
+ [['readwrite','writemasked'],
+ ['readonly','arraymask']])
+ for x, m in it:
+ x[...] = 2.5
+ # Because we violated the semantics, all the values became 2.5
+ assert_equal(a, [2.5,2.5,2.5])
+
+ # If buffering will definitely happening, for instance because of
+ # a cast, only the items selected by the mask will be copied back from
+ # the buffer.
+ it = np.nditer([a,msk], ['buffered'],
+ [['readwrite','writemasked'],
+ ['readonly','arraymask']],
+ op_dtypes=['i8',None],
+ casting='unsafe')
+ for x, m in it:
+ x[...] = 3
+ # Even though we violated the semantics, only the selected values
+ # were copied back
+ assert_equal(a, [3,3,2.5])
+
if __name__ == "__main__":
run_module_suite()