diff options
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 48 |
1 files changed, 32 insertions, 16 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 0385dfba1..8718e1a26 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -1756,39 +1756,51 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc, /* * Validate the core dimensions of all the operands, * and collect all of the labeled core dimension sizes - * into the array 'core_dim_sizes'. + * into the array 'core_dim_sizes'. Initialize them to + * 1, for example in the case where the operand broadcasts + * to a core dimension, it won't be visited. */ for (i = 0; i < ufunc->core_num_dim_ix; ++i) { - core_dim_sizes[i] = -1; + core_dim_sizes[i] = 1; } for (i = 0; i < nop; ++i) { if (op[i] != NULL) { int dim_offset = ufunc->core_offsets[i]; int num_dims = ufunc->core_num_dims[i]; int core_start_dim = PyArray_NDIM(op[i]) - num_dims; - /* Make sure the operand has enough dimensions */ - if (core_start_dim < 0) { + /* Make sure any output operand has enough dimensions */ + if (i >= nin && core_start_dim < 0) { PyErr_Format(PyExc_ValueError, - "%s: Operand %d does not have enough dimensions " + "%s: Output operand %d does not have enough dimensions " "(has %d, gufunc core with signature %s " "requires %d)", - ufunc_name, i, PyArray_NDIM(op[i]), + ufunc_name, i - nin, PyArray_NDIM(op[i]), ufunc->core_signature, num_dims); retval = -1; goto fail; } + /* * Make sure each core dimension matches all other core * dimensions with the same label + * + * NOTE: For input operands, core_start_dim may be negative. + * In that case, the operand is being broadcast onto + * core dimensions. For example, a scalar will broadcast + * to fit any core signature. */ - for (idim = 0; idim < num_dims; ++idim) { + if (core_start_dim >= 0) { + idim = 0; + } else { + idim = -core_start_dim; + } + for (; idim < num_dims; ++idim) { int core_dim_index = ufunc->core_dim_ixs[dim_offset + idim]; npy_intp op_dim_size = PyArray_SHAPE(op[i])[core_start_dim + idim]; - if (core_dim_sizes[core_dim_index] == -1 || - core_dim_sizes[core_dim_index] == 1) { + if (core_dim_sizes[core_dim_index] == 1) { core_dim_sizes[core_dim_index] = op_dim_size; - } else if (op_dim_size != 1 && + } else if ((i >= nin || op_dim_size != 1) && core_dim_sizes[core_dim_index] != op_dim_size) { PyErr_Format(PyExc_ValueError, "%s: Operand %d has a mismatch in its core " @@ -1997,12 +2009,16 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc, npy_intp *shape = PyArray_SHAPE(arr); npy_intp *strides = PyArray_STRIDES(arr); for (j = 0; j < num_dims; ++j) { - /* - * Force the stride to zero when the shape is 1, sot - * that the broadcasting works right. - */ - if (shape[core_start_dim + j] != 1) { - inner_strides[idim++] = strides[core_start_dim + j]; + if (core_start_dim + j >= 0) { + /* + * Force the stride to zero when the shape is 1, sot + * that the broadcasting works right. + */ + if (shape[core_start_dim + j] != 1) { + inner_strides[idim++] = strides[core_start_dim + j]; + } else { + inner_strides[idim++] = 0; + } } else { inner_strides[idim++] = 0; } |