diff options
author | Ryan <50784910+rlintott@users.noreply.github.com> | 2020-05-16 07:59:50 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-16 09:59:50 -0500 |
commit | 3ba84d2fedf30aba7d47f81e36bbf265745c4e1c (patch) | |
tree | 59dd79894caeb579502b3ff7412130211fa674a3 /numpy/core | |
parent | 50ce0fce70cac779919d97578381a9d762f42594 (diff) | |
download | numpy-3ba84d2fedf30aba7d47f81e36bbf265745c4e1c.tar.gz |
BUG: numpy.einsum indexing arrays now accept numpy int type (gh-16080)
* Using PyArray_PyIntAsIntp helper function instead
* TST: add tests for einsum numpy int and bool list subscripts
Added tests to check that einsum accepts numpy int64 types and
rejects bool. Rejecting bools is new behaviour in subscript lists.
I changed ValueError to TypeError on line 2496 in multiarraymodule.c
as it is more appropriate. I also modified einsumfunc.py to have the
same behaviour as in the C file when checking subscript list.
(Reject bools but accept anything else from operator.index())
Closes gh-15961
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/einsumfunc.py | 21 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 26 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 7 |
3 files changed, 33 insertions, 21 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index a1e2efdb4..c46ae173d 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -3,6 +3,7 @@ Implementation of optimized einsum. """ import itertools +import operator from numpy.core.multiarray import c_einsum from numpy.core.numeric import asanyarray, tensordot @@ -576,11 +577,13 @@ def _parse_einsum_input(operands): for s in sub: if s is Ellipsis: subscripts += "..." - elif isinstance(s, int): - subscripts += einsum_symbols[s] else: - raise TypeError("For this input type lists must contain " - "either int or Ellipsis") + try: + s = operator.index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") from e + subscripts += einsum_symbols[s] if num != last: subscripts += "," @@ -589,11 +592,13 @@ def _parse_einsum_input(operands): for s in output_list: if s is Ellipsis: subscripts += "..." - elif isinstance(s, int): - subscripts += einsum_symbols[s] else: - raise TypeError("For this input type lists must contain " - "either int or Ellipsis") + try: + s = operator.index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") from e + subscripts += einsum_symbols[s] # Check for proper "->" if ("-" in subscripts) or (">" in subscripts): invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 4c316052d..6915371d8 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -2438,7 +2438,6 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) } size = PySequence_Size(obj); - for (i = 0; i < size; ++i) { item = PySequence_Fast_GET_ITEM(obj, i); /* Ellipsis */ @@ -2461,8 +2460,16 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) ellipsis = 1; } /* Subscript */ - else if (PyInt_Check(item) || PyLong_Check(item)) { - long s = PyInt_AsLong(item); + else { + npy_intp s = PyArray_PyIntAsIntp(item); + /* Invalid */ + if (error_converting(s)) { + PyErr_SetString(PyExc_TypeError, + "each subscript must be either an integer " + "or an ellipsis"); + Py_DECREF(obj); + return -1; + } npy_bool bad_input = 0; if (subindex + 1 >= subsize) { @@ -2472,7 +2479,7 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) return -1; } - if ( s < 0 ) { + if (s < 0) { bad_input = 1; } else if (s < 26) { @@ -2490,16 +2497,9 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize) "subscript is not within the valid range [0, 52)"); Py_DECREF(obj); return -1; - } - } - /* Invalid */ - else { - PyErr_SetString(PyExc_ValueError, - "each subscript must be either an integer " - "or an ellipsis"); - Py_DECREF(obj); - return -1; + } } + } Py_DECREF(obj); diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index 68491681a..da84735a0 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -274,6 +274,13 @@ class TestEinsum: assert_equal(np.einsum(a, [0, 0], optimize=do_opt), np.trace(a).astype(dtype)) + # gh-15961: should accept numpy int64 type in subscript list + np_array = np.asarray([0, 0]) + assert_equal(np.einsum(a, np_array, optimize=do_opt), + np.trace(a).astype(dtype)) + assert_equal(np.einsum(a, list(np_array), optimize=do_opt), + np.trace(a).astype(dtype)) + # multiply(a, b) assert_equal(np.einsum("..., ...", 3, 4), 12) # scalar case for n in range(1, 17): |