diff options
author | Jaime Fernandez <jaimefrio@google.com> | 2018-05-14 05:38:28 -0700 |
---|---|---|
committer | Jaime Fernandez <jaimefrio@google.com> | 2018-05-14 05:38:28 -0700 |
commit | a12174dd702b75d07b1fe03944df11595e6c4e28 (patch) | |
tree | a16a1cc8ac934ab9c04ea36be83ace15ab56247c | |
parent | ed0815f395df557784695e2781330769007f2ef0 (diff) | |
download | numpy-a12174dd702b75d07b1fe03944df11595e6c4e28.tar.gz |
MAINT: Refactor parse_operand_subscripts to avoid repetition.
-rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 174 |
1 files changed, 69 insertions, 105 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 4340b38d4..9b3a52636 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -1776,129 +1776,94 @@ get_sum_of_products_function(int nop, int type_num, return _unspecialized_table[type_num][nop <= 3 ? nop : 0]; } + /* - * Parses the subscripts for one operand into an output - * of 'ndim' labels + * Parses the subscripts for one operand into an output of 'ndim' + * labels. The resulting 'op_labels' array will have: + * - the ASCII code of the label for the first occurrence of a label; + * - the (negative) offset to the first occurrence of the label for + * repeated labels; + * - zero for broadcast dimensions, if subscripts has an ellipsis. + * For example: + * - subscripts="abbcbc", ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2] + * - subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99] */ + static int parse_operand_subscripts(char *subscripts, int length, - int ndim, - int iop, char *out_labels, - char *out_label_counts, - int *out_min_label, - int *out_max_label) + int ndim, int iop, char *op_labels, + char *label_counts, int *min_label, int *max_label) { - int i, idim, ndim_left, label; - int ellipsis = 0; + int i; + int idim = 0; + int ellipsis = -1; - /* Process the labels from the end until the ellipsis */ - idim = ndim-1; - for (i = length-1; i >= 0; --i) { - label = subscripts[i]; - /* A label for an axis */ + /* Process all labels for this operand */ + for (i = 0; i < length; ++i) { + int label = subscripts[i]; + + /* A proper label for an axis. */ if (label > 0 && isalpha(label)) { - if (idim >= 0) { - out_labels[idim--] = label; - /* Calculate the min and max labels */ - if (label < *out_min_label) { - *out_min_label = label; + if (idim < ndim) { + op_labels[idim++] = label; + if (label < *min_label) { + *min_label = label; } - if (label > *out_max_label) { - *out_max_label = label; + if (label > *max_label) { + *max_label = label; } - out_label_counts[label]++; + label_counts[label]++; } else { PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains " - "too many subscripts for operand %d", iop); + "einstein sum subscripts string contains " + "too many subscripts for operand %d", iop); return 0; } } - /* The end of the ellipsis */ + /* The beginning of the ellipsis. */ else if (label == '.') { - /* A valid ellipsis */ - if (i >= 2 && subscripts[i-1] == '.' && subscripts[i-2] == '.') { - ellipsis = 1; - length = i-2; - break; + if (ellipsis == -1 && i + 2 < length + && subscripts[++i] == '.' && subscripts[++i] == '.') { + ellipsis = idim; } else { PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...') in " - "operand %d", iop); + "einstein sum subscripts string contains a " + "'.' that is not part of an ellipsis ('...') " + "in operand %d", iop); return 0; - } } else if (label != ' ') { PyErr_Format(PyExc_ValueError, - "invalid subscript '%c' in einstein sum " - "subscripts string, subscripts must " - "be letters", (char)label); + "invalid subscript '%c' in einstein sum " + "subscripts string, subscripts must " + "be letters", (char)label); return 0; } } - if (!ellipsis && idim != -1) { - PyErr_Format(PyExc_ValueError, - "operand has more dimensions than subscripts " - "given in einstein sum, but no '...' ellipsis " - "provided to broadcast the extra dimensions."); - return 0; - } - - /* Reduce ndim to just the dimensions left to fill at the beginning */ - ndim_left = idim+1; - idim = 0; - - /* - * If we stopped because of an ellipsis, start again from the beginning. - * The length was truncated to end at the ellipsis in this case. - */ - if (i > 0) { - for (i = 0; i < length; ++i) { - label = subscripts[i]; - /* A label for an axis */ - if (label > 0 && isalnum(label)) { - if (idim < ndim_left) { - out_labels[idim++] = label; - /* Calculate the min and max labels */ - if (label < *out_min_label) { - *out_min_label = label; - } - if (label > *out_max_label) { - *out_max_label = label; - } - out_label_counts[label]++; - } - else { - PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains " - "too many subscripts for operand %d", iop); - return 0; - } - } - else if (label == '.') { - PyErr_Format(PyExc_ValueError, - "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...') in " - "operand %d", iop); - } - else if (label != ' ') { - PyErr_Format(PyExc_ValueError, - "invalid subscript '%c' in einstein sum " - "subscripts string, subscripts must " - "be letters", (char)label); - return 0; - } + /* No ellipsis found, labels must match dimensions exactly. */ + if (ellipsis == -1) { + if (idim != ndim) { + PyErr_Format(PyExc_ValueError, + "operand has more dimensions than subscripts " + "given in einstein sum, but no '...' ellipsis " + "provided to broadcast the extra dimensions."); + return 0; } } - - /* Set the remaining labels to 0 */ - while (idim < ndim_left) { - out_labels[idim++] = 0; + /* Ellipsis found, may have to add broadcast dimensions. */ + else if (idim < ndim) { + /* Move labels after ellipsis to the end. */ + for (i = 0; i < idim - ellipsis; ++i) { + op_labels[ndim - i - 1] = op_labels[idim - i - 1]; + } + /* Set all broadcast dimensions to zero. */ + for (i = 0; i < ndim - idim; ++i) { + op_labels[ellipsis + i] = 0; + } } /* @@ -1909,20 +1874,18 @@ parse_operand_subscripts(char *subscripts, int length, * twos complement arithmetic the char is ok either way here, and * later where it matters the char is cast to a signed char. */ - for (idim = 0; idim < ndim-1; ++idim) { - char *next; - /* If this is a proper label, find any duplicates of it */ - label = out_labels[idim]; + for (idim = 0; idim < ndim - 1; ++idim) { + int label = op_labels[idim]; + /* If it is a proper label, find any duplicates of it. */ if (label > 0) { - /* Search for the next matching label */ - next = (char *)memchr(out_labels+idim+1, label, - ndim-idim-1); + /* Search for the next ,atching label. */ + char *next = memchr(op_labels + idim + 1, label, ndim - idim - 1); + while (next != NULL) { - /* The offset from next to out_labels[idim] (negative) */ - *next = (char)((out_labels+idim)-next); - /* Search for the next matching label */ - next = (char *)memchr(next+1, label, - out_labels+ndim-1-next); + /* The offset from next to op_labels[idim] (negative). */ + *next = (char)((op_labels + idim) - next); + /* Search for the next matching label. */ + next = memchr(next + 1, label, op_labels + ndim - 1 - next); } } } @@ -1930,6 +1893,7 @@ parse_operand_subscripts(char *subscripts, int length, return 1; } + /* * Parses the subscripts for the output operand into an output * that requires 'ndim_broadcast' unlabeled dimensions, returning |