diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 49 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 15 |
2 files changed, 54 insertions, 10 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index bea15c65c..7ec5b977f 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -753,7 +753,7 @@ static int get_ufunc_arguments(PyUFuncObject *ufunc, /* Get input arguments */ for (i = 0; i < nin; ++i) { obj = PyTuple_GET_ITEM(args, i); - + if (PyArray_Check(obj)) { out_op[i] = (PyArrayObject *)PyArray_FromArray(obj,NULL,0); } @@ -1236,10 +1236,10 @@ iterator_loop(PyUFuncObject *ufunc, for (i = 0; i < nin; ++i) { op_flags[i] = NPY_ITER_READONLY | NPY_ITER_ALIGNED; - /* + /* * If READWRITE flag has been set for this operand, * then clear default READONLY flag - */ + */ op_flags[i] |= ufunc->op_flags[i]; if (op_flags[i] & (NPY_ITER_READWRITE | NPY_ITER_WRITEONLY)) { op_flags[i] &= ~NPY_ITER_READONLY; @@ -1541,10 +1541,10 @@ execute_fancy_ufunc_loop(PyUFuncObject *ufunc, op_flags[i] = default_op_in_flags | NPY_ITER_READONLY | NPY_ITER_ALIGNED; - /* + /* * If READWRITE flag has been set for this operand, * then clear default READONLY flag - */ + */ op_flags[i] |= ufunc->op_flags[i]; if (op_flags[i] & (NPY_ITER_READWRITE | NPY_ITER_WRITEONLY)) { op_flags[i] &= ~NPY_ITER_READONLY; @@ -2037,10 +2037,10 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc, op_flags[i] = NPY_ITER_READONLY | NPY_ITER_COPY | NPY_ITER_ALIGNED; - /* + /* * If READWRITE flag has been set for this operand, * then clear default READONLY flag - */ + */ op_flags[i] |= ufunc->op_flags[i]; if (op_flags[i] & (NPY_ITER_READWRITE | NPY_ITER_WRITEONLY)) { op_flags[i] &= ~NPY_ITER_READONLY; @@ -3337,12 +3337,40 @@ PyUFunc_Reduceat(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *ind, op[1] = arr; op[2] = ind; - /* Likewise with accumulate, must do UPDATEIFCOPY */ if (out != NULL || ndim > 1 || !PyArray_ISALIGNED(arr) || !PyArray_EquivTypes(op_dtypes[0], PyArray_DESCR(arr))) { need_outer_iterator = 1; } + /* Special case when the index array's size is zero */ + if (ind_size == 0) { + if (out == NULL) { + npy_intp out_shape[NPY_MAXDIMS]; + memcpy(out_shape, PyArray_SHAPE(arr), + PyArray_NDIM(arr) * NPY_SIZEOF_INTP); + out_shape[axis] = 0; + Py_INCREF(op_dtypes[0]); + op[0] = out = (PyArrayObject *)PyArray_NewFromDescr( + &PyArray_Type, op_dtypes[0], + PyArray_NDIM(arr), out_shape, NULL, NULL, + 0, NULL); + if (out == NULL) { + goto fail; + } + } + else { + /* Allow any zero-sized output array in this case */ + if (PyArray_SIZE(out) != 0) { + PyErr_SetString(PyExc_ValueError, + "output operand shape for reduceat is " + "incompatible with index array of shape (0,)"); + goto fail; + } + } + + goto finish; + } + if (need_outer_iterator) { npy_uint32 flags = NPY_ITER_ZEROSIZE_OK| NPY_ITER_REFS_OK| @@ -3350,7 +3378,8 @@ PyUFunc_Reduceat(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *ind, /* * The way reduceat is set up, we can't do buffering, - * so make a copy instead when necessary. + * so make a copy instead when necessary using + * the UPDATEIFCOPY flag */ /* The per-operand flags for the outer loop */ @@ -4488,7 +4517,7 @@ PyUFunc_RegisterLoopForDescr(PyUFuncObject *ufunc, arg_typenums[i] = user_dtype->type_num; } } - + result = PyUFunc_RegisterLoopForType(ufunc, user_dtype->type_num, function, arg_typenums, data); diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index c2304a748..b2d47d052 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -1263,6 +1263,21 @@ def test_reduceat(): np.setbufsize(np.UFUNC_BUFSIZE_DEFAULT) assert_array_almost_equal(h1, h2) +def test_reduceat_empty(): + """Reduceat should work with empty arrays""" + indices = np.array([], 'i4') + x = np.array([], 'f8') + result = np.add.reduceat(x, indices) + assert_equal(result.dtype, x.dtype) + assert_equal(result.shape, (0,)) + # Another case with a slightly different zero-sized shape + x = np.ones((5,2)) + result = np.add.reduceat(x, [], axis=0) + assert_equal(result.dtype, x.dtype) + assert_equal(result.shape, (0, 2)) + result = np.add.reduceat(x, [], axis=1) + assert_equal(result.dtype, x.dtype) + assert_equal(result.shape, (5, 0)) def test_complex_nan_comparisons(): nans = [complex(np.nan, 0), complex(0, np.nan), complex(np.nan, np.nan)] |