From 3ba84d2fedf30aba7d47f81e36bbf265745c4e1c Mon Sep 17 00:00:00 2001 From: Ryan <50784910+rlintott@users.noreply.github.com> Date: Sat, 16 May 2020 07:59:50 -0700 Subject: BUG: numpy.einsum indexing arrays now accept numpy int type (gh-16080) * Using PyArray_PyIntAsIntp helper function instead * TST: add tests for einsum numpy int and bool list subscripts Added tests to check that einsum accepts numpy int64 types and rejects bool. Rejecting bools is new behaviour in subscript lists. I changed ValueError to TypeError on line 2496 in multiarraymodule.c as it is more appropriate. I also modified einsumfunc.py to have the same behaviour as in the C file when checking subscript list. (Reject bools but accept anything else from operator.index()) Closes gh-15961 --- numpy/core/einsumfunc.py | 21 +++++++++++++-------- numpy/core/src/multiarray/multiarraymodule.c | 26 +++++++++++++------------- numpy/core/tests/test_einsum.py | 7 +++++++ 3 files changed, 33 insertions(+), 21 deletions(-) (limited to 'numpy') diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index a1e2efdb4..c46ae173d 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -3,6 +3,7 @@ Implementation of optimized einsum. """ import itertools +import operator from numpy.core.multiarray import c_einsum from numpy.core.numeric import asanyarray, tensordot @@ -576,11 +577,13 @@ def _parse_einsum_input(operands): for s in sub: if s is Ellipsis: subscripts += "..." - elif isinstance(s, int): - subscripts += einsum_symbols[s] else: - raise TypeError("For this input type lists must contain " - "either int or Ellipsis") + try: + s = operator.index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") from e + subscripts += einsum_symbols[s] if num != last: subscripts += "," @@ -589,11 +592,13 @@ def _parse_einsum_input(operands): for s in output_list: if s is Ellipsis: subscripts += "..." - elif isinstance(s, int): - subscripts += einsum_symbols[s] else: - raise TypeError("For this input type lists must contain " - "either int or Ellipsis") + try: + s = operator.index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") from e + subscripts += einsum_symbols[s] # Check for proper "->" if ("-" in subscripts) or (">" in subscripts): invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 4c316052d..6915371d8 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -2438,7 +2438,6 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) } size = PySequence_Size(obj); - for (i = 0; i < size; ++i) { item = PySequence_Fast_GET_ITEM(obj, i); /* Ellipsis */ @@ -2461,8 +2460,16 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) ellipsis = 1; } /* Subscript */ - else if (PyInt_Check(item) || PyLong_Check(item)) { - long s = PyInt_AsLong(item); + else { + npy_intp s = PyArray_PyIntAsIntp(item); + /* Invalid */ + if (error_converting(s)) { + PyErr_SetString(PyExc_TypeError, + "each subscript must be either an integer " + "or an ellipsis"); + Py_DECREF(obj); + return -1; + } npy_bool bad_input = 0; if (subindex + 1 >= subsize) { @@ -2472,7 +2479,7 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) return -1; } - if ( s < 0 ) { + if (s < 0) { bad_input = 1; } else if (s < 26) { @@ -2490,16 +2497,9 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) "subscript is not within the valid range [0, 52)"); Py_DECREF(obj); return -1; - } - } - /* Invalid */ - else { - PyErr_SetString(PyExc_ValueError, - "each subscript must be either an integer " - "or an ellipsis"); - Py_DECREF(obj); - return -1; + } } + } Py_DECREF(obj); diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index 68491681a..da84735a0 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -274,6 +274,13 @@ class TestEinsum: assert_equal(np.einsum(a, [0, 0], optimize=do_opt), np.trace(a).astype(dtype)) + # gh-15961: should accept numpy int64 type in subscript list + np_array = np.asarray([0, 0]) + assert_equal(np.einsum(a, np_array, optimize=do_opt), + np.trace(a).astype(dtype)) + assert_equal(np.einsum(a, list(np_array), optimize=do_opt), + np.trace(a).astype(dtype)) + # multiply(a, b) assert_equal(np.einsum("..., ...", 3, 4), 12) # scalar case for n in range(1, 17): -- cgit v1.2.1