summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/einsumfunc.py21
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c26
-rw-r--r--numpy/core/tests/test_einsum.py7
3 files changed, 33 insertions, 21 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
index a1e2efdb4..c46ae173d 100644
--- a/numpy/core/einsumfunc.py
+++ b/numpy/core/einsumfunc.py
@@ -3,6 +3,7 @@ Implementation of optimized einsum.
"""
import itertools
+import operator
from numpy.core.multiarray import c_einsum
from numpy.core.numeric import asanyarray, tensordot
@@ -576,11 +577,13 @@ def _parse_einsum_input(operands):
for s in sub:
if s is Ellipsis:
subscripts += "..."
- elif isinstance(s, int):
- subscripts += einsum_symbols[s]
else:
- raise TypeError("For this input type lists must contain "
- "either int or Ellipsis")
+ try:
+ s = operator.index(s)
+ except TypeError as e:
+ raise TypeError("For this input type lists must contain "
+ "either int or Ellipsis") from e
+ subscripts += einsum_symbols[s]
if num != last:
subscripts += ","
@@ -589,11 +592,13 @@ def _parse_einsum_input(operands):
for s in output_list:
if s is Ellipsis:
subscripts += "..."
- elif isinstance(s, int):
- subscripts += einsum_symbols[s]
else:
- raise TypeError("For this input type lists must contain "
- "either int or Ellipsis")
+ try:
+ s = operator.index(s)
+ except TypeError as e:
+ raise TypeError("For this input type lists must contain "
+ "either int or Ellipsis") from e
+ subscripts += einsum_symbols[s]
# Check for proper "->"
if ("-" in subscripts) or (">" in subscripts):
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 4c316052d..6915371d8 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -2438,7 +2438,6 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
}
size = PySequence_Size(obj);
-
for (i = 0; i < size; ++i) {
item = PySequence_Fast_GET_ITEM(obj, i);
/* Ellipsis */
@@ -2461,8 +2460,16 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
ellipsis = 1;
}
/* Subscript */
- else if (PyInt_Check(item) || PyLong_Check(item)) {
- long s = PyInt_AsLong(item);
+ else {
+ npy_intp s = PyArray_PyIntAsIntp(item);
+ /* Invalid */
+ if (error_converting(s)) {
+ PyErr_SetString(PyExc_TypeError,
+ "each subscript must be either an integer "
+ "or an ellipsis");
+ Py_DECREF(obj);
+ return -1;
+ }
npy_bool bad_input = 0;
if (subindex + 1 >= subsize) {
@@ -2472,7 +2479,7 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
return -1;
}
- if ( s < 0 ) {
+ if (s < 0) {
bad_input = 1;
}
else if (s < 26) {
@@ -2490,16 +2497,9 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
"subscript is not within the valid range [0, 52)");
Py_DECREF(obj);
return -1;
- }
- }
- /* Invalid */
- else {
- PyErr_SetString(PyExc_ValueError,
- "each subscript must be either an integer "
- "or an ellipsis");
- Py_DECREF(obj);
- return -1;
+ }
}
+
}
Py_DECREF(obj);
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 68491681a..da84735a0 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -274,6 +274,13 @@ class TestEinsum:
assert_equal(np.einsum(a, [0, 0], optimize=do_opt),
np.trace(a).astype(dtype))
+ # gh-15961: should accept numpy int64 type in subscript list
+ np_array = np.asarray([0, 0])
+ assert_equal(np.einsum(a, np_array, optimize=do_opt),
+ np.trace(a).astype(dtype))
+ assert_equal(np.einsum(a, list(np_array), optimize=do_opt),
+ np.trace(a).astype(dtype))
+
# multiply(a, b)
assert_equal(np.einsum("..., ...", 3, 4), 12) # scalar case
for n in range(1, 17):