diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/typing/_add_docstring.py | 4 | ||||
| -rw-r--r-- | numpy/typing/_generic_alias.py | 5 | ||||
| -rw-r--r-- | numpy/typing/tests/test_generic_alias.py | 5 |
3 files changed, 9 insertions, 5 deletions
diff --git a/numpy/typing/_add_docstring.py b/numpy/typing/_add_docstring.py index 34dbdb0c6..56ef41cfd 100644 --- a/numpy/typing/_add_docstring.py +++ b/numpy/typing/_add_docstring.py @@ -114,7 +114,7 @@ add_newdoc('DTypeLike', 'typing.Union[...]', add_newdoc('NDArray', repr(NDArray), """ A :term:`generic <generic type>` version of - `np.ndarray[Any, np.dtype[~ScalarType]] <numpy.ndarray>`. + `np.ndarray[Any, np.dtype[+ScalarType]] <numpy.ndarray>`. Can be used during runtime for typing arrays with a given dtype and unspecified shape. @@ -127,7 +127,7 @@ add_newdoc('NDArray', repr(NDArray), >>> import numpy.typing as npt >>> print(npt.NDArray) - numpy.ndarray[typing.Any, numpy.dtype[~ScalarType]] + numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]] >>> print(npt.NDArray[np.float64]) numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]] diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py index f98fca62e..68523827a 100644 --- a/numpy/typing/_generic_alias.py +++ b/numpy/typing/_generic_alias.py @@ -63,7 +63,8 @@ def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T: elif isinstance(i, _GenericAlias): value = _reconstruct_alias(i, parameters) elif hasattr(i, "__parameters__"): - value = i[next(parameters)] + prm_tup = tuple(next(parameters) for _ in i.__parameters__) + value = i[prm_tup] else: value = i args.append(value) @@ -195,7 +196,7 @@ if sys.version_info >= (3, 9): else: _GENERIC_ALIAS_TYPE = (_GenericAlias,) -ScalarType = TypeVar("ScalarType", bound=np.generic) +ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True) if TYPE_CHECKING: NDArray = np.ndarray[Any, np.dtype[ScalarType]] diff --git a/numpy/typing/tests/test_generic_alias.py b/numpy/typing/tests/test_generic_alias.py index 13072051a..0b9917439 100644 --- a/numpy/typing/tests/test_generic_alias.py +++ b/numpy/typing/tests/test_generic_alias.py @@ -10,7 +10,9 @@ import pytest import numpy as np from numpy.typing._generic_alias import _GenericAlias -ScalarType = TypeVar("ScalarType", bound=np.generic) +ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True) +T1 = TypeVar("T1") +T2 = TypeVar("T2") DType = _GenericAlias(np.dtype, (ScalarType,)) NDArray = _GenericAlias(np.ndarray, (Any, DType)) @@ -50,6 +52,7 @@ class TestGenericAlias: ("__getitem__", lambda n: n[np.float64]), ("__getitem__", lambda n: n[ScalarType][np.float64]), ("__getitem__", lambda n: n[Union[np.int64, ScalarType]][np.float64]), + ("__getitem__", lambda n: n[Union[T1, T2]][np.float32, np.float64]), ("__eq__", lambda n: n == n), ("__ne__", lambda n: n != np.ndarray), ("__dir__", lambda n: dir(n)), |
