summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2022-05-27 08:28:54 +0300
committerGitHub <noreply@github.com>2022-05-27 08:28:54 +0300
commit18c3671429f56fb07106b91a0eaf37a9e3e83809 (patch)
tree8e01d56529740a84166898d6d449208c123624ee /numpy
parentb101756ac02e390d605b2febcded30a1da50cc2c (diff)
parent6a947e53660ea4cce7f1d5fa91b19cb9f6f86b51 (diff)
downloadnumpy-18c3671429f56fb07106b91a0eaf37a9e3e83809.tar.gz
Merge pull request #21605 from BvB93/alias
MAINT: Adapt the `npt._GenericAlias` backport to Python 3.11 `types.GenericAlias` changes
Diffstat (limited to 'numpy')
-rw-r--r--numpy/_typing/_generic_alias.py49
-rw-r--r--numpy/typing/tests/test_generic_alias.py43
2 files changed, 82 insertions, 10 deletions
diff --git a/numpy/_typing/_generic_alias.py b/numpy/_typing/_generic_alias.py
index 0541ad77f..d32814a72 100644
--- a/numpy/_typing/_generic_alias.py
+++ b/numpy/_typing/_generic_alias.py
@@ -64,7 +64,7 @@ def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
args.append(value)
cls = type(alias)
- return cls(alias.__origin__, tuple(args))
+ return cls(alias.__origin__, tuple(args), alias.__unpacked__)
class _GenericAlias:
@@ -80,7 +80,14 @@ class _GenericAlias:
"""
- __slots__ = ("__weakref__", "_origin", "_args", "_parameters", "_hash")
+ __slots__ = (
+ "__weakref__",
+ "_origin",
+ "_args",
+ "_parameters",
+ "_hash",
+ "_starred",
+ )
@property
def __origin__(self) -> type:
@@ -95,14 +102,27 @@ class _GenericAlias:
"""Type variables in the ``GenericAlias``."""
return super().__getattribute__("_parameters")
+ @property
+ def __unpacked__(self) -> bool:
+ return super().__getattribute__("_starred")
+
+ @property
+ def __typing_unpacked_tuple_args__(self) -> tuple[object, ...] | None:
+ # NOTE: This should return `__args__` if `__origin__` is a tuple,
+ # which should never be the case with how `_GenericAlias` is used
+ # within numpy
+ return None
+
def __init__(
self,
origin: type,
args: object | tuple[object, ...],
+ starred: bool = False,
) -> None:
self._origin = origin
self._args = args if isinstance(args, tuple) else (args,)
self._parameters = tuple(_parse_parameters(self.__args__))
+ self._starred = starred
@property
def __call__(self) -> type[Any]:
@@ -110,10 +130,10 @@ class _GenericAlias:
def __reduce__(self: _T) -> tuple[
type[_T],
- tuple[type[Any], tuple[object, ...]],
+ tuple[type[Any], tuple[object, ...], bool],
]:
cls = type(self)
- return cls, (self.__origin__, self.__args__)
+ return cls, (self.__origin__, self.__args__, self.__unpacked__)
def __mro_entries__(self, bases: Iterable[object]) -> tuple[type[Any]]:
return (self.__origin__,)
@@ -130,7 +150,11 @@ class _GenericAlias:
try:
return super().__getattribute__("_hash")
except AttributeError:
- self._hash: int = hash(self.__origin__) ^ hash(self.__args__)
+ self._hash: int = (
+ hash(self.__origin__) ^
+ hash(self.__args__) ^
+ hash(self.__unpacked__)
+ )
return super().__getattribute__("_hash")
def __instancecheck__(self, obj: object) -> NoReturn:
@@ -147,7 +171,8 @@ class _GenericAlias:
"""Return ``repr(self)``."""
args = ", ".join(_to_str(i) for i in self.__args__)
origin = _to_str(self.__origin__)
- return f"{origin}[{args}]"
+ prefix = "*" if self.__unpacked__ else ""
+ return f"{prefix}{origin}[{args}]"
def __getitem__(self: _T, key: object | tuple[object, ...]) -> _T:
"""Return ``self[key]``."""
@@ -169,9 +194,17 @@ class _GenericAlias:
return NotImplemented
return (
self.__origin__ == value.__origin__ and
- self.__args__ == value.__args__
+ self.__args__ == value.__args__ and
+ self.__unpacked__ == getattr(
+ value, "__unpacked__", self.__unpacked__
+ )
)
+ def __iter__(self: _T) -> Generator[_T, None, None]:
+ """Return ``iter(self)``."""
+ cls = type(self)
+ yield cls(self.__origin__, self.__args__, True)
+
_ATTR_EXCEPTIONS: ClassVar[frozenset[str]] = frozenset({
"__origin__",
"__args__",
@@ -181,6 +214,8 @@ class _GenericAlias:
"__reduce_ex__",
"__copy__",
"__deepcopy__",
+ "__unpacked__",
+ "__typing_unpacked_tuple_args__",
})
def __getattribute__(self, name: str) -> Any:
diff --git a/numpy/typing/tests/test_generic_alias.py b/numpy/typing/tests/test_generic_alias.py
index ae55ef439..093e12109 100644
--- a/numpy/typing/tests/test_generic_alias.py
+++ b/numpy/typing/tests/test_generic_alias.py
@@ -10,6 +10,7 @@ from typing import TypeVar, Any, Union, Callable
import pytest
import numpy as np
from numpy._typing._generic_alias import _GenericAlias
+from typing_extensions import Unpack
ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
T1 = TypeVar("T1")
@@ -55,8 +56,6 @@ class TestGenericAlias:
("__origin__", lambda n: n.__origin__),
("__args__", lambda n: n.__args__),
("__parameters__", lambda n: n.__parameters__),
- ("__reduce__", lambda n: n.__reduce__()[1:]),
- ("__reduce_ex__", lambda n: n.__reduce_ex__(1)[1:]),
("__mro_entries__", lambda n: n.__mro_entries__([object])),
("__hash__", lambda n: hash(n)),
("__repr__", lambda n: repr(n)),
@@ -66,7 +65,6 @@ class TestGenericAlias:
("__getitem__", lambda n: n[Union[T1, T2]][np.float32, np.float64]),
("__eq__", lambda n: n == n),
("__ne__", lambda n: n != np.ndarray),
- ("__dir__", lambda n: dir(n)),
("__call__", lambda n: n((1,), np.int64, BUFFER)),
("__call__", lambda n: n(shape=(1,), dtype=np.int64, buffer=BUFFER)),
("subclassing", lambda n: _get_subclass_mro(n)),
@@ -100,6 +98,45 @@ class TestGenericAlias:
value_ref = func(NDArray_ref)
assert value == value_ref
+ def test_dir(self) -> None:
+ value = dir(NDArray)
+ if sys.version_info < (3, 9):
+ return
+
+ # A number attributes only exist in `types.GenericAlias` in >= 3.11
+ if sys.version_info < (3, 11, 0, "beta", 3):
+ value.remove("__typing_unpacked_tuple_args__")
+ if sys.version_info < (3, 11, 0, "beta", 1):
+ value.remove("__unpacked__")
+ assert value == dir(NDArray_ref)
+
+ @pytest.mark.parametrize("name,func,dev_version", [
+ ("__iter__", lambda n: len(list(n)), ("beta", 1)),
+ ("__iter__", lambda n: next(iter(n)), ("beta", 1)),
+ ("__unpacked__", lambda n: n.__unpacked__, ("beta", 1)),
+ ("Unpack", lambda n: Unpack[n], ("beta", 1)),
+
+ # The right operand should now have `__unpacked__ = True`,
+ # and they are thus now longer equivalent
+ ("__ne__", lambda n: n != next(iter(n)), ("beta", 1)),
+
+ # >= beta3 stuff
+ ("__typing_unpacked_tuple_args__",
+ lambda n: n.__typing_unpacked_tuple_args__, ("beta", 3)),
+ ])
+ def test_py311_features(
+ self,
+ name: str,
+ func: FuncType,
+ dev_version: tuple[str, int],
+ ) -> None:
+ """Test Python 3.11 features."""
+ value = func(NDArray)
+
+ if sys.version_info >= (3, 11, 0, *dev_version):
+ value_ref = func(NDArray_ref)
+ assert value == value_ref
+
def test_weakref(self) -> None:
"""Test ``__weakref__``."""
value = weakref.ref(NDArray)()