summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/einsum.c.src164
1 files changed, 65 insertions, 99 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index e701f43d1..2b5b74c37 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -1895,136 +1895,102 @@ parse_operand_subscripts(char *subscripts, int length,
/*
- * Parses the subscripts for the output operand into an output
- * that requires 'ndim_broadcast' unlabeled dimensions, returning
- * the number of output dimensions. Returns -1 if there is an error.
+ * Parses the subscripts for the output operand into an output that
+ * includes 'ndim_broadcast' unlabeled dimensions, and returns the total
+ * number of output dimensions, or -1 if there is an error. Similarly
+ * to parse_operand_subscripts, the 'out_labels' array will have, for
+ * each dimension:
+ * - the ASCII code of the corresponding label;
+ * - zero for broadcast dimensions, if subscripts has an ellipsis.
*/
static int
parse_output_subscripts(char *subscripts, int length,
int ndim_broadcast,
- const char *label_counts,
- char *out_labels)
+ const char *label_counts, char *out_labels)
{
- int i, nlabels, label, idim, ndim, ndim_left;
+ int i;
+ int ndim = 0;
int ellipsis = 0;
- /* Count the labels, making sure they're all unique and valid */
- nlabels = 0;
+ /* Process all the output labels. */
for (i = 0; i < length; ++i) {
- label = subscripts[i];
+ int label = subscripts[i];
+
+ /* A proper label for an axis. */
if (label > 0 && isalpha(label)) {
- /* Check if it occurs again */
- if (memchr(subscripts+i+1, label, length-i-1) == NULL) {
- /* Check that it was used in the inputs */
+ /* Check that it doesn't occur again. */
+ if (memchr(subscripts + i + 1, label, length - i - 1) == NULL) {
+ /* Check that it was used in the inputs. */
if (label_counts[label] == 0) {
PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string included "
- "output subscript '%c' which never appeared "
- "in an input", (char)label);
+ "einstein sum subscripts string included "
+ "output subscript '%c' which never appeared "
+ "in an input", (char)label);
+ return -1;
+ }
+ /* Check that there is room in out_labels for this label. */
+ if (ndim < NPY_MAXDIMS) {
+ out_labels[ndim++] = label;
+ }
+ else {
+ PyErr_Format(PyExc_ValueError,
+ "einstein sum subscripts string contains "
+ "too many subscripts in the output");
return -1;
}
-
- nlabels++;
- }
- else {
- PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string includes "
- "output subscript '%c' multiple times",
- (char)label);
- return -1;
- }
- }
- else if (label != '.' && label != ' ') {
- PyErr_Format(PyExc_ValueError,
- "invalid subscript '%c' in einstein sum "
- "subscripts string, subscripts must "
- "be letters", (char)label);
- return -1;
- }
- }
-
- /* The number of output dimensions */
- ndim = ndim_broadcast + nlabels;
-
- /* Process the labels from the end until the ellipsis */
- idim = ndim-1;
- for (i = length-1; i >= 0; --i) {
- label = subscripts[i];
- /* A label for an axis */
- if (label != '.' && label != ' ') {
- if (idim >= 0) {
- out_labels[idim--] = label;
}
else {
PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string contains "
- "too many output subscripts");
+ "einstein sum subscripts string includes "
+ "output subscript '%c' multiple times",
+ (char)label);
return -1;
}
}
- /* The end of the ellipsis */
+ /* The beginning of the ellipsis. */
else if (label == '.') {
- /* A valid ellipsis */
- if (i >= 2 && subscripts[i-1] == '.' && subscripts[i-2] == '.') {
+ if (!ellipsis && i + 2 < length
+ && subscripts[++i] == '.' && subscripts[++i] == '.') {
ellipsis = 1;
- length = i-2;
- break;
- }
- else {
- PyErr_SetString(PyExc_ValueError,
- "einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...') "
- "in the output");
- return -1;
-
- }
- }
- }
-
- if (!ellipsis && idim != -1) {
- PyErr_SetString(PyExc_ValueError,
- "output has more dimensions than subscripts "
- "given in einstein sum, but no '...' ellipsis "
- "provided to broadcast the extra dimensions.");
- return 0;
- }
-
- /* Reduce ndim to just the dimensions left to fill at the beginning */
- ndim_left = idim+1;
- idim = 0;
+ /* Check there is room in out_labels for broadcast dims. */
+ if (ndim + ndim_broadcast <= NPY_MAXDIMS) {
+ int bdim;
- /*
- * If we stopped because of an ellipsis, start again from the beginning.
- * The length was truncated to end at the ellipsis in this case.
- */
- if (i > 0) {
- for (i = 0; i < length; ++i) {
- label = subscripts[i];
- if (label == '.') {
- PyErr_SetString(PyExc_ValueError,
- "einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...') "
- "in the output");
- return -1;
- }
- /* A label for an axis */
- else if (label != ' ') {
- if (idim < ndim_left) {
- out_labels[idim++] = label;
+ for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
+ out_labels[ndim++] = 0;
+ }
}
else {
PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string contains "
- "too many subscripts for the output");
+ "einstein sum subscripts string contains "
+ "too many subscripts in the output");
return -1;
}
}
+ else {
+ PyErr_SetString(PyExc_ValueError,
+ "einstein sum subscripts string "
+ "contains a '.' that is not part of "
+ "an ellipsis ('...') in the output");
+ return -1;
+ }
+ }
+ else if (label != ' ') {
+ PyErr_Format(PyExc_ValueError,
+ "invalid subscript '%c' in einstein sum "
+ "subscripts string, subscripts must "
+ "be letters", (char)label);
+ return -1;
}
}
- /* Set the remaining output labels to 0 */
- while (idim < ndim_left) {
- out_labels[idim++] = 0;
+ /* If no ellipsis was found there should be no broadcast dimensions. */
+ if (!ellipsis && ndim_broadcast > 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "output has more dimensions than subscripts "
+ "given in einstein sum, but no '...' ellipsis "
+ "provided to broadcast the extra dimensions.");
+ return -1;
}
return ndim;