diff options
-rw-r--r-- | numpy/core/src/multiarray/lowlevel_strided_loops.h | 98 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 1 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 71 |
3 files changed, 146 insertions, 24 deletions
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.h b/numpy/core/src/multiarray/lowlevel_strided_loops.h index 34cdf3c39..e0745c4ea 100644 --- a/numpy/core/src/multiarray/lowlevel_strided_loops.h +++ b/numpy/core/src/multiarray/lowlevel_strided_loops.h @@ -286,6 +286,20 @@ PyArray_TransferStridedToNDim(npy_intp ndim, * // Create iterator, etc... * } */ + +/* + * Note: Equivalently iterable macro requires one of arr1 or arr2 be + * trivially iterable to be valid. + */ +#define PyArray_EQUIVALENTLY_ITERABLE(arr1, arr2) ( \ + PyArray_NDIM(arr1) == PyArray_NDIM(arr2) && \ + PyArray_CompareLists(PyArray_DIMS(arr1), \ + PyArray_DIMS(arr2), \ + PyArray_NDIM(arr1)) && \ + (arr1->flags&(NPY_CONTIGUOUS|NPY_FORTRAN)) == \ + (arr2->flags&(NPY_CONTIGUOUS|NPY_FORTRAN)) \ + ) + #define PyArray_TRIVIALLY_ITERABLE(arr) ( \ PyArray_NDIM(arr) <= 1 || \ PyArray_CHKFLAGS(arr, NPY_CONTIGUOUS) || \ @@ -302,30 +316,82 @@ PyArray_TransferStridedToNDim(npy_intp ndim, #define PyArray_TRIVIALLY_ITERABLE_PAIR(arr1, arr2) (\ PyArray_TRIVIALLY_ITERABLE(arr1) && \ - PyArray_NDIM(arr1) == PyArray_NDIM(arr2) && \ - PyArray_CompareLists(PyArray_DIMS(arr1), \ - PyArray_DIMS(arr2), \ - PyArray_NDIM(arr1)) && \ - PyArray_CHKFLAGS(arr1, NPY_CONTIGUOUS) == \ - PyArray_CHKFLAGS(arr2, NPY_CONTIGUOUS) && \ - PyArray_CHKFLAGS(arr1, NPY_FORTRAN) == \ - PyArray_CHKFLAGS(arr2, NPY_FORTRAN) \ + (PyArray_NDIM(arr2) == 0 || \ + PyArray_EQUIVALENTLY_ITERABLE(arr1, arr2) || \ + (PyArray_NDIM(arr1) == 0 && \ + PyArray_TRIVIALLY_ITERABLE(arr2) \ + ) \ + ) \ ) #define PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(arr1, arr2, \ count, \ data1, data2, \ - stride1, stride2) \ - count = PyArray_SIZE(arr1), \ - data1 = PyArray_BYTES(arr1), \ - data2 = PyArray_BYTES(arr2), \ - stride1 = ((PyArray_NDIM(arr1) == 0) ? 0 : \ + stride1, stride2) { \ + npy_intp size1 = PyArray_SIZE(arr1); \ + npy_intp size2 = PyArray_SIZE(arr2); \ + count = size1 > size2 ? size1 : size2; \ + data1 = PyArray_BYTES(arr1); \ + data2 = PyArray_BYTES(arr2); \ + stride1 = (size1 == 1 ? 0 : \ + (PyArray_CHKFLAGS(arr1, NPY_FORTRAN) ? \ + PyArray_STRIDE(arr1, 0) : \ + PyArray_STRIDE(arr1, \ + PyArray_NDIM(arr1)-1))); \ + stride2 = (size2 == 1 ? 0 : \ + (PyArray_CHKFLAGS(arr2, NPY_FORTRAN) ? \ + PyArray_STRIDE(arr2, 0) : \ + PyArray_STRIDE(arr2, \ + PyArray_NDIM(arr2)-1))); \ + } + +#define PyArray_TRIVIALLY_ITERABLE_TRIPLE(arr1, arr2, arr3) (\ + PyArray_TRIVIALLY_ITERABLE(arr1) && \ + ((PyArray_NDIM(arr2) == 0 && \ + (PyArray_NDIM(arr3) == 0 || \ + PyArray_EQUIVALENTLY_ITERABLE(arr1, arr3) \ + ) \ + ) || \ + (PyArray_EQUIVALENTLY_ITERABLE(arr1, arr2) && \ + (PyArray_NDIM(arr3) == 0 || \ + PyArray_EQUIVALENTLY_ITERABLE(arr1, arr3) \ + ) \ + ) || \ + (PyArray_NDIM(arr1) == 0 && \ + PyArray_TRIVIALLY_ITERABLE(arr2) && \ + (PyArray_NDIM(arr3) == 0 || \ + PyArray_EQUIVALENTLY_ITERABLE(arr2, arr3) \ + ) \ + ) \ + ) \ + ) + +#define PyArray_PREPARE_TRIVIAL_TRIPLE_ITERATION(arr1, arr2, arr3, \ + count, \ + data1, data2, data3, \ + stride1, stride2, stride3) { \ + npy_intp size1 = PyArray_SIZE(arr1); \ + npy_intp size2 = PyArray_SIZE(arr2); \ + npy_intp size3 = PyArray_SIZE(arr3); \ + count = size1 > size2 ? size1 : size2; \ + count = size3 > count ? size3 : count; \ + data1 = PyArray_BYTES(arr1); \ + data2 = PyArray_BYTES(arr2); \ + data2 = PyArray_BYTES(arr3); \ + stride1 = (size1 == 1 ? 0 : \ (PyArray_CHKFLAGS(arr1, NPY_FORTRAN) ? \ PyArray_STRIDE(arr1, 0) : \ PyArray_STRIDE(arr1, \ - PyArray_NDIM(arr1)-1))), \ - stride2 = ((PyArray_NDIM(arr2) == 0) ? 0 : \ + PyArray_NDIM(arr1)-1))); \ + stride2 = (size2 == 1 ? 0 : \ (PyArray_CHKFLAGS(arr2, NPY_FORTRAN) ? \ PyArray_STRIDE(arr2, 0) : \ PyArray_STRIDE(arr2, \ - PyArray_NDIM(arr2)-1))) + PyArray_NDIM(arr2)-1))); \ + stride3 = (size3 == 1 ? 0 : \ + (PyArray_CHKFLAGS(arr3, NPY_FORTRAN) ? \ + PyArray_STRIDE(arr3, 0) : \ + PyArray_STRIDE(arr3, \ + PyArray_NDIM(arr3)-1))); \ + } + #endif diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 28b57c8ce..9173bb0a6 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -773,7 +773,6 @@ static PyObject * array_astype(PyArrayObject *self, PyObject *args) { PyArray_Descr *descr = NULL; - PyObject *obj; if (!PyArg_ParseTuple(args, "O&", PyArray_DescrConverter, &descr)) { diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index f3415f723..1f78fb6f6 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -37,6 +37,7 @@ #include "numpy/noprefix.h" #include "numpy/ufuncobject.h" +#include "lowlevel_strided_loops.h" #include "ufunc_object.h" @@ -3509,7 +3510,9 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) /* This contains the all the inputs and outputs */ PyArrayObject *op[NPY_MAXARGS]; PyArray_Descr *dtype[NPY_MAXARGS]; - PyArray_Descr *result_type; + + PyUFuncGenericFunction innerloop = NULL; + void *innerloopdata = NULL; int no_castable_output, all_inputs_scalar; @@ -3538,7 +3541,6 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) op[i] = NULL; dtype[i] = NULL; } - result_type = NULL; /* Get input arguments */ all_inputs_scalar = 1; @@ -3692,7 +3694,8 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) * Determine the UFunc loop. This could in general be *much* faster, * and a better way to implement it might be for the ufunc to * provide a function which gives back the result type and inner - * loop function. + * loop function. A default (fast) mechanism could be provided for + * functions which follow the same pattern as add, subtract, etc. * * The method for finding the loop in the previous code did not * appear consistent (as noted by some asymmetry in the generated @@ -3750,6 +3753,9 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) for (j = 0; j < niter; ++j) { dtype[j] = PyArray_DescrFromType(types[j]); } + /* Save the inner loop and its data */ + innerloop = self->functions[i]; + innerloopdata = self->data[i]; break; } } @@ -3769,9 +3775,6 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) } return NULL; } - else { - - } printf("input types:\n"); @@ -3786,13 +3789,67 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds) } printf("\n"); + /* First check for the trivial cases that don't need an iterator */ + if (niter == 2 && op[1] == NULL && + (order == NPY_ANYORDER || order == NPY_KEEPORDER) && + PyArray_TRIVIALLY_ITERABLE(op[0])) { + Py_INCREF(dtype[1]); + op[1] = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type, + dtype[1], + PyArray_NDIM(op[0]), + PyArray_DIMS(op[0]), + NULL, NULL, + PyArray_ISFORTRAN(op[0]) ? NPY_F_CONTIGUOUS : 0, + NULL); + printf("trivial 1 input with allocated output\n"); + } + else if (niter == 2 && op[1] != NULL && + PyArray_NDIM(op[1]) >= PyArray_NDIM(op[0]) && + PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) { + printf("trivial 1 input\n"); + } + else if (nin == 2 && nout == 1 && op[2] == NULL && + (order == NPY_ANYORDER || order == NPY_KEEPORDER) && + PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) { + PyArrayObject *tmp; + /* + * Have to choose the input with more dimensions to clone, as + * one of them could be a scalar. + */ + if (PyArray_NDIM(op[0]) >= PyArray_NDIM(op[1])) { + tmp = op[0]; + } + else { + tmp = op[1]; + } + Py_INCREF(dtype[2]); + op[2] = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type, + dtype[2], + PyArray_NDIM(tmp), + PyArray_DIMS(tmp), + NULL, NULL, + PyArray_ISFORTRAN(tmp) ? NPY_F_CONTIGUOUS : 0, + NULL); + printf("trivial 2 input with allocated output\n"); + } + else if (niter == 3 && op[1] != NULL && op[2] != NULL && + PyArray_NDIM(op[2]) >= PyArray_NDIM(op[0]) && + PyArray_NDIM(op[2]) >= PyArray_NDIM(op[1]) && + PyArray_TRIVIALLY_ITERABLE_TRIPLE(op[0], op[1], op[2])) { + printf("trivial 2 input\n"); + } + /* Otherwise an iterator is required to resolve broadcasting, etc */ + else { + printf("iterator case\n"); + } + + PyErr_SetString(PyExc_RuntimeError, "implementation is not finished!"); fail: for (i = 0; i < niter; ++i) { Py_XDECREF(op[i]); Py_XDECREF(dtype[i]); } - Py_XDECREF(result_type); return NULL; } |