summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwiebe@enthought.com>2011-07-08 11:51:41 -0500
committerCharles Harris <charlesr.harris@gmail.com>2011-07-08 19:38:24 -0600
commit5f03b1504bcbe31b611376b6651e0297db165bad (patch)
tree420bd74507dc9b5625c391a4bf8a6b0410a5a137
parent9910b0dbdf999e5a0d1a94d1134071612e126f06 (diff)
downloadnumpy-5f03b1504bcbe31b611376b6651e0297db165bad.tar.gz
ENH: core: Add support for masked strided transfer functions
This implementation has no optimization whatsoever in it yet, it just wraps the unmasked strided transfer functions. It also does not handle struct masks yet.
-rw-r--r--numpy/core/src/multiarray/dtype_transfer.c221
-rw-r--r--numpy/core/src/multiarray/nditer_constr.c4
-rw-r--r--numpy/core/src/private/lowlevel_strided_loops.h42
-rw-r--r--numpy/core/src/umath/ufunc_type_resolution.c12
4 files changed, 271 insertions, 8 deletions
diff --git a/numpy/core/src/multiarray/dtype_transfer.c b/numpy/core/src/multiarray/dtype_transfer.c
index 7609f0016..47182b82a 100644
--- a/numpy/core/src/multiarray/dtype_transfer.c
+++ b/numpy/core/src/multiarray/dtype_transfer.c
@@ -2999,6 +2999,149 @@ get_setdestzero_fields_transfer_function(int aligned,
return NPY_SUCCEED;
}
+/************************* MASKED TRANSFER WRAPPER *************************/
+
+typedef struct {
+ NpyAuxData base;
+ /* The transfer function being wrapped */
+ PyArray_StridedTransferFn *stransfer;
+ NpyAuxData *transferdata;
+
+ /* The src decref function if necessary */
+ PyArray_StridedTransferFn *decsrcref_stransfer;
+ NpyAuxData *decsrcref_transferdata;
+} _masked_wrapper_transfer_data;
+
+/* transfer data free function */
+void _masked_wrapper_transfer_data_free(NpyAuxData *data)
+{
+ _masked_wrapper_transfer_data *d = (_masked_wrapper_transfer_data *)data;
+ NPY_AUXDATA_FREE(d->transferdata);
+ NPY_AUXDATA_FREE(d->decsrcref_transferdata);
+ PyArray_free(data);
+}
+
+/* transfer data copy function */
+NpyAuxData *_masked_wrapper_transfer_data_clone(NpyAuxData *data)
+{
+ _masked_wrapper_transfer_data *d = (_masked_wrapper_transfer_data *)data;
+ _masked_wrapper_transfer_data *newdata;
+
+ /* Allocate the data and populate it */
+ newdata = (_masked_wrapper_transfer_data *)PyArray_malloc(
+ sizeof(_masked_wrapper_transfer_data));
+ if (newdata == NULL) {
+ return NULL;
+ }
+ memcpy(newdata, d, sizeof(_masked_wrapper_transfer_data));
+
+ /* Clone all the owned auxdata as well */
+ if (newdata->transferdata != NULL) {
+ newdata->transferdata = NPY_AUXDATA_CLONE(newdata->transferdata);
+ if (newdata->transferdata == NULL) {
+ PyArray_free(newdata);
+ return NULL;
+ }
+ }
+ if (newdata->decsrcref_transferdata != NULL) {
+ newdata->decsrcref_transferdata =
+ NPY_AUXDATA_CLONE(newdata->decsrcref_transferdata);
+ if (newdata->decsrcref_transferdata == NULL) {
+ NPY_AUXDATA_FREE(newdata->transferdata);
+ PyArray_free(newdata);
+ return NULL;
+ }
+ }
+
+ return (NpyAuxData *)newdata;
+}
+
+void _strided_masked_wrapper_decsrcref_transfer_function(
+ char *dst, npy_intp dst_stride,
+ char *src, npy_intp src_stride,
+ npy_uint8 *mask, npy_intp mask_stride,
+ npy_intp N, npy_intp src_itemsize,
+ NpyAuxData *transferdata)
+{
+ _masked_wrapper_transfer_data *d =
+ (_masked_wrapper_transfer_data *)transferdata;
+ npy_intp subloopsize;
+ PyArray_StridedTransferFn *unmasked_stransfer, *decsrcref_stransfer;
+ NpyAuxData *unmasked_transferdata, *decsrcref_transferdata;
+
+ unmasked_stransfer = d->stransfer;
+ unmasked_transferdata = d->transferdata;
+ decsrcref_stransfer = d->decsrcref_stransfer;
+ decsrcref_transferdata = d->decsrcref_transferdata;
+
+ while (N > 0) {
+ /* Skip masked values, still calling decsrcref for move_references */
+ subloopsize = 0;
+ while (subloopsize < N && ((*mask)&0x01) == 0) {
+ ++subloopsize;
+ mask += mask_stride;
+ }
+ decsrcref_stransfer(NULL, 0, src, src_stride,
+ subloopsize, src_itemsize, decsrcref_transferdata);
+ dst += subloopsize * dst_stride;
+ src += subloopsize * src_stride;
+ N -= subloopsize;
+ /* Process unmasked values */
+ subloopsize = 0;
+ while (subloopsize < N && ((*mask)&0x01) != 0) {
+ ++subloopsize;
+ mask += mask_stride;
+ }
+ unmasked_stransfer(dst, dst_stride, src, src_stride,
+ subloopsize, src_itemsize, unmasked_transferdata);
+ dst += subloopsize * dst_stride;
+ src += subloopsize * src_stride;
+ N -= subloopsize;
+ }
+}
+
+void _strided_masked_wrapper_transfer_function(
+ char *dst, npy_intp dst_stride,
+ char *src, npy_intp src_stride,
+ npy_uint8 *mask, npy_intp mask_stride,
+ npy_intp N, npy_intp src_itemsize,
+ NpyAuxData *transferdata)
+{
+
+ _masked_wrapper_transfer_data *d =
+ (_masked_wrapper_transfer_data *)transferdata;
+ npy_intp subloopsize;
+ PyArray_StridedTransferFn *unmasked_stransfer;
+ NpyAuxData *unmasked_transferdata;
+
+ unmasked_stransfer = d->stransfer;
+ unmasked_transferdata = d->transferdata;
+
+ while (N > 0) {
+ /* Skip masked values */
+ subloopsize = 0;
+ while (subloopsize < N && ((*mask)&0x01) == 0) {
+ ++subloopsize;
+ mask += mask_stride;
+ }
+ dst += subloopsize * dst_stride;
+ src += subloopsize * src_stride;
+ N -= subloopsize;
+ /* Process unmasked values */
+ subloopsize = 0;
+ while (subloopsize < N && ((*mask)&0x01) != 0) {
+ ++subloopsize;
+ mask += mask_stride;
+ }
+ unmasked_stransfer(dst, dst_stride, src, src_stride,
+ subloopsize, src_itemsize, unmasked_transferdata);
+ dst += subloopsize * dst_stride;
+ src += subloopsize * src_stride;
+ N -= subloopsize;
+ }
+}
+
+
/************************* DEST BOOL SETONE *******************************/
static void
@@ -3603,6 +3746,84 @@ PyArray_GetDTypeTransferFunction(int aligned,
}
NPY_NO_EXPORT int
+PyArray_GetMaskedDTypeTransferFunction(int aligned,
+ npy_intp src_stride,
+ npy_intp dst_stride,
+ npy_intp mask_stride,
+ PyArray_Descr *src_dtype,
+ PyArray_Descr *dst_dtype,
+ PyArray_Descr *mask_dtype,
+ int move_references,
+ PyArray_MaskedStridedTransferFn **out_stransfer,
+ NpyAuxData **out_transferdata,
+ int *out_needs_api)
+{
+ PyArray_StridedTransferFn *stransfer = NULL;
+ NpyAuxData *transferdata = NULL;
+ _masked_wrapper_transfer_data *data;
+
+ /* TODO: Add struct-based mask_dtype support later */
+ if (mask_dtype->type_num != NPY_BOOL &&
+ mask_dtype->type_num != NPY_UINT8) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Only bool and uint8 masks are supported at the moment, "
+ "structs of bool/uint8 is planned for the future");
+ return NPY_FAIL;
+ }
+
+ /* TODO: Special case some important cases so they're fast */
+
+ /* Fall back to wrapping a non-masked transfer function */
+ if (PyArray_GetDTypeTransferFunction(aligned,
+ src_stride, dst_stride,
+ src_dtype, dst_dtype,
+ move_references,
+ &stransfer, &transferdata,
+ out_needs_api) != NPY_SUCCEED) {
+ return NPY_FAIL;
+ }
+
+ /* Create the wrapper function's auxdata */
+ data = (_masked_wrapper_transfer_data *)PyArray_malloc(
+ sizeof(_masked_wrapper_transfer_data));
+ if (data == NULL) {
+ PyErr_NoMemory();
+ NPY_AUXDATA_FREE(transferdata);
+ return NPY_FAIL;
+ }
+
+ /* Fill in the auxdata object */
+ memset(data, 0, sizeof(_masked_wrapper_transfer_data));
+ data->base.free = &_masked_wrapper_transfer_data_free;
+ data->base.clone = &_masked_wrapper_transfer_data_clone;
+
+ data->stransfer = stransfer;
+ data->transferdata = transferdata;
+
+ /* If the src object will need a DECREF, get a function to handle that */
+ if (move_references && PyDataType_REFCHK(src_dtype)) {
+ if (get_decsrcref_transfer_function(aligned,
+ src_stride,
+ src_dtype,
+ &data->decsrcref_stransfer,
+ &data->decsrcref_transferdata,
+ out_needs_api) != NPY_SUCCEED) {
+ NPY_AUXDATA_FREE((NpyAuxData *)data);
+ return NPY_FAIL;
+ }
+
+ *out_stransfer = &_strided_masked_wrapper_decsrcref_transfer_function;
+ }
+ else {
+ *out_stransfer = &_strided_masked_wrapper_transfer_function;
+ }
+
+ *out_transferdata = (NpyAuxData *)data;
+
+ return NPY_SUCCEED;
+}
+
+NPY_NO_EXPORT int
PyArray_CastRawArrays(npy_intp count,
char *src, char *dst,
npy_intp src_stride, npy_intp dst_stride,
diff --git a/numpy/core/src/multiarray/nditer_constr.c b/numpy/core/src/multiarray/nditer_constr.c
index 94baafbdb..42720adba 100644
--- a/numpy/core/src/multiarray/nditer_constr.c
+++ b/numpy/core/src/multiarray/nditer_constr.c
@@ -987,10 +987,10 @@ npyiter_prepare_one_operand(PyArrayObject **op,
*op_dtype = NULL;
}
- /* Specify uint8 if no dtype was requested for the mask */
+ /* Specify bool if no dtype was requested for the mask */
if (op_flags&NPY_ITER_ARRAYMASK) {
if (*op_dtype == NULL) {
- *op_dtype = PyArray_DescrFromType(NPY_UINT8);
+ *op_dtype = PyArray_DescrFromType(NPY_BOOL);
if (*op_dtype == NULL) {
return 0;
}
diff --git a/numpy/core/src/private/lowlevel_strided_loops.h b/numpy/core/src/private/lowlevel_strided_loops.h
index a1f183e50..b4cd79f9a 100644
--- a/numpy/core/src/private/lowlevel_strided_loops.h
+++ b/numpy/core/src/private/lowlevel_strided_loops.h
@@ -27,6 +27,20 @@ typedef void (PyArray_StridedTransferFn)(char *dst, npy_intp dst_stride,
NpyAuxData *transferdata);
/*
+ * This is for pointers to functions which behave exactly as
+ * for PyArray_StridedTransferFn, but with an additional mask controlling
+ * which values are transferred.
+ *
+ * In particular, the 'i'-th element is transfered if and only if
+ * (((mask[i*mask_stride])&0x01) == 0x01).
+ */
+typedef void (PyArray_MaskedStridedTransferFn)(char *dst, npy_intp dst_stride,
+ char *src, npy_intp src_stride,
+ npy_uint8 *mask, npy_intp mask_stride,
+ npy_intp N, npy_intp src_itemsize,
+ NpyAuxData *transferdata);
+
+/*
* Gives back a function pointer to a specialized function for copying
* strided memory. Returns NULL if there is a problem with the inputs.
*
@@ -174,6 +188,34 @@ PyArray_GetDTypeTransferFunction(int aligned,
int *out_needs_api);
/*
+ * This is identical to PyArray_GetDTypeTransferFunction, but
+ * returns a transfer function which also takes a mask as a parameter.
+ * Bit zero of the mask is used to determine which values to copy,
+ * data is transfered exactly when ((mask[i])&0x01) == 0x01.
+ *
+ * If move_references is true, values which are not copied to the
+ * destination will still have their source reference decremented.
+ *
+ * If mask_dtype is NPY_BOOL or NPY_UINT8, each full element is either
+ * transferred or not according to the mask as described above. If
+ * dst_dtype and mask_dtype are both struct dtypes, their names must
+ * match exactly, and the dtype of each leaf field in mask_dtype must
+ * be either NPY_BOOL or NPY_UINT8.
+ */
+NPY_NO_EXPORT int
+PyArray_GetMaskedDTypeTransferFunction(int aligned,
+ npy_intp src_stride,
+ npy_intp dst_stride,
+ npy_intp mask_stride,
+ PyArray_Descr *src_dtype,
+ PyArray_Descr *dst_dtype,
+ PyArray_Descr *mask_dtype,
+ int move_references,
+ PyArray_MaskedStridedTransferFn **out_stransfer,
+ NpyAuxData **out_transferdata,
+ int *out_needs_api);
+
+/*
* Casts the specified number of elements from 'src' with data type
* 'src_dtype' to 'dst' with 'dst_dtype'. See
* PyArray_GetDTypeTransferFunction for more details.
diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c
index 8eb1f8ddf..5fb055b28 100644
--- a/numpy/core/src/umath/ufunc_type_resolution.c
+++ b/numpy/core/src/umath/ufunc_type_resolution.c
@@ -1391,7 +1391,7 @@ unmasked_ufunc_loop_as_masked(
void *unmasked_innerloopdata;
npy_intp loopsize, subloopsize;
char *mask;
- npy_intp maskstep;
+ npy_intp mask_stride;
/* Put the aux data into local variables */
data = (_ufunc_masker_data *)innerloopdata;
@@ -1400,16 +1400,16 @@ unmasked_ufunc_loop_as_masked(
nargs = data->nargs;
loopsize = *dimensions;
mask = args[nargs];
- maskstep = steps[nargs];
+ mask_stride = steps[nargs];
/* Process the data as runs of unmasked values */
do {
/* Skip masked values */
subloopsize = 0;
- while (subloopsize < loopsize && *(npy_bool *)mask == 0) {
+ while (subloopsize < loopsize && (*(npy_uint8 *)mask)&0x01 == 0) {
++subloopsize;
- mask += maskstep;
+ mask += mask_stride;
}
for (iargs = 0; iargs < nargs; ++iargs) {
args[iargs] += subloopsize * steps[iargs];
@@ -1420,9 +1420,9 @@ unmasked_ufunc_loop_as_masked(
* mess with the 'args' pointer values)
*/
subloopsize = 0;
- while (subloopsize < loopsize && *(npy_bool *)mask != 0) {
+ while (subloopsize < loopsize && (*(npy_uint8 *)mask)&0x01 != 0) {
++subloopsize;
- mask += maskstep;
+ mask += mask_stride;
}
unmasked_innerloop(args, &subloopsize, steps, unmasked_innerloopdata);
for (iargs = 0; iargs < nargs; ++iargs) {