summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-06-04 16:13:42 -0600
committerGitHub <noreply@github.com>2021-06-04 16:13:42 -0600
commitffa5ece3547b6e973167aadde82f8d6b4e4380d3 (patch)
treec7518f9d7032fc54d43d8ee0172c90cec2d0effc
parent6790873334b143117f4e8d1f515def8c7fdeb9fb (diff)
parentcca815760e1873de46eeba6fb717e63615f2cd53 (diff)
downloadnumpy-ffa5ece3547b6e973167aadde82f8d6b4e4380d3.tar.gz
Merge pull request #19172 from BvB93/generic_alias
BUG: Fixed an issue wherein `_GenericAlias` could raise for non-iterable parameters
-rw-r--r--numpy/typing/_generic_alias.py17
-rw-r--r--numpy/typing/tests/test_generic_alias.py6
2 files changed, 18 insertions, 5 deletions
diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py
index 0d30f54ca..8d65ef855 100644
--- a/numpy/typing/_generic_alias.py
+++ b/numpy/typing/_generic_alias.py
@@ -93,7 +93,7 @@ class _GenericAlias:
return super().__getattribute__("_origin")
@property
- def __args__(self) -> Tuple[Any, ...]:
+ def __args__(self) -> Tuple[object, ...]:
return super().__getattribute__("_args")
@property
@@ -101,16 +101,23 @@ class _GenericAlias:
"""Type variables in the ``GenericAlias``."""
return super().__getattribute__("_parameters")
- def __init__(self, origin: type, args: Any) -> None:
+ def __init__(
+ self,
+ origin: type,
+ args: object | Tuple[object, ...],
+ ) -> None:
self._origin = origin
self._args = args if isinstance(args, tuple) else (args,)
- self._parameters = tuple(_parse_parameters(args))
+ self._parameters = tuple(_parse_parameters(self.__args__))
@property
def __call__(self) -> type:
return self.__origin__
- def __reduce__(self: _T) -> Tuple[Type[_T], Tuple[type, Tuple[Any, ...]]]:
+ def __reduce__(self: _T) -> Tuple[
+ Type[_T],
+ Tuple[type, Tuple[object, ...]],
+ ]:
cls = type(self)
return cls, (self.__origin__, self.__args__)
@@ -148,7 +155,7 @@ class _GenericAlias:
origin = _to_str(self.__origin__)
return f"{origin}[{args}]"
- def __getitem__(self: _T, key: Any) -> _T:
+ def __getitem__(self: _T, key: object | Tuple[object, ...]) -> _T:
"""Return ``self[key]``."""
key_tup = key if isinstance(key, tuple) else (key,)
diff --git a/numpy/typing/tests/test_generic_alias.py b/numpy/typing/tests/test_generic_alias.py
index 0b9917439..27afe3927 100644
--- a/numpy/typing/tests/test_generic_alias.py
+++ b/numpy/typing/tests/test_generic_alias.py
@@ -41,6 +41,12 @@ class TestGenericAlias:
@pytest.mark.parametrize("name,func", [
("__init__", lambda n: n),
+ ("__init__", lambda n: _GenericAlias(np.ndarray, Any)),
+ ("__init__", lambda n: _GenericAlias(np.ndarray, (Any,))),
+ ("__init__", lambda n: _GenericAlias(np.ndarray, (Any, Any))),
+ ("__init__", lambda n: _GenericAlias(np.ndarray, T1)),
+ ("__init__", lambda n: _GenericAlias(np.ndarray, (T1,))),
+ ("__init__", lambda n: _GenericAlias(np.ndarray, (T1, T2))),
("__origin__", lambda n: n.__origin__),
("__args__", lambda n: n.__args__),
("__parameters__", lambda n: n.__parameters__),