summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/umath/ufunc_object.c48
1 files changed, 32 insertions, 16 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 0385dfba1..8718e1a26 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -1756,39 +1756,51 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
/*
* Validate the core dimensions of all the operands,
* and collect all of the labeled core dimension sizes
- * into the array 'core_dim_sizes'.
+ * into the array 'core_dim_sizes'. Initialize them to
+ * 1, for example in the case where the operand broadcasts
+ * to a core dimension, it won't be visited.
*/
for (i = 0; i < ufunc->core_num_dim_ix; ++i) {
- core_dim_sizes[i] = -1;
+ core_dim_sizes[i] = 1;
}
for (i = 0; i < nop; ++i) {
if (op[i] != NULL) {
int dim_offset = ufunc->core_offsets[i];
int num_dims = ufunc->core_num_dims[i];
int core_start_dim = PyArray_NDIM(op[i]) - num_dims;
- /* Make sure the operand has enough dimensions */
- if (core_start_dim < 0) {
+ /* Make sure any output operand has enough dimensions */
+ if (i >= nin && core_start_dim < 0) {
PyErr_Format(PyExc_ValueError,
- "%s: Operand %d does not have enough dimensions "
+ "%s: Output operand %d does not have enough dimensions "
"(has %d, gufunc core with signature %s "
"requires %d)",
- ufunc_name, i, PyArray_NDIM(op[i]),
+ ufunc_name, i - nin, PyArray_NDIM(op[i]),
ufunc->core_signature, num_dims);
retval = -1;
goto fail;
}
+
/*
* Make sure each core dimension matches all other core
* dimensions with the same label
+ *
+ * NOTE: For input operands, core_start_dim may be negative.
+ * In that case, the operand is being broadcast onto
+ * core dimensions. For example, a scalar will broadcast
+ * to fit any core signature.
*/
- for (idim = 0; idim < num_dims; ++idim) {
+ if (core_start_dim >= 0) {
+ idim = 0;
+ } else {
+ idim = -core_start_dim;
+ }
+ for (; idim < num_dims; ++idim) {
int core_dim_index = ufunc->core_dim_ixs[dim_offset + idim];
npy_intp op_dim_size =
PyArray_SHAPE(op[i])[core_start_dim + idim];
- if (core_dim_sizes[core_dim_index] == -1 ||
- core_dim_sizes[core_dim_index] == 1) {
+ if (core_dim_sizes[core_dim_index] == 1) {
core_dim_sizes[core_dim_index] = op_dim_size;
- } else if (op_dim_size != 1 &&
+ } else if ((i >= nin || op_dim_size != 1) &&
core_dim_sizes[core_dim_index] != op_dim_size) {
PyErr_Format(PyExc_ValueError,
"%s: Operand %d has a mismatch in its core "
@@ -1997,12 +2009,16 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
npy_intp *shape = PyArray_SHAPE(arr);
npy_intp *strides = PyArray_STRIDES(arr);
for (j = 0; j < num_dims; ++j) {
- /*
- * Force the stride to zero when the shape is 1, sot
- * that the broadcasting works right.
- */
- if (shape[core_start_dim + j] != 1) {
- inner_strides[idim++] = strides[core_start_dim + j];
+ if (core_start_dim + j >= 0) {
+ /*
+ * Force the stride to zero when the shape is 1, sot
+ * that the broadcasting works right.
+ */
+ if (shape[core_start_dim + j] != 1) {
+ inner_strides[idim++] = strides[core_start_dim + j];
+ } else {
+ inner_strides[idim++] = 0;
+ }
} else {
inner_strides[idim++] = 0;
}