diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2018-05-29 00:57:18 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-05-29 00:57:18 -0700 |
commit | 5075933290a59153ad899f382b93e6af2154a7cf (patch) | |
tree | b2db18967b926681dd68dcbd43eba549d5f257f3 | |
parent | 6ccb03d5e0bfe3cc9fd2c69e0d9606447c358a2d (diff) | |
parent | c2d59257b0d5e24d9f3811df5c314d525acfc0cb (diff) | |
download | numpy-5075933290a59153ad899f382b93e6af2154a7cf.tar.gz |
Merge pull request #11095 from jaimefrio/einsum_cleanup
MAINT: Einsum argument parsing cleanup
-rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 396 |
1 files changed, 151 insertions, 245 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 470a5fff9..167e3fbdc 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -1776,138 +1776,94 @@ get_sum_of_products_function(int nop, int type_num, return _unspecialized_table[type_num][nop <= 3 ? nop : 0]; } + /* - * Parses the subscripts for one operand into an output - * of 'ndim' labels + * Parses the subscripts for one operand into an output of 'ndim' + * labels. The resulting 'op_labels' array will have: + * - the ASCII code of the label for the first occurrence of a label; + * - the (negative) offset to the first occurrence of the label for + * repeated labels; + * - zero for broadcast dimensions, if subscripts has an ellipsis. + * For example: + * - subscripts="abbcbc", ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2] + * - subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99] */ + static int parse_operand_subscripts(char *subscripts, int length, - int ndim, - int iop, char *out_labels, - char *out_label_counts, - int *out_min_label, - int *out_max_label, - int *out_num_labels) + int ndim, int iop, char *op_labels, + char *label_counts, int *min_label, int *max_label) { - int i, idim, ndim_left, label; - int ellipsis = 0; + int i; + int idim = 0; + int ellipsis = -1; - /* Process the labels from the end until the ellipsis */ - idim = ndim-1; - for (i = length-1; i >= 0; --i) { - label = subscripts[i]; - /* A label for an axis */ + /* Process all labels for this operand */ + for (i = 0; i < length; ++i) { + int label = subscripts[i]; + + /* A proper label for an axis. */ if (label > 0 && isalpha(label)) { - if (idim >= 0) { - out_labels[idim--] = label; - /* Calculate the min and max labels */ - if (label < *out_min_label) { - *out_min_label = label; - } - if (label > *out_max_label) { - *out_max_label = label; - } - /* If it's the first time we see this label, count it */ - if (out_label_counts[label] == 0) { - (*out_num_labels)++; - } - out_label_counts[label]++; - } - else { + /* Check we don't exceed the operator dimensions. */ + if (idim >= ndim) { PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains " - "too many subscripts for operand %d", iop); + "einstein sum subscripts string contains " + "too many subscripts for operand %d", iop); return 0; } + + op_labels[idim++] = label; + if (label < *min_label) { + *min_label = label; + } + if (label > *max_label) { + *max_label = label; + } + label_counts[label]++; } - /* The end of the ellipsis */ + /* The beginning of the ellipsis. */ else if (label == '.') { - /* A valid ellipsis */ - if (i >= 2 && subscripts[i-1] == '.' && subscripts[i-2] == '.') { - ellipsis = 1; - length = i-2; - break; - } - else { + /* Check it's a proper ellipsis. */ + if (ellipsis != -1 || i + 2 >= length + || subscripts[++i] != '.' || subscripts[++i] != '.') { PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...') in " - "operand %d", iop); + "einstein sum subscripts string contains a " + "'.' that is not part of an ellipsis ('...') " + "in operand %d", iop); return 0; - } + + ellipsis = idim; } else if (label != ' ') { PyErr_Format(PyExc_ValueError, - "invalid subscript '%c' in einstein sum " - "subscripts string, subscripts must " - "be letters", (char)label); + "invalid subscript '%c' in einstein sum " + "subscripts string, subscripts must " + "be letters", (char)label); return 0; } } - if (!ellipsis && idim != -1) { - PyErr_Format(PyExc_ValueError, - "operand has more dimensions than subscripts " - "given in einstein sum, but no '...' ellipsis " - "provided to broadcast the extra dimensions."); - return 0; - } - - /* Reduce ndim to just the dimensions left to fill at the beginning */ - ndim_left = idim+1; - idim = 0; - - /* - * If we stopped because of an ellipsis, start again from the beginning. - * The length was truncated to end at the ellipsis in this case. - */ - if (i > 0) { - for (i = 0; i < length; ++i) { - label = subscripts[i]; - /* A label for an axis */ - if (label > 0 && isalnum(label)) { - if (idim < ndim_left) { - out_labels[idim++] = label; - /* Calculate the min and max labels */ - if (label < *out_min_label) { - *out_min_label = label; - } - if (label > *out_max_label) { - *out_max_label = label; - } - /* If it's the first time we see this label, count it */ - if (out_label_counts[label] == 0) { - (*out_num_labels)++; - } - out_label_counts[label]++; - } - else { - PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains " - "too many subscripts for operand %d", iop); - return 0; - } - } - else if (label == '.') { - PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...') in " - "operand %d", iop); - } - else if (label != ' ') { - PyErr_Format(PyExc_ValueError, - "invalid subscript '%c' in einstein sum " - "subscripts string, subscripts must " - "be letters", (char)label); - return 0; - } + /* No ellipsis found, labels must match dimensions exactly. */ + if (ellipsis == -1) { + if (idim != ndim) { + PyErr_Format(PyExc_ValueError, + "operand has more dimensions than subscripts " + "given in einstein sum, but no '...' ellipsis " + "provided to broadcast the extra dimensions."); + return 0; } } - - /* Set the remaining labels to 0 */ - while (idim < ndim_left) { - out_labels[idim++] = 0; + /* Ellipsis found, may have to add broadcast dimensions. */ + else if (idim < ndim) { + /* Move labels after ellipsis to the end. */ + for (i = 0; i < idim - ellipsis; ++i) { + op_labels[ndim - i - 1] = op_labels[idim - i - 1]; + } + /* Set all broadcast dimensions to zero. */ + for (i = 0; i < ndim - idim; ++i) { + op_labels[ellipsis + i] = 0; + } } /* @@ -1918,20 +1874,18 @@ parse_operand_subscripts(char *subscripts, int length, * twos complement arithmetic the char is ok either way here, and * later where it matters the char is cast to a signed char. */ - for (idim = 0; idim < ndim-1; ++idim) { - char *next; - /* If this is a proper label, find any duplicates of it */ - label = out_labels[idim]; + for (idim = 0; idim < ndim - 1; ++idim) { + int label = op_labels[idim]; + /* If it is a proper label, find any duplicates of it. */ if (label > 0) { - /* Search for the next matching label */ - next = (char *)memchr(out_labels+idim+1, label, - ndim-idim-1); + /* Search for the next matching label. */ + char *next = memchr(op_labels + idim + 1, label, ndim - idim - 1); + while (next != NULL) { - /* The offset from next to out_labels[idim] (negative) */ - *next = (char)((out_labels+idim)-next); - /* Search for the next matching label */ - next = (char *)memchr(next+1, label, - out_labels+ndim-1-next); + /* The offset from next to op_labels[idim] (negative). */ + *next = (char)((op_labels + idim) - next); + /* Search for the next matching label. */ + next = memchr(next + 1, label, op_labels + ndim - 1 - next); } } } @@ -1939,137 +1893,97 @@ parse_operand_subscripts(char *subscripts, int length, return 1; } + /* - * Parses the subscripts for the output operand into an output - * that requires 'ndim_broadcast' unlabeled dimensions, returning - * the number of output dimensions. Returns -1 if there is an error. + * Parses the subscripts for the output operand into an output that + * includes 'ndim_broadcast' unlabeled dimensions, and returns the total + * number of output dimensions, or -1 if there is an error. Similarly + * to parse_operand_subscripts, the 'out_labels' array will have, for + * each dimension: + * - the ASCII code of the corresponding label; + * - zero for broadcast dimensions, if subscripts has an ellipsis. */ static int parse_output_subscripts(char *subscripts, int length, int ndim_broadcast, - const char *label_counts, - char *out_labels) + const char *label_counts, char *out_labels) { - int i, nlabels, label, idim, ndim, ndim_left; + int i, bdim; + int ndim = 0; int ellipsis = 0; - /* Count the labels, making sure they're all unique and valid */ - nlabels = 0; + /* Process all the output labels. */ for (i = 0; i < length; ++i) { - label = subscripts[i]; - if (label > 0 && isalpha(label)) { - /* Check if it occurs again */ - if (memchr(subscripts+i+1, label, length-i-1) == NULL) { - /* Check that it was used in the inputs */ - if (label_counts[label] == 0) { - PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string included " - "output subscript '%c' which never appeared " - "in an input", (char)label); - return -1; - } + int label = subscripts[i]; - nlabels++; - } - else { + /* A proper label for an axis. */ + if (label > 0 && isalpha(label)) { + /* Check that it doesn't occur again. */ + if (memchr(subscripts + i + 1, label, length - i - 1) != NULL) { PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string includes " - "output subscript '%c' multiple times", - (char)label); + "einstein sum subscripts string includes " + "output subscript '%c' multiple times", + (char)label); return -1; } - } - else if (label != '.' && label != ' ') { - PyErr_Format(PyExc_ValueError, - "invalid subscript '%c' in einstein sum " - "subscripts string, subscripts must " - "be letters", (char)label); - return -1; - } - } - - /* The number of output dimensions */ - ndim = ndim_broadcast + nlabels; - - /* Process the labels from the end until the ellipsis */ - idim = ndim-1; - for (i = length-1; i >= 0; --i) { - label = subscripts[i]; - /* A label for an axis */ - if (label != '.' && label != ' ') { - if (idim >= 0) { - out_labels[idim--] = label; + /* Check that it was used in the inputs. */ + if (label_counts[label] == 0) { + PyErr_Format(PyExc_ValueError, + "einstein sum subscripts string included " + "output subscript '%c' which never appeared " + "in an input", (char)label); + return -1; } - else { + /* Check that there is room in out_labels for this label. */ + if (ndim >= NPY_MAXDIMS) { PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains " - "too many output subscripts"); + "einstein sum subscripts string contains " + "too many subscripts in the output"); return -1; } + + out_labels[ndim++] = label; } - /* The end of the ellipsis */ + /* The beginning of the ellipsis. */ else if (label == '.') { - /* A valid ellipsis */ - if (i >= 2 && subscripts[i-1] == '.' && subscripts[i-2] == '.') { - ellipsis = 1; - length = i-2; - break; - } - else { + /* Check it is a proper ellipsis. */ + if (ellipsis || i + 2 >= length + || subscripts[++i] != '.' || subscripts[++i] != '.') { PyErr_SetString(PyExc_ValueError, - "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...') " - "in the output"); + "einstein sum subscripts string " + "contains a '.' that is not part of " + "an ellipsis ('...') in the output"); return -1; - } - } - } - - if (!ellipsis && idim != -1) { - PyErr_SetString(PyExc_ValueError, - "output has more dimensions than subscripts " - "given in einstein sum, but no '...' ellipsis " - "provided to broadcast the extra dimensions."); - return 0; - } - - /* Reduce ndim to just the dimensions left to fill at the beginning */ - ndim_left = idim+1; - idim = 0; - - /* - * If we stopped because of an ellipsis, start again from the beginning. - * The length was truncated to end at the ellipsis in this case. - */ - if (i > 0) { - for (i = 0; i < length; ++i) { - label = subscripts[i]; - if (label == '.') { - PyErr_SetString(PyExc_ValueError, - "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...') " - "in the output"); + /* Check there is room in out_labels for broadcast dims. */ + if (ndim + ndim_broadcast > NPY_MAXDIMS) { + PyErr_Format(PyExc_ValueError, + "einstein sum subscripts string contains " + "too many subscripts in the output"); return -1; } - /* A label for an axis */ - else if (label != ' ') { - if (idim < ndim_left) { - out_labels[idim++] = label; - } - else { - PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains " - "too many subscripts for the output"); - return -1; - } + + ellipsis = 1; + for (bdim = 0; bdim < ndim_broadcast; ++bdim) { + out_labels[ndim++] = 0; } } + else if (label != ' ') { + PyErr_Format(PyExc_ValueError, + "invalid subscript '%c' in einstein sum " + "subscripts string, subscripts must " + "be letters", (char)label); + return -1; + } } - /* Set the remaining output labels to 0 */ - while (idim < ndim_left) { - out_labels[idim++] = 0; + /* If no ellipsis was found there should be no broadcast dimensions. */ + if (!ellipsis && ndim_broadcast > 0) { + PyErr_SetString(PyExc_ValueError, + "output has more dimensions than subscripts " + "given in einstein sum, but no '...' ellipsis " + "provided to broadcast the extra dimensions."); + return -1; } return ndim; @@ -2613,7 +2527,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, NPY_ORDER order, NPY_CASTING casting, PyArrayObject *out) { - int iop, label, min_label = 127, max_label = 0, num_labels; + int iop, label, min_label = 127, max_label = 0; char label_counts[128]; char op_labels[NPY_MAXARGS][NPY_MAXDIMS]; char output_labels[NPY_MAXDIMS], *iter_labels; @@ -2644,7 +2558,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, /* Parse the subscripts string into label_counts and op_labels */ memset(label_counts, 0, sizeof(label_counts)); - num_labels = 0; for (iop = 0; iop < nop; ++iop) { int length = (int)strcspn(subscripts, ",-"); @@ -2664,7 +2577,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, if (!parse_operand_subscripts(subscripts, length, PyArray_NDIM(op_in[iop]), iop, op_labels[iop], label_counts, - &min_label, &max_label, &num_labels)) { + &min_label, &max_label)) { return NULL; } @@ -2698,21 +2611,18 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, } /* - * If there is no output signature, create one using each label - * that appeared once, in alphabetical order + * If there is no output signature, fill output_labels and ndim_output + * using each label that appeared once, in alphabetical order. */ if (subscripts[0] == '\0') { - char outsubscripts[NPY_MAXDIMS + 3]; - int length; - /* If no output was specified, always broadcast left (like normal) */ - outsubscripts[0] = '.'; - outsubscripts[1] = '.'; - outsubscripts[2] = '.'; - length = 3; + /* If no output was specified, always broadcast left, as usual. */ + for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) { + output_labels[ndim_output] = 0; + } for (label = min_label; label <= max_label; ++label) { if (label_counts[label] == 1) { - if (length < NPY_MAXDIMS-1) { - outsubscripts[length++] = label; + if (ndim_output < NPY_MAXDIMS) { + output_labels[ndim_output++] = label; } else { PyErr_SetString(PyExc_ValueError, @@ -2722,10 +2632,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, } } } - /* Parse the output subscript string */ - ndim_output = parse_output_subscripts(outsubscripts, length, - ndim_broadcast, label_counts, - output_labels); } else { if (subscripts[0] != '-' || subscripts[1] != '>') { @@ -2736,13 +2642,13 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, } subscripts += 2; - /* Parse the output subscript string */ + /* Parse the output subscript string. */ ndim_output = parse_output_subscripts(subscripts, strlen(subscripts), ndim_broadcast, label_counts, output_labels); - } - if (ndim_output < 0) { - return NULL; + if (ndim_output < 0) { + return NULL; + } } if (out != NULL && PyArray_NDIM(out) != ndim_output) { |