summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2018-05-29 00:57:18 -0700
committerGitHub <noreply@github.com>2018-05-29 00:57:18 -0700
commit5075933290a59153ad899f382b93e6af2154a7cf (patch)
treeb2db18967b926681dd68dcbd43eba549d5f257f3
parent6ccb03d5e0bfe3cc9fd2c69e0d9606447c358a2d (diff)
parentc2d59257b0d5e24d9f3811df5c314d525acfc0cb (diff)
downloadnumpy-5075933290a59153ad899f382b93e6af2154a7cf.tar.gz
Merge pull request #11095 from jaimefrio/einsum_cleanup
MAINT: Einsum argument parsing cleanup
-rw-r--r--numpy/core/src/multiarray/einsum.c.src396
1 files changed, 151 insertions, 245 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 470a5fff9..167e3fbdc 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -1776,138 +1776,94 @@ get_sum_of_products_function(int nop, int type_num,
return _unspecialized_table[type_num][nop <= 3 ? nop : 0];
}
+
/*
- * Parses the subscripts for one operand into an output
- * of 'ndim' labels
+ * Parses the subscripts for one operand into an output of 'ndim'
+ * labels. The resulting 'op_labels' array will have:
+ * - the ASCII code of the label for the first occurrence of a label;
+ * - the (negative) offset to the first occurrence of the label for
+ * repeated labels;
+ * - zero for broadcast dimensions, if subscripts has an ellipsis.
+ * For example:
+ * - subscripts="abbcbc", ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2]
+ * - subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99]
*/
+
static int
parse_operand_subscripts(char *subscripts, int length,
- int ndim,
- int iop, char *out_labels,
- char *out_label_counts,
- int *out_min_label,
- int *out_max_label,
- int *out_num_labels)
+ int ndim, int iop, char *op_labels,
+ char *label_counts, int *min_label, int *max_label)
{
- int i, idim, ndim_left, label;
- int ellipsis = 0;
+ int i;
+ int idim = 0;
+ int ellipsis = -1;
- /* Process the labels from the end until the ellipsis */
- idim = ndim-1;
- for (i = length-1; i >= 0; --i) {
- label = subscripts[i];
- /* A label for an axis */
+ /* Process all labels for this operand */
+ for (i = 0; i < length; ++i) {
+ int label = subscripts[i];
+
+ /* A proper label for an axis. */
if (label > 0 && isalpha(label)) {
- if (idim >= 0) {
- out_labels[idim--] = label;
- /* Calculate the min and max labels */
- if (label < *out_min_label) {
- *out_min_label = label;
- }
- if (label > *out_max_label) {
- *out_max_label = label;
- }
- /* If it's the first time we see this label, count it */
- if (out_label_counts[label] == 0) {
- (*out_num_labels)++;
- }
- out_label_counts[label]++;
- }
- else {
+ /* Check we don't exceed the operator dimensions. */
+ if (idim >= ndim) {
PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string contains "
- "too many subscripts for operand %d", iop);
+ "einstein sum subscripts string contains "
+ "too many subscripts for operand %d", iop);
return 0;
}
+
+ op_labels[idim++] = label;
+ if (label < *min_label) {
+ *min_label = label;
+ }
+ if (label > *max_label) {
+ *max_label = label;
+ }
+ label_counts[label]++;
}
- /* The end of the ellipsis */
+ /* The beginning of the ellipsis. */
else if (label == '.') {
- /* A valid ellipsis */
- if (i >= 2 && subscripts[i-1] == '.' && subscripts[i-2] == '.') {
- ellipsis = 1;
- length = i-2;
- break;
- }
- else {
+ /* Check it's a proper ellipsis. */
+ if (ellipsis != -1 || i + 2 >= length
+ || subscripts[++i] != '.' || subscripts[++i] != '.') {
PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...') in "
- "operand %d", iop);
+ "einstein sum subscripts string contains a "
+ "'.' that is not part of an ellipsis ('...') "
+ "in operand %d", iop);
return 0;
-
}
+
+ ellipsis = idim;
}
else if (label != ' ') {
PyErr_Format(PyExc_ValueError,
- "invalid subscript '%c' in einstein sum "
- "subscripts string, subscripts must "
- "be letters", (char)label);
+ "invalid subscript '%c' in einstein sum "
+ "subscripts string, subscripts must "
+ "be letters", (char)label);
return 0;
}
}
- if (!ellipsis && idim != -1) {
- PyErr_Format(PyExc_ValueError,
- "operand has more dimensions than subscripts "
- "given in einstein sum, but no '...' ellipsis "
- "provided to broadcast the extra dimensions.");
- return 0;
- }
-
- /* Reduce ndim to just the dimensions left to fill at the beginning */
- ndim_left = idim+1;
- idim = 0;
-
- /*
- * If we stopped because of an ellipsis, start again from the beginning.
- * The length was truncated to end at the ellipsis in this case.
- */
- if (i > 0) {
- for (i = 0; i < length; ++i) {
- label = subscripts[i];
- /* A label for an axis */
- if (label > 0 && isalnum(label)) {
- if (idim < ndim_left) {
- out_labels[idim++] = label;
- /* Calculate the min and max labels */
- if (label < *out_min_label) {
- *out_min_label = label;
- }
- if (label > *out_max_label) {
- *out_max_label = label;
- }
- /* If it's the first time we see this label, count it */
- if (out_label_counts[label] == 0) {
- (*out_num_labels)++;
- }
- out_label_counts[label]++;
- }
- else {
- PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string contains "
- "too many subscripts for operand %d", iop);
- return 0;
- }
- }
- else if (label == '.') {
- PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...') in "
- "operand %d", iop);
- }
- else if (label != ' ') {
- PyErr_Format(PyExc_ValueError,
- "invalid subscript '%c' in einstein sum "
- "subscripts string, subscripts must "
- "be letters", (char)label);
- return 0;
- }
+ /* No ellipsis found, labels must match dimensions exactly. */
+ if (ellipsis == -1) {
+ if (idim != ndim) {
+ PyErr_Format(PyExc_ValueError,
+ "operand has more dimensions than subscripts "
+ "given in einstein sum, but no '...' ellipsis "
+ "provided to broadcast the extra dimensions.");
+ return 0;
}
}
-
- /* Set the remaining labels to 0 */
- while (idim < ndim_left) {
- out_labels[idim++] = 0;
+ /* Ellipsis found, may have to add broadcast dimensions. */
+ else if (idim < ndim) {
+ /* Move labels after ellipsis to the end. */
+ for (i = 0; i < idim - ellipsis; ++i) {
+ op_labels[ndim - i - 1] = op_labels[idim - i - 1];
+ }
+ /* Set all broadcast dimensions to zero. */
+ for (i = 0; i < ndim - idim; ++i) {
+ op_labels[ellipsis + i] = 0;
+ }
}
/*
@@ -1918,20 +1874,18 @@ parse_operand_subscripts(char *subscripts, int length,
* twos complement arithmetic the char is ok either way here, and
* later where it matters the char is cast to a signed char.
*/
- for (idim = 0; idim < ndim-1; ++idim) {
- char *next;
- /* If this is a proper label, find any duplicates of it */
- label = out_labels[idim];
+ for (idim = 0; idim < ndim - 1; ++idim) {
+ int label = op_labels[idim];
+ /* If it is a proper label, find any duplicates of it. */
if (label > 0) {
- /* Search for the next matching label */
- next = (char *)memchr(out_labels+idim+1, label,
- ndim-idim-1);
+ /* Search for the next matching label. */
+ char *next = memchr(op_labels + idim + 1, label, ndim - idim - 1);
+
while (next != NULL) {
- /* The offset from next to out_labels[idim] (negative) */
- *next = (char)((out_labels+idim)-next);
- /* Search for the next matching label */
- next = (char *)memchr(next+1, label,
- out_labels+ndim-1-next);
+ /* The offset from next to op_labels[idim] (negative). */
+ *next = (char)((op_labels + idim) - next);
+ /* Search for the next matching label. */
+ next = memchr(next + 1, label, op_labels + ndim - 1 - next);
}
}
}
@@ -1939,137 +1893,97 @@ parse_operand_subscripts(char *subscripts, int length,
return 1;
}
+
/*
- * Parses the subscripts for the output operand into an output
- * that requires 'ndim_broadcast' unlabeled dimensions, returning
- * the number of output dimensions. Returns -1 if there is an error.
+ * Parses the subscripts for the output operand into an output that
+ * includes 'ndim_broadcast' unlabeled dimensions, and returns the total
+ * number of output dimensions, or -1 if there is an error. Similarly
+ * to parse_operand_subscripts, the 'out_labels' array will have, for
+ * each dimension:
+ * - the ASCII code of the corresponding label;
+ * - zero for broadcast dimensions, if subscripts has an ellipsis.
*/
static int
parse_output_subscripts(char *subscripts, int length,
int ndim_broadcast,
- const char *label_counts,
- char *out_labels)
+ const char *label_counts, char *out_labels)
{
- int i, nlabels, label, idim, ndim, ndim_left;
+ int i, bdim;
+ int ndim = 0;
int ellipsis = 0;
- /* Count the labels, making sure they're all unique and valid */
- nlabels = 0;
+ /* Process all the output labels. */
for (i = 0; i < length; ++i) {
- label = subscripts[i];
- if (label > 0 && isalpha(label)) {
- /* Check if it occurs again */
- if (memchr(subscripts+i+1, label, length-i-1) == NULL) {
- /* Check that it was used in the inputs */
- if (label_counts[label] == 0) {
- PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string included "
- "output subscript '%c' which never appeared "
- "in an input", (char)label);
- return -1;
- }
+ int label = subscripts[i];
- nlabels++;
- }
- else {
+ /* A proper label for an axis. */
+ if (label > 0 && isalpha(label)) {
+ /* Check that it doesn't occur again. */
+ if (memchr(subscripts + i + 1, label, length - i - 1) != NULL) {
PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string includes "
- "output subscript '%c' multiple times",
- (char)label);
+ "einstein sum subscripts string includes "
+ "output subscript '%c' multiple times",
+ (char)label);
return -1;
}
- }
- else if (label != '.' && label != ' ') {
- PyErr_Format(PyExc_ValueError,
- "invalid subscript '%c' in einstein sum "
- "subscripts string, subscripts must "
- "be letters", (char)label);
- return -1;
- }
- }
-
- /* The number of output dimensions */
- ndim = ndim_broadcast + nlabels;
-
- /* Process the labels from the end until the ellipsis */
- idim = ndim-1;
- for (i = length-1; i >= 0; --i) {
- label = subscripts[i];
- /* A label for an axis */
- if (label != '.' && label != ' ') {
- if (idim >= 0) {
- out_labels[idim--] = label;
+ /* Check that it was used in the inputs. */
+ if (label_counts[label] == 0) {
+ PyErr_Format(PyExc_ValueError,
+ "einstein sum subscripts string included "
+ "output subscript '%c' which never appeared "
+ "in an input", (char)label);
+ return -1;
}
- else {
+ /* Check that there is room in out_labels for this label. */
+ if (ndim >= NPY_MAXDIMS) {
PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string contains "
- "too many output subscripts");
+ "einstein sum subscripts string contains "
+ "too many subscripts in the output");
return -1;
}
+
+ out_labels[ndim++] = label;
}
- /* The end of the ellipsis */
+ /* The beginning of the ellipsis. */
else if (label == '.') {
- /* A valid ellipsis */
- if (i >= 2 && subscripts[i-1] == '.' && subscripts[i-2] == '.') {
- ellipsis = 1;
- length = i-2;
- break;
- }
- else {
+ /* Check it is a proper ellipsis. */
+ if (ellipsis || i + 2 >= length
+ || subscripts[++i] != '.' || subscripts[++i] != '.') {
PyErr_SetString(PyExc_ValueError,
- "einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...') "
- "in the output");
+ "einstein sum subscripts string "
+ "contains a '.' that is not part of "
+ "an ellipsis ('...') in the output");
return -1;
-
}
- }
- }
-
- if (!ellipsis && idim != -1) {
- PyErr_SetString(PyExc_ValueError,
- "output has more dimensions than subscripts "
- "given in einstein sum, but no '...' ellipsis "
- "provided to broadcast the extra dimensions.");
- return 0;
- }
-
- /* Reduce ndim to just the dimensions left to fill at the beginning */
- ndim_left = idim+1;
- idim = 0;
-
- /*
- * If we stopped because of an ellipsis, start again from the beginning.
- * The length was truncated to end at the ellipsis in this case.
- */
- if (i > 0) {
- for (i = 0; i < length; ++i) {
- label = subscripts[i];
- if (label == '.') {
- PyErr_SetString(PyExc_ValueError,
- "einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...') "
- "in the output");
+ /* Check there is room in out_labels for broadcast dims. */
+ if (ndim + ndim_broadcast > NPY_MAXDIMS) {
+ PyErr_Format(PyExc_ValueError,
+ "einstein sum subscripts string contains "
+ "too many subscripts in the output");
return -1;
}
- /* A label for an axis */
- else if (label != ' ') {
- if (idim < ndim_left) {
- out_labels[idim++] = label;
- }
- else {
- PyErr_Format(PyExc_ValueError,
- "einstein sum subscripts string contains "
- "too many subscripts for the output");
- return -1;
- }
+
+ ellipsis = 1;
+ for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
+ out_labels[ndim++] = 0;
}
}
+ else if (label != ' ') {
+ PyErr_Format(PyExc_ValueError,
+ "invalid subscript '%c' in einstein sum "
+ "subscripts string, subscripts must "
+ "be letters", (char)label);
+ return -1;
+ }
}
- /* Set the remaining output labels to 0 */
- while (idim < ndim_left) {
- out_labels[idim++] = 0;
+ /* If no ellipsis was found there should be no broadcast dimensions. */
+ if (!ellipsis && ndim_broadcast > 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "output has more dimensions than subscripts "
+ "given in einstein sum, but no '...' ellipsis "
+ "provided to broadcast the extra dimensions.");
+ return -1;
}
return ndim;
@@ -2613,7 +2527,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
NPY_ORDER order, NPY_CASTING casting,
PyArrayObject *out)
{
- int iop, label, min_label = 127, max_label = 0, num_labels;
+ int iop, label, min_label = 127, max_label = 0;
char label_counts[128];
char op_labels[NPY_MAXARGS][NPY_MAXDIMS];
char output_labels[NPY_MAXDIMS], *iter_labels;
@@ -2644,7 +2558,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
/* Parse the subscripts string into label_counts and op_labels */
memset(label_counts, 0, sizeof(label_counts));
- num_labels = 0;
for (iop = 0; iop < nop; ++iop) {
int length = (int)strcspn(subscripts, ",-");
@@ -2664,7 +2577,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
if (!parse_operand_subscripts(subscripts, length,
PyArray_NDIM(op_in[iop]),
iop, op_labels[iop], label_counts,
- &min_label, &max_label, &num_labels)) {
+ &min_label, &max_label)) {
return NULL;
}
@@ -2698,21 +2611,18 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
}
/*
- * If there is no output signature, create one using each label
- * that appeared once, in alphabetical order
+ * If there is no output signature, fill output_labels and ndim_output
+ * using each label that appeared once, in alphabetical order.
*/
if (subscripts[0] == '\0') {
- char outsubscripts[NPY_MAXDIMS + 3];
- int length;
- /* If no output was specified, always broadcast left (like normal) */
- outsubscripts[0] = '.';
- outsubscripts[1] = '.';
- outsubscripts[2] = '.';
- length = 3;
+ /* If no output was specified, always broadcast left, as usual. */
+ for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
+ output_labels[ndim_output] = 0;
+ }
for (label = min_label; label <= max_label; ++label) {
if (label_counts[label] == 1) {
- if (length < NPY_MAXDIMS-1) {
- outsubscripts[length++] = label;
+ if (ndim_output < NPY_MAXDIMS) {
+ output_labels[ndim_output++] = label;
}
else {
PyErr_SetString(PyExc_ValueError,
@@ -2722,10 +2632,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
}
}
}
- /* Parse the output subscript string */
- ndim_output = parse_output_subscripts(outsubscripts, length,
- ndim_broadcast, label_counts,
- output_labels);
}
else {
if (subscripts[0] != '-' || subscripts[1] != '>') {
@@ -2736,13 +2642,13 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
}
subscripts += 2;
- /* Parse the output subscript string */
+ /* Parse the output subscript string. */
ndim_output = parse_output_subscripts(subscripts, strlen(subscripts),
ndim_broadcast, label_counts,
output_labels);
- }
- if (ndim_output < 0) {
- return NULL;
+ if (ndim_output < 0) {
+ return NULL;
+ }
}
if (out != NULL && PyArray_NDIM(out) != ndim_output) {