summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2022-05-25 15:48:22 +0200
committerBas van Beek <43369155+BvB93@users.noreply.github.com>2022-05-26 16:24:53 +0200
commit4461ec48e573bb30d5f1de2a07680067e95861c9 (patch)
tree07a8761fe3c6e21273a0576fc366eb9586edd3f2 /numpy
parent7e15fd77ca6c09989f219acf432c98a1036d14c5 (diff)
downloadnumpy-4461ec48e573bb30d5f1de2a07680067e95861c9.tar.gz
MAINT: Adapt the `npt._GenericAlias` backport to Python 3.11 `types.GenericAlias` changes
Diffstat (limited to 'numpy')
-rw-r--r--numpy/_typing/_generic_alias.py49
1 files changed, 42 insertions, 7 deletions
diff --git a/numpy/_typing/_generic_alias.py b/numpy/_typing/_generic_alias.py
index 0541ad77f..d32814a72 100644
--- a/numpy/_typing/_generic_alias.py
+++ b/numpy/_typing/_generic_alias.py
@@ -64,7 +64,7 @@ def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
args.append(value)
cls = type(alias)
- return cls(alias.__origin__, tuple(args))
+ return cls(alias.__origin__, tuple(args), alias.__unpacked__)
class _GenericAlias:
@@ -80,7 +80,14 @@ class _GenericAlias:
"""
- __slots__ = ("__weakref__", "_origin", "_args", "_parameters", "_hash")
+ __slots__ = (
+ "__weakref__",
+ "_origin",
+ "_args",
+ "_parameters",
+ "_hash",
+ "_starred",
+ )
@property
def __origin__(self) -> type:
@@ -95,14 +102,27 @@ class _GenericAlias:
"""Type variables in the ``GenericAlias``."""
return super().__getattribute__("_parameters")
+ @property
+ def __unpacked__(self) -> bool:
+ return super().__getattribute__("_starred")
+
+ @property
+ def __typing_unpacked_tuple_args__(self) -> tuple[object, ...] | None:
+ # NOTE: This should return `__args__` if `__origin__` is a tuple,
+ # which should never be the case with how `_GenericAlias` is used
+ # within numpy
+ return None
+
def __init__(
self,
origin: type,
args: object | tuple[object, ...],
+ starred: bool = False,
) -> None:
self._origin = origin
self._args = args if isinstance(args, tuple) else (args,)
self._parameters = tuple(_parse_parameters(self.__args__))
+ self._starred = starred
@property
def __call__(self) -> type[Any]:
@@ -110,10 +130,10 @@ class _GenericAlias:
def __reduce__(self: _T) -> tuple[
type[_T],
- tuple[type[Any], tuple[object, ...]],
+ tuple[type[Any], tuple[object, ...], bool],
]:
cls = type(self)
- return cls, (self.__origin__, self.__args__)
+ return cls, (self.__origin__, self.__args__, self.__unpacked__)
def __mro_entries__(self, bases: Iterable[object]) -> tuple[type[Any]]:
return (self.__origin__,)
@@ -130,7 +150,11 @@ class _GenericAlias:
try:
return super().__getattribute__("_hash")
except AttributeError:
- self._hash: int = hash(self.__origin__) ^ hash(self.__args__)
+ self._hash: int = (
+ hash(self.__origin__) ^
+ hash(self.__args__) ^
+ hash(self.__unpacked__)
+ )
return super().__getattribute__("_hash")
def __instancecheck__(self, obj: object) -> NoReturn:
@@ -147,7 +171,8 @@ class _GenericAlias:
"""Return ``repr(self)``."""
args = ", ".join(_to_str(i) for i in self.__args__)
origin = _to_str(self.__origin__)
- return f"{origin}[{args}]"
+ prefix = "*" if self.__unpacked__ else ""
+ return f"{prefix}{origin}[{args}]"
def __getitem__(self: _T, key: object | tuple[object, ...]) -> _T:
"""Return ``self[key]``."""
@@ -169,9 +194,17 @@ class _GenericAlias:
return NotImplemented
return (
self.__origin__ == value.__origin__ and
- self.__args__ == value.__args__
+ self.__args__ == value.__args__ and
+ self.__unpacked__ == getattr(
+ value, "__unpacked__", self.__unpacked__
+ )
)
+ def __iter__(self: _T) -> Generator[_T, None, None]:
+ """Return ``iter(self)``."""
+ cls = type(self)
+ yield cls(self.__origin__, self.__args__, True)
+
_ATTR_EXCEPTIONS: ClassVar[frozenset[str]] = frozenset({
"__origin__",
"__args__",
@@ -181,6 +214,8 @@ class _GenericAlias:
"__reduce_ex__",
"__copy__",
"__deepcopy__",
+ "__unpacked__",
+ "__typing_unpacked_tuple_args__",
})
def __getattribute__(self, name: str) -> Any: