summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/umath/ufunc_object.c62
-rw-r--r--numpy/core/tests/test_maskna.py47
2 files changed, 102 insertions, 7 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 96bf925ca..3586e2ac2 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -712,12 +712,13 @@ static int get_ufunc_arguments(PyUFuncObject *self,
PyArrayObject **out_wheremask,
int *out_use_maskna)
{
- int i, nargs, nin = self->nin;
+ int i, nargs, nin = self->nin, nout = self->nout;
PyObject *obj, *context;
PyObject *str_key_obj = NULL;
char *ufunc_name;
int any_flexible = 0, any_object = 0;
+ int any_non_maskna_out = 0, any_maskna_out = 0;
ufunc_name = self->name ? self->name : "<unnamed ufunc>";
@@ -800,6 +801,13 @@ static int get_ufunc_arguments(PyUFuncObject *self,
}
Py_INCREF(obj);
out_op[i] = (PyArrayObject *)obj;
+
+ if (PyArray_HASMASKNA((PyArrayObject *)obj)) {
+ any_maskna_out = 1;
+ }
+ else {
+ any_non_maskna_out = 1;
+ }
}
else {
PyErr_SetString(PyExc_TypeError,
@@ -892,6 +900,13 @@ static int get_ufunc_arguments(PyUFuncObject *self,
}
Py_INCREF(value);
out_op[nin] = (PyArrayObject *)value;
+
+ if (PyArray_HASMASKNA((PyArrayObject *)value)) {
+ any_maskna_out = 1;
+ }
+ else {
+ any_non_maskna_out = 1;
+ }
}
else {
PyErr_SetString(PyExc_TypeError,
@@ -961,8 +976,42 @@ static int get_ufunc_arguments(PyUFuncObject *self,
}
}
}
-
Py_XDECREF(str_key_obj);
+
+ /*
+ * If NA mask support is enabled and there are non-maskNA outputs,
+ * only proceed if all the inputs contain no NA values.
+ */
+ if (*out_use_maskna && any_non_maskna_out) {
+ /* Check all the inputs for NA */
+ for(i = 0; i < nin; ++i) {
+ if (PyArray_HASMASKNA(out_op[i])) {
+ if (PyArray_ContainsNA(out_op[i])) {
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot assign NA value to an array which "
+ "does not support NAs");
+ return -1;
+ }
+ }
+ }
+
+ /* Disable MASKNA - the inner loop uses NPY_ITER_IGNORE_MASKNA */
+ *out_use_maskna = 0;
+ }
+ /*
+ * If we're not using a masked loop, but an output has an NA mask,
+ * set it to all exposed.
+ */
+ else if (!(*out_use_maskna) && any_maskna_out) {
+ for (i = nin; i < nin+nout; ++i) {
+ if (PyArray_HASMASKNA(out_op[i])) {
+ if (PyArray_AssignMaskNA(out_op[i], 1) < 0) {
+ return -1;
+ }
+ }
+ }
+ }
+
return 0;
fail:
@@ -1507,6 +1556,15 @@ execute_ufunc_masked_loop(PyUFuncObject *self,
default_op_in_flags |= NPY_ITER_USE_MASKNA;
default_op_out_flags |= NPY_ITER_USE_MASKNA;
}
+ /*
+ * Some operands may still have NA masks, but they will
+ * have been checked to ensure they have no NAs using
+ * PyArray_ContainsNA. Thus we flag to ignore MASKNA here.
+ */
+ else {
+ default_op_in_flags |= NPY_ITER_IGNORE_MASKNA;
+ default_op_out_flags |= NPY_ITER_IGNORE_MASKNA;
+ }
/* Set up the flags */
for (i = 0; i < nin; ++i) {
diff --git a/numpy/core/tests/test_maskna.py b/numpy/core/tests/test_maskna.py
index 6c95de4db..419758242 100644
--- a/numpy/core/tests/test_maskna.py
+++ b/numpy/core/tests/test_maskna.py
@@ -576,13 +576,50 @@ def test_maskna_take_1D():
assert_equal(np.isna(c), [0,1,0])
def test_maskna_ufunc_1D():
- a = np.arange(3, maskna=True)
- b = np.arange(3)
+ a_orig = np.arange(3)
+ a = a_orig.view(maskna=True)
+ b_orig = np.array([5,4,3])
+ b = b_orig.view(maskna=True)
+ c_orig = np.array([0,0,0])
+ c = c_orig.view(maskna=True)
# An NA mask is produced if an operand has one
- c = a + b
- assert_(c.flags.maskna)
- #assert_equal(c, [0,2,4])
+ res = a + b_orig
+ assert_(res.flags.maskna)
+ assert_equal(res, [5,5,5])
+
+ res = b_orig + a
+ assert_(res.flags.maskna)
+ assert_equal(res, [5,5,5])
+
+ # Can still output to a non-NA array if there are no NAs
+ np.add(a, b, out=c_orig)
+ assert_equal(c_orig, [5,5,5])
+
+ # Should unmask everything if the output has NA support but
+ # the inputs don't
+ c_orig[...] = 0
+ c[...] = np.NA
+ np.add(a_orig, b_orig, out=c)
+ assert_equal(c, [5,5,5])
+
+ # If the input has NA support but an output parameter doesn't,
+ # should work as long as the inputs contain no NAs
+ c_orig[...] = 0
+ np.add(a, b, out=c_orig)
+ assert_equal(c_orig, [5,5,5])
+
+ # An NA is produced if either operand has one
+ a[0] = np.NA
+ b[1] = np.NA
+ res = a + b
+ assert_equal(np.isna(res), [1,1,0])
+ assert_equal(res[2], 5)
+
+ # If the output contains NA, can't have out= parameter without
+ # NA support
+ assert_raises(ValueError, np.add, a, b, out=c_orig)
+
def test_maskna_ufunc_sum_1D():
check_maskna_ufunc_sum_1D(np.sum)