diff options
-rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 164 |
1 files changed, 65 insertions, 99 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index e701f43d1..2b5b74c37 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -1895,136 +1895,102 @@ parse_operand_subscripts(char *subscripts, int length, /* - * Parses the subscripts for the output operand into an output - * that requires 'ndim_broadcast' unlabeled dimensions, returning - * the number of output dimensions. Returns -1 if there is an error. + * Parses the subscripts for the output operand into an output that + * includes 'ndim_broadcast' unlabeled dimensions, and returns the total + * number of output dimensions, or -1 if there is an error. Similarly + * to parse_operand_subscripts, the 'out_labels' array will have, for + * each dimension: + * - the ASCII code of the corresponding label; + * - zero for broadcast dimensions, if subscripts has an ellipsis. */ static int parse_output_subscripts(char *subscripts, int length, int ndim_broadcast, - const char *label_counts, - char *out_labels) + const char *label_counts, char *out_labels) { - int i, nlabels, label, idim, ndim, ndim_left; + int i; + int ndim = 0; int ellipsis = 0; - /* Count the labels, making sure they're all unique and valid */ - nlabels = 0; + /* Process all the output labels. */ for (i = 0; i < length; ++i) { - label = subscripts[i]; + int label = subscripts[i]; + + /* A proper label for an axis. */ if (label > 0 && isalpha(label)) { - /* Check if it occurs again */ - if (memchr(subscripts+i+1, label, length-i-1) == NULL) { - /* Check that it was used in the inputs */ + /* Check that it doesn't occur again. */ + if (memchr(subscripts + i + 1, label, length - i - 1) == NULL) { + /* Check that it was used in the inputs. */ if (label_counts[label] == 0) { PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string included " - "output subscript '%c' which never appeared " - "in an input", (char)label); + "einstein sum subscripts string included " + "output subscript '%c' which never appeared " + "in an input", (char)label); + return -1; + } + /* Check that there is room in out_labels for this label. */ + if (ndim < NPY_MAXDIMS) { + out_labels[ndim++] = label; + } + else { + PyErr_Format(PyExc_ValueError, + "einstein sum subscripts string contains " + "too many subscripts in the output"); return -1; } - - nlabels++; - } - else { - PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string includes " - "output subscript '%c' multiple times", - (char)label); - return -1; - } - } - else if (label != '.' && label != ' ') { - PyErr_Format(PyExc_ValueError, - "invalid subscript '%c' in einstein sum " - "subscripts string, subscripts must " - "be letters", (char)label); - return -1; - } - } - - /* The number of output dimensions */ - ndim = ndim_broadcast + nlabels; - - /* Process the labels from the end until the ellipsis */ - idim = ndim-1; - for (i = length-1; i >= 0; --i) { - label = subscripts[i]; - /* A label for an axis */ - if (label != '.' && label != ' ') { - if (idim >= 0) { - out_labels[idim--] = label; } else { PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains " - "too many output subscripts"); + "einstein sum subscripts string includes " + "output subscript '%c' multiple times", + (char)label); return -1; } } - /* The end of the ellipsis */ + /* The beginning of the ellipsis. */ else if (label == '.') { - /* A valid ellipsis */ - if (i >= 2 && subscripts[i-1] == '.' && subscripts[i-2] == '.') { + if (!ellipsis && i + 2 < length + && subscripts[++i] == '.' && subscripts[++i] == '.') { ellipsis = 1; - length = i-2; - break; - } - else { - PyErr_SetString(PyExc_ValueError, - "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...') " - "in the output"); - return -1; - - } - } - } - - if (!ellipsis && idim != -1) { - PyErr_SetString(PyExc_ValueError, - "output has more dimensions than subscripts " - "given in einstein sum, but no '...' ellipsis " - "provided to broadcast the extra dimensions."); - return 0; - } - - /* Reduce ndim to just the dimensions left to fill at the beginning */ - ndim_left = idim+1; - idim = 0; + /* Check there is room in out_labels for broadcast dims. */ + if (ndim + ndim_broadcast <= NPY_MAXDIMS) { + int bdim; - /* - * If we stopped because of an ellipsis, start again from the beginning. - * The length was truncated to end at the ellipsis in this case. - */ - if (i > 0) { - for (i = 0; i < length; ++i) { - label = subscripts[i]; - if (label == '.') { - PyErr_SetString(PyExc_ValueError, - "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...') " - "in the output"); - return -1; - } - /* A label for an axis */ - else if (label != ' ') { - if (idim < ndim_left) { - out_labels[idim++] = label; + for (bdim = 0; bdim < ndim_broadcast; ++bdim) { + out_labels[ndim++] = 0; + } } else { PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains " - "too many subscripts for the output"); + "einstein sum subscripts string contains " + "too many subscripts in the output"); return -1; } } + else { + PyErr_SetString(PyExc_ValueError, + "einstein sum subscripts string " + "contains a '.' that is not part of " + "an ellipsis ('...') in the output"); + return -1; + } + } + else if (label != ' ') { + PyErr_Format(PyExc_ValueError, + "invalid subscript '%c' in einstein sum " + "subscripts string, subscripts must " + "be letters", (char)label); + return -1; } } - /* Set the remaining output labels to 0 */ - while (idim < ndim_left) { - out_labels[idim++] = 0; + /* If no ellipsis was found there should be no broadcast dimensions. */ + if (!ellipsis && ndim_broadcast > 0) { + PyErr_SetString(PyExc_ValueError, + "output has more dimensions than subscripts " + "given in einstein sum, but no '...' ellipsis " + "provided to broadcast the extra dimensions."); + return -1; } return ndim; |