diff options
author | Mark Wiebe <mwiebe@enthought.com> | 2011-06-29 11:13:41 -0500 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2011-07-06 16:24:12 -0600 |
commit | f308fcbbe2e6a0840804716822a4315ddefa45a2 (patch) | |
tree | 898ca2004d2f5e3a444dc0f2790ce0915f5042ca /numpy | |
parent | d9aaf6f66bf1777d9bd82304a123d90cfd653f4d (diff) | |
download | numpy-f308fcbbe2e6a0840804716822a4315ddefa45a2.tar.gz |
ENH: umath: Add tests, work out kinks in ufunc 'where=' parameter
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 17 | ||||
-rw-r--r-- | numpy/core/tests/test_ufunc.py | 14 |
2 files changed, 29 insertions, 2 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index cca8c9525..a89c0f235 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -1467,6 +1467,7 @@ execute_ufunc_masked_loop(PyUFuncObject *self, return -1; } op[nop] = wheremask; + dtype[nop] = NULL; } /* Set up the flags */ @@ -1483,6 +1484,8 @@ execute_ufunc_masked_loop(PyUFuncObject *self, } op_flags[nop] = NPY_ITER_READONLY; + NPY_UF_DBG_PRINT("Making iterator\n"); + /* * Allocate the iterator. Because the types of the inputs * were already checked, we use the casting rule 'unsafe' which @@ -1502,6 +1505,8 @@ execute_ufunc_masked_loop(PyUFuncObject *self, return -1; } + NPY_UF_DBG_PRINT("Made iterator\n"); + needs_api = NpyIter_IterationNeedsAPI(iter); /* Copy any allocated outputs */ @@ -1532,10 +1537,15 @@ execute_ufunc_masked_loop(PyUFuncObject *self, for (i = nin; i < nop; ++i) { baseptrs[i] = PyArray_BYTES(op[i]); } + if (wheremask != NULL) { + baseptrs[nop] = PyArray_BYTES(op[nop]); + } + NPY_UF_DBG_PRINT("reset base pointers call:\n"); if (NpyIter_ResetBasePointers(iter, baseptrs, NULL) != NPY_SUCCEED) { NpyIter_Deallocate(iter); return -1; } + NPY_UF_DBG_PRINT("finished reset base pointers call\n"); /* Get the variables needed for the loop */ iternext = NpyIter_GetIterNext(iter, NULL); @@ -1551,6 +1561,7 @@ execute_ufunc_masked_loop(PyUFuncObject *self, NPY_BEGIN_THREADS; } + NPY_UF_DBG_PRINT("Actual inner loop:\n"); /* Execute the loop */ do { NPY_UF_DBG_PRINT1("iterator loop count %d\n", (int)*count_ptr); @@ -2205,16 +2216,18 @@ PyUFunc_GenericFunction(PyUFuncObject *self, /* Start with the floating-point exception flags cleared */ PyUFunc_clearfperr(); - NPY_UF_DBG_PRINT("Executing inner loop\n"); - /* Do the ufunc loop */ if (usemaskedloop) { + NPY_UF_DBG_PRINT("Executing masked inner loop\n"); + retval = execute_ufunc_masked_loop(self, wheremask, op, dtype, order, buffersize, arr_prep, arr_prep_args, masked_innerloop, masked_innerloopdata); } else { + NPY_UF_DBG_PRINT("Executing unmasked inner loop\n"); + retval = execute_ufunc_loop(self, trivial_loop_ok, op, dtype, order, buffersize, arr_prep, arr_prep_args, diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index 3f6c8e0d2..cc5a79772 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -484,6 +484,20 @@ class TestUfunc(TestCase): np.subtract(a, 0, out=b) assert_equal(b, 0) + def test_where_param(self): + # Test that the where= ufunc parameter works with regular arrays + a = np.arange(7) + b = np.ones(7) + c = np.zeros(7) + np.add(a, b, out=c, where=(a % 2 == 1)) + assert_equal(c, [0,2,0,4,0,6,0]) + + a = np.arange(4).reshape(2,2) + 2 + np.power(a, [2,3], out=a, where=[[0,1],[1,0]]) + assert_equal(a, [[2, 27], [16, 5]]) + # Broadcasting the where= parameter + np.subtract(a, 2, out=a, where=[True,False]) + assert_equal(a, [[0, 27], [14, 5]]) if __name__ == "__main__": run_module_suite() |