summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-05-25 15:29:22 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-05-25 15:29:22 +0200
commit5847a1b2d4252582cd4648972780e5428c42ba1e (patch)
tree547900a2afba771dc33eb46ef6ddc834db3e914c
parentf9a1f3199623d48a4aeaacdcbae6a1851e125cb4 (diff)
downloadnumpy-5847a1b2d4252582cd4648972780e5428c42ba1e.tar.gz
BUG: Fixed an issue wherein `_GenericAlias.__getitem__` would raise for underlying types with >1 parameters
-rw-r--r--numpy/typing/_generic_alias.py3
-rw-r--r--numpy/typing/tests/test_generic_alias.py3
2 files changed, 5 insertions, 1 deletions
diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py
index f98fca62e..d83979aaf 100644
--- a/numpy/typing/_generic_alias.py
+++ b/numpy/typing/_generic_alias.py
@@ -63,7 +63,8 @@ def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
elif isinstance(i, _GenericAlias):
value = _reconstruct_alias(i, parameters)
elif hasattr(i, "__parameters__"):
- value = i[next(parameters)]
+ prm_tup = tuple(next(parameters) for _ in i.__parameters__)
+ value = i[prm_tup]
else:
value = i
args.append(value)
diff --git a/numpy/typing/tests/test_generic_alias.py b/numpy/typing/tests/test_generic_alias.py
index 13072051a..5f86c4001 100644
--- a/numpy/typing/tests/test_generic_alias.py
+++ b/numpy/typing/tests/test_generic_alias.py
@@ -11,6 +11,8 @@ import numpy as np
from numpy.typing._generic_alias import _GenericAlias
ScalarType = TypeVar("ScalarType", bound=np.generic)
+T1 = TypeVar("T1")
+T2 = TypeVar("T2")
DType = _GenericAlias(np.dtype, (ScalarType,))
NDArray = _GenericAlias(np.ndarray, (Any, DType))
@@ -50,6 +52,7 @@ class TestGenericAlias:
("__getitem__", lambda n: n[np.float64]),
("__getitem__", lambda n: n[ScalarType][np.float64]),
("__getitem__", lambda n: n[Union[np.int64, ScalarType]][np.float64]),
+ ("__getitem__", lambda n: n[Union[T1, T2]][np.float32, np.float64]),
("__eq__", lambda n: n == n),
("__ne__", lambda n: n != np.ndarray),
("__dir__", lambda n: dir(n)),