summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorRyan <50784910+rlintott@users.noreply.github.com>2020-05-16 07:59:50 -0700
committerGitHub <noreply@github.com>2020-05-16 09:59:50 -0500
commit3ba84d2fedf30aba7d47f81e36bbf265745c4e1c (patch)
tree59dd79894caeb579502b3ff7412130211fa674a3 /numpy/core
parent50ce0fce70cac779919d97578381a9d762f42594 (diff)
downloadnumpy-3ba84d2fedf30aba7d47f81e36bbf265745c4e1c.tar.gz
BUG: numpy.einsum indexing arrays now accept numpy int type (gh-16080)
* Using PyArray_PyIntAsIntp helper function instead * TST: add tests for einsum numpy int and bool list subscripts Added tests to check that einsum accepts numpy int64 types and rejects bool. Rejecting bools is new behaviour in subscript lists. I changed ValueError to TypeError on line 2496 in multiarraymodule.c as it is more appropriate. I also modified einsumfunc.py to have the same behaviour as in the C file when checking subscript list. (Reject bools but accept anything else from operator.index()) Closes gh-15961
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/einsumfunc.py21
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c26
-rw-r--r--numpy/core/tests/test_einsum.py7
3 files changed, 33 insertions, 21 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
index a1e2efdb4..c46ae173d 100644
--- a/numpy/core/einsumfunc.py
+++ b/numpy/core/einsumfunc.py
@@ -3,6 +3,7 @@ Implementation of optimized einsum.
"""
import itertools
+import operator
from numpy.core.multiarray import c_einsum
from numpy.core.numeric import asanyarray, tensordot
@@ -576,11 +577,13 @@ def _parse_einsum_input(operands):
for s in sub:
if s is Ellipsis:
subscripts += "..."
- elif isinstance(s, int):
- subscripts += einsum_symbols[s]
else:
- raise TypeError("For this input type lists must contain "
- "either int or Ellipsis")
+ try:
+ s = operator.index(s)
+ except TypeError as e:
+ raise TypeError("For this input type lists must contain "
+ "either int or Ellipsis") from e
+ subscripts += einsum_symbols[s]
if num != last:
subscripts += ","
@@ -589,11 +592,13 @@ def _parse_einsum_input(operands):
for s in output_list:
if s is Ellipsis:
subscripts += "..."
- elif isinstance(s, int):
- subscripts += einsum_symbols[s]
else:
- raise TypeError("For this input type lists must contain "
- "either int or Ellipsis")
+ try:
+ s = operator.index(s)
+ except TypeError as e:
+ raise TypeError("For this input type lists must contain "
+ "either int or Ellipsis") from e
+ subscripts += einsum_symbols[s]
# Check for proper "->"
if ("-" in subscripts) or (">" in subscripts):
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 4c316052d..6915371d8 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -2438,7 +2438,6 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
}
size = PySequence_Size(obj);
-
for (i = 0; i < size; ++i) {
item = PySequence_Fast_GET_ITEM(obj, i);
/* Ellipsis */
@@ -2461,8 +2460,16 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
ellipsis = 1;
}
/* Subscript */
- else if (PyInt_Check(item) || PyLong_Check(item)) {
- long s = PyInt_AsLong(item);
+ else {
+ npy_intp s = PyArray_PyIntAsIntp(item);
+ /* Invalid */
+ if (error_converting(s)) {
+ PyErr_SetString(PyExc_TypeError,
+ "each subscript must be either an integer "
+ "or an ellipsis");
+ Py_DECREF(obj);
+ return -1;
+ }
npy_bool bad_input = 0;
if (subindex + 1 >= subsize) {
@@ -2472,7 +2479,7 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
return -1;
}
- if ( s < 0 ) {
+ if (s < 0) {
bad_input = 1;
}
else if (s < 26) {
@@ -2490,16 +2497,9 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
"subscript is not within the valid range [0, 52)");
Py_DECREF(obj);
return -1;
- }
- }
- /* Invalid */
- else {
- PyErr_SetString(PyExc_ValueError,
- "each subscript must be either an integer "
- "or an ellipsis");
- Py_DECREF(obj);
- return -1;
+ }
}
+
}
Py_DECREF(obj);
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 68491681a..da84735a0 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -274,6 +274,13 @@ class TestEinsum:
assert_equal(np.einsum(a, [0, 0], optimize=do_opt),
np.trace(a).astype(dtype))
+ # gh-15961: should accept numpy int64 type in subscript list
+ np_array = np.asarray([0, 0])
+ assert_equal(np.einsum(a, np_array, optimize=do_opt),
+ np.trace(a).astype(dtype))
+ assert_equal(np.einsum(a, list(np_array), optimize=do_opt),
+ np.trace(a).astype(dtype))
+
# multiply(a, b)
assert_equal(np.einsum("..., ...", 3, 4), 12) # scalar case
for n in range(1, 17):