diff options
Diffstat (limited to 'numpy/core/einsumfunc.py')
-rw-r--r-- | numpy/core/einsumfunc.py | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index a1e2efdb4..c46ae173d 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -3,6 +3,7 @@ Implementation of optimized einsum. """ import itertools +import operator from numpy.core.multiarray import c_einsum from numpy.core.numeric import asanyarray, tensordot @@ -576,11 +577,13 @@ def _parse_einsum_input(operands): for s in sub: if s is Ellipsis: subscripts += "..." - elif isinstance(s, int): - subscripts += einsum_symbols[s] else: - raise TypeError("For this input type lists must contain " - "either int or Ellipsis") + try: + s = operator.index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") from e + subscripts += einsum_symbols[s] if num != last: subscripts += "," @@ -589,11 +592,13 @@ def _parse_einsum_input(operands): for s in output_list: if s is Ellipsis: subscripts += "..." - elif isinstance(s, int): - subscripts += einsum_symbols[s] else: - raise TypeError("For this input type lists must contain " - "either int or Ellipsis") + try: + s = operator.index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") from e + subscripts += einsum_symbols[s] # Check for proper "->" if ("-" in subscripts) or (">" in subscripts): invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) |