summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/lowlevel_strided_loops.h98
-rw-r--r--numpy/core/src/multiarray/methods.c1
-rw-r--r--numpy/core/src/umath/ufunc_object.c71
3 files changed, 146 insertions, 24 deletions
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.h b/numpy/core/src/multiarray/lowlevel_strided_loops.h
index 34cdf3c39..e0745c4ea 100644
--- a/numpy/core/src/multiarray/lowlevel_strided_loops.h
+++ b/numpy/core/src/multiarray/lowlevel_strided_loops.h
@@ -286,6 +286,20 @@ PyArray_TransferStridedToNDim(npy_intp ndim,
* // Create iterator, etc...
* }
*/
+
+/*
+ * Note: Equivalently iterable macro requires one of arr1 or arr2 be
+ * trivially iterable to be valid.
+ */
+#define PyArray_EQUIVALENTLY_ITERABLE(arr1, arr2) ( \
+ PyArray_NDIM(arr1) == PyArray_NDIM(arr2) && \
+ PyArray_CompareLists(PyArray_DIMS(arr1), \
+ PyArray_DIMS(arr2), \
+ PyArray_NDIM(arr1)) && \
+ (arr1->flags&(NPY_CONTIGUOUS|NPY_FORTRAN)) == \
+ (arr2->flags&(NPY_CONTIGUOUS|NPY_FORTRAN)) \
+ )
+
#define PyArray_TRIVIALLY_ITERABLE(arr) ( \
PyArray_NDIM(arr) <= 1 || \
PyArray_CHKFLAGS(arr, NPY_CONTIGUOUS) || \
@@ -302,30 +316,82 @@ PyArray_TransferStridedToNDim(npy_intp ndim,
#define PyArray_TRIVIALLY_ITERABLE_PAIR(arr1, arr2) (\
PyArray_TRIVIALLY_ITERABLE(arr1) && \
- PyArray_NDIM(arr1) == PyArray_NDIM(arr2) && \
- PyArray_CompareLists(PyArray_DIMS(arr1), \
- PyArray_DIMS(arr2), \
- PyArray_NDIM(arr1)) && \
- PyArray_CHKFLAGS(arr1, NPY_CONTIGUOUS) == \
- PyArray_CHKFLAGS(arr2, NPY_CONTIGUOUS) && \
- PyArray_CHKFLAGS(arr1, NPY_FORTRAN) == \
- PyArray_CHKFLAGS(arr2, NPY_FORTRAN) \
+ (PyArray_NDIM(arr2) == 0 || \
+ PyArray_EQUIVALENTLY_ITERABLE(arr1, arr2) || \
+ (PyArray_NDIM(arr1) == 0 && \
+ PyArray_TRIVIALLY_ITERABLE(arr2) \
+ ) \
+ ) \
)
#define PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(arr1, arr2, \
count, \
data1, data2, \
- stride1, stride2) \
- count = PyArray_SIZE(arr1), \
- data1 = PyArray_BYTES(arr1), \
- data2 = PyArray_BYTES(arr2), \
- stride1 = ((PyArray_NDIM(arr1) == 0) ? 0 : \
+ stride1, stride2) { \
+ npy_intp size1 = PyArray_SIZE(arr1); \
+ npy_intp size2 = PyArray_SIZE(arr2); \
+ count = size1 > size2 ? size1 : size2; \
+ data1 = PyArray_BYTES(arr1); \
+ data2 = PyArray_BYTES(arr2); \
+ stride1 = (size1 == 1 ? 0 : \
+ (PyArray_CHKFLAGS(arr1, NPY_FORTRAN) ? \
+ PyArray_STRIDE(arr1, 0) : \
+ PyArray_STRIDE(arr1, \
+ PyArray_NDIM(arr1)-1))); \
+ stride2 = (size2 == 1 ? 0 : \
+ (PyArray_CHKFLAGS(arr2, NPY_FORTRAN) ? \
+ PyArray_STRIDE(arr2, 0) : \
+ PyArray_STRIDE(arr2, \
+ PyArray_NDIM(arr2)-1))); \
+ }
+
+#define PyArray_TRIVIALLY_ITERABLE_TRIPLE(arr1, arr2, arr3) (\
+ PyArray_TRIVIALLY_ITERABLE(arr1) && \
+ ((PyArray_NDIM(arr2) == 0 && \
+ (PyArray_NDIM(arr3) == 0 || \
+ PyArray_EQUIVALENTLY_ITERABLE(arr1, arr3) \
+ ) \
+ ) || \
+ (PyArray_EQUIVALENTLY_ITERABLE(arr1, arr2) && \
+ (PyArray_NDIM(arr3) == 0 || \
+ PyArray_EQUIVALENTLY_ITERABLE(arr1, arr3) \
+ ) \
+ ) || \
+ (PyArray_NDIM(arr1) == 0 && \
+ PyArray_TRIVIALLY_ITERABLE(arr2) && \
+ (PyArray_NDIM(arr3) == 0 || \
+ PyArray_EQUIVALENTLY_ITERABLE(arr2, arr3) \
+ ) \
+ ) \
+ ) \
+ )
+
+#define PyArray_PREPARE_TRIVIAL_TRIPLE_ITERATION(arr1, arr2, arr3, \
+ count, \
+ data1, data2, data3, \
+ stride1, stride2, stride3) { \
+ npy_intp size1 = PyArray_SIZE(arr1); \
+ npy_intp size2 = PyArray_SIZE(arr2); \
+ npy_intp size3 = PyArray_SIZE(arr3); \
+ count = size1 > size2 ? size1 : size2; \
+ count = size3 > count ? size3 : count; \
+ data1 = PyArray_BYTES(arr1); \
+ data2 = PyArray_BYTES(arr2); \
+ data2 = PyArray_BYTES(arr3); \
+ stride1 = (size1 == 1 ? 0 : \
(PyArray_CHKFLAGS(arr1, NPY_FORTRAN) ? \
PyArray_STRIDE(arr1, 0) : \
PyArray_STRIDE(arr1, \
- PyArray_NDIM(arr1)-1))), \
- stride2 = ((PyArray_NDIM(arr2) == 0) ? 0 : \
+ PyArray_NDIM(arr1)-1))); \
+ stride2 = (size2 == 1 ? 0 : \
(PyArray_CHKFLAGS(arr2, NPY_FORTRAN) ? \
PyArray_STRIDE(arr2, 0) : \
PyArray_STRIDE(arr2, \
- PyArray_NDIM(arr2)-1)))
+ PyArray_NDIM(arr2)-1))); \
+ stride3 = (size3 == 1 ? 0 : \
+ (PyArray_CHKFLAGS(arr3, NPY_FORTRAN) ? \
+ PyArray_STRIDE(arr3, 0) : \
+ PyArray_STRIDE(arr3, \
+ PyArray_NDIM(arr3)-1))); \
+ }
+
#endif
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 28b57c8ce..9173bb0a6 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -773,7 +773,6 @@ static PyObject *
array_astype(PyArrayObject *self, PyObject *args)
{
PyArray_Descr *descr = NULL;
- PyObject *obj;
if (!PyArg_ParseTuple(args, "O&", PyArray_DescrConverter,
&descr)) {
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index f3415f723..1f78fb6f6 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -37,6 +37,7 @@
#include "numpy/noprefix.h"
#include "numpy/ufuncobject.h"
+#include "lowlevel_strided_loops.h"
#include "ufunc_object.h"
@@ -3509,7 +3510,9 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
/* This contains the all the inputs and outputs */
PyArrayObject *op[NPY_MAXARGS];
PyArray_Descr *dtype[NPY_MAXARGS];
- PyArray_Descr *result_type;
+
+ PyUFuncGenericFunction innerloop = NULL;
+ void *innerloopdata = NULL;
int no_castable_output, all_inputs_scalar;
@@ -3538,7 +3541,6 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
op[i] = NULL;
dtype[i] = NULL;
}
- result_type = NULL;
/* Get input arguments */
all_inputs_scalar = 1;
@@ -3692,7 +3694,8 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
* Determine the UFunc loop. This could in general be *much* faster,
* and a better way to implement it might be for the ufunc to
* provide a function which gives back the result type and inner
- * loop function.
+ * loop function. A default (fast) mechanism could be provided for
+ * functions which follow the same pattern as add, subtract, etc.
*
* The method for finding the loop in the previous code did not
* appear consistent (as noted by some asymmetry in the generated
@@ -3750,6 +3753,9 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
for (j = 0; j < niter; ++j) {
dtype[j] = PyArray_DescrFromType(types[j]);
}
+ /* Save the inner loop and its data */
+ innerloop = self->functions[i];
+ innerloopdata = self->data[i];
break;
}
}
@@ -3769,9 +3775,6 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
}
return NULL;
}
- else {
-
- }
printf("input types:\n");
@@ -3786,13 +3789,67 @@ ufunc_generic_call_iter(PyUFuncObject *self, PyObject *args, PyObject *kwds)
}
printf("\n");
+ /* First check for the trivial cases that don't need an iterator */
+ if (niter == 2 && op[1] == NULL &&
+ (order == NPY_ANYORDER || order == NPY_KEEPORDER) &&
+ PyArray_TRIVIALLY_ITERABLE(op[0])) {
+ Py_INCREF(dtype[1]);
+ op[1] = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
+ dtype[1],
+ PyArray_NDIM(op[0]),
+ PyArray_DIMS(op[0]),
+ NULL, NULL,
+ PyArray_ISFORTRAN(op[0]) ? NPY_F_CONTIGUOUS : 0,
+ NULL);
+ printf("trivial 1 input with allocated output\n");
+ }
+ else if (niter == 2 && op[1] != NULL &&
+ PyArray_NDIM(op[1]) >= PyArray_NDIM(op[0]) &&
+ PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) {
+ printf("trivial 1 input\n");
+ }
+ else if (nin == 2 && nout == 1 && op[2] == NULL &&
+ (order == NPY_ANYORDER || order == NPY_KEEPORDER) &&
+ PyArray_TRIVIALLY_ITERABLE_PAIR(op[0], op[1])) {
+ PyArrayObject *tmp;
+ /*
+ * Have to choose the input with more dimensions to clone, as
+ * one of them could be a scalar.
+ */
+ if (PyArray_NDIM(op[0]) >= PyArray_NDIM(op[1])) {
+ tmp = op[0];
+ }
+ else {
+ tmp = op[1];
+ }
+ Py_INCREF(dtype[2]);
+ op[2] = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
+ dtype[2],
+ PyArray_NDIM(tmp),
+ PyArray_DIMS(tmp),
+ NULL, NULL,
+ PyArray_ISFORTRAN(tmp) ? NPY_F_CONTIGUOUS : 0,
+ NULL);
+ printf("trivial 2 input with allocated output\n");
+ }
+ else if (niter == 3 && op[1] != NULL && op[2] != NULL &&
+ PyArray_NDIM(op[2]) >= PyArray_NDIM(op[0]) &&
+ PyArray_NDIM(op[2]) >= PyArray_NDIM(op[1]) &&
+ PyArray_TRIVIALLY_ITERABLE_TRIPLE(op[0], op[1], op[2])) {
+ printf("trivial 2 input\n");
+ }
+ /* Otherwise an iterator is required to resolve broadcasting, etc */
+ else {
+ printf("iterator case\n");
+ }
+
+
PyErr_SetString(PyExc_RuntimeError, "implementation is not finished!");
fail:
for (i = 0; i < niter; ++i) {
Py_XDECREF(op[i]);
Py_XDECREF(dtype[i]);
}
- Py_XDECREF(result_type);
return NULL;
}