summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2011-08-22 11:44:47 -0700
committerCharles Harris <charlesr.harris@gmail.com>2011-08-27 07:26:59 -0600
commit447d55d17136b8516a2ce49edae9ec82f0b00046 (patch)
treed851c121af898075b499b69d5c4bd871c758bbe9 /numpy
parentda2c9e4fa05b2df1062af519c7880286ab8d20c9 (diff)
downloadnumpy-447d55d17136b8516a2ce49edae9ec82f0b00046.tar.gz
ENH: ufunc: Add a mask dtype parameter to the masked ufunc loop selector
This is to allow for future expansion to multi-NA and struct-NA.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/include/numpy/ufuncobject.h1
-rw-r--r--numpy/core/src/umath/ufunc_object.c6
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c8
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.h1
4 files changed, 15 insertions, 1 deletions
diff --git a/numpy/core/include/numpy/ufuncobject.h b/numpy/core/include/numpy/ufuncobject.h
index 88198a449..dab5fd6e4 100644
--- a/numpy/core/include/numpy/ufuncobject.h
+++ b/numpy/core/include/numpy/ufuncobject.h
@@ -119,6 +119,7 @@ typedef int (PyUFunc_InnerLoopSelectionFunc)(
typedef int (PyUFunc_MaskedInnerLoopSelectionFunc)(
struct _tagPyUFuncObject *ufunc,
PyArray_Descr **dtypes,
+ PyArray_Descr *mask_dtype,
npy_intp *fixed_strides,
npy_intp fixed_mask_stride,
PyUFunc_MaskedStridedInnerLoopFunc **out_innerloop,
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index b0ebbf9b0..a9a3bc8b8 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -1640,6 +1640,7 @@ execute_ufunc_masked_loop(PyUFuncObject *ufunc,
PyUFunc_MaskedStridedInnerLoopFunc *innerloop;
NpyAuxData *innerloopdata;
npy_intp fixed_strides[2*NPY_MAXARGS];
+ PyArray_Descr **iter_dtypes;
/* Validate that the prepare_ufunc_output didn't mess with pointers */
for (i = nin; i < nop; ++i) {
@@ -1657,7 +1658,10 @@ execute_ufunc_masked_loop(PyUFuncObject *ufunc,
* based on the fixed strides.
*/
NpyIter_GetInnerFixedStrideArray(iter, fixed_strides);
+ iter_dtypes = NpyIter_GetDescrArray(iter);
if (ufunc->masked_inner_loop_selector(ufunc, dtypes,
+ wheremask != NULL ? iter_dtypes[nop]
+ : iter_dtypes[nop + nin],
fixed_strides,
wheremask != NULL ? fixed_strides[nop]
: fixed_strides[nop + nin],
@@ -2686,7 +2690,7 @@ masked_reduce_loop(NpyIter *iter, char **dataptrs, npy_intp *strides,
dtypes[0] = iter_dtypes[0];
dtypes[1] = iter_dtypes[1];
dtypes[2] = iter_dtypes[0];
- if (ufunc->masked_inner_loop_selector(ufunc, dtypes,
+ if (ufunc->masked_inner_loop_selector(ufunc, dtypes, iter_dtypes[2],
fixed_strides, fixed_mask_stride,
&innerloop, &innerloopdata, &needs_api) < 0) {
return -1;
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index cfb12fb00..77059a19a 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -1392,6 +1392,7 @@ unmasked_ufunc_loop_as_masked(
NPY_NO_EXPORT int
PyUFunc_DefaultMaskedInnerLoopSelector(PyUFuncObject *ufunc,
PyArray_Descr **dtypes,
+ PyArray_Descr *mask_dtype,
npy_intp *NPY_UNUSED(fixed_strides),
npy_intp NPY_UNUSED(fixed_mask_stride),
PyUFunc_MaskedStridedInnerLoopFunc **out_innerloop,
@@ -1409,6 +1410,13 @@ PyUFunc_DefaultMaskedInnerLoopSelector(PyUFuncObject *ufunc,
return -1;
}
+ if (mask_dtype->type_num != NPY_BOOL) {
+ PyErr_SetString(PyExc_ValueError,
+ "only boolean masks are supported in ufunc inner loops "
+ "presently");
+ return -1;
+ }
+
/* Create a new NpyAuxData object for the masker data */
data = (_ufunc_masker_data *)PyArray_malloc(sizeof(_ufunc_masker_data));
if (data == NULL) {
diff --git a/numpy/core/src/umath/ufunc_type_resolution.h b/numpy/core/src/umath/ufunc_type_resolution.h
index 8effa33a4..a1241827e 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.h
+++ b/numpy/core/src/umath/ufunc_type_resolution.h
@@ -102,6 +102,7 @@ PyUFunc_DefaultLegacyInnerLoopSelector(PyUFuncObject *ufunc,
NPY_NO_EXPORT int
PyUFunc_DefaultMaskedInnerLoopSelector(PyUFuncObject *ufunc,
PyArray_Descr **dtypes,
+ PyArray_Descr *mask_dtypes,
npy_intp *NPY_UNUSED(fixed_strides),
npy_intp NPY_UNUSED(fixed_mask_stride),
PyUFunc_MaskedStridedInnerLoopFunc **out_innerloop,