diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-06-04 16:13:42 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-04 16:13:42 -0600 |
commit | ffa5ece3547b6e973167aadde82f8d6b4e4380d3 (patch) | |
tree | c7518f9d7032fc54d43d8ee0172c90cec2d0effc | |
parent | 6790873334b143117f4e8d1f515def8c7fdeb9fb (diff) | |
parent | cca815760e1873de46eeba6fb717e63615f2cd53 (diff) | |
download | numpy-ffa5ece3547b6e973167aadde82f8d6b4e4380d3.tar.gz |
Merge pull request #19172 from BvB93/generic_alias
BUG: Fixed an issue wherein `_GenericAlias` could raise for non-iterable parameters
-rw-r--r-- | numpy/typing/_generic_alias.py | 17 | ||||
-rw-r--r-- | numpy/typing/tests/test_generic_alias.py | 6 |
2 files changed, 18 insertions, 5 deletions
diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py index 0d30f54ca..8d65ef855 100644 --- a/numpy/typing/_generic_alias.py +++ b/numpy/typing/_generic_alias.py @@ -93,7 +93,7 @@ class _GenericAlias: return super().__getattribute__("_origin") @property - def __args__(self) -> Tuple[Any, ...]: + def __args__(self) -> Tuple[object, ...]: return super().__getattribute__("_args") @property @@ -101,16 +101,23 @@ class _GenericAlias: """Type variables in the ``GenericAlias``.""" return super().__getattribute__("_parameters") - def __init__(self, origin: type, args: Any) -> None: + def __init__( + self, + origin: type, + args: object | Tuple[object, ...], + ) -> None: self._origin = origin self._args = args if isinstance(args, tuple) else (args,) - self._parameters = tuple(_parse_parameters(args)) + self._parameters = tuple(_parse_parameters(self.__args__)) @property def __call__(self) -> type: return self.__origin__ - def __reduce__(self: _T) -> Tuple[Type[_T], Tuple[type, Tuple[Any, ...]]]: + def __reduce__(self: _T) -> Tuple[ + Type[_T], + Tuple[type, Tuple[object, ...]], + ]: cls = type(self) return cls, (self.__origin__, self.__args__) @@ -148,7 +155,7 @@ class _GenericAlias: origin = _to_str(self.__origin__) return f"{origin}[{args}]" - def __getitem__(self: _T, key: Any) -> _T: + def __getitem__(self: _T, key: object | Tuple[object, ...]) -> _T: """Return ``self[key]``.""" key_tup = key if isinstance(key, tuple) else (key,) diff --git a/numpy/typing/tests/test_generic_alias.py b/numpy/typing/tests/test_generic_alias.py index 0b9917439..27afe3927 100644 --- a/numpy/typing/tests/test_generic_alias.py +++ b/numpy/typing/tests/test_generic_alias.py @@ -41,6 +41,12 @@ class TestGenericAlias: @pytest.mark.parametrize("name,func", [ ("__init__", lambda n: n), + ("__init__", lambda n: _GenericAlias(np.ndarray, Any)), + ("__init__", lambda n: _GenericAlias(np.ndarray, (Any,))), + ("__init__", lambda n: _GenericAlias(np.ndarray, (Any, Any))), + ("__init__", lambda n: _GenericAlias(np.ndarray, T1)), + ("__init__", lambda n: _GenericAlias(np.ndarray, (T1,))), + ("__init__", lambda n: _GenericAlias(np.ndarray, (T1, T2))), ("__origin__", lambda n: n.__origin__), ("__args__", lambda n: n.__args__), ("__parameters__", lambda n: n.__parameters__), |