summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/typing/_add_docstring.py4
-rw-r--r--numpy/typing/_generic_alias.py5
-rw-r--r--numpy/typing/tests/test_generic_alias.py5
3 files changed, 9 insertions, 5 deletions
diff --git a/numpy/typing/_add_docstring.py b/numpy/typing/_add_docstring.py
index 34dbdb0c6..56ef41cfd 100644
--- a/numpy/typing/_add_docstring.py
+++ b/numpy/typing/_add_docstring.py
@@ -114,7 +114,7 @@ add_newdoc('DTypeLike', 'typing.Union[...]',
add_newdoc('NDArray', repr(NDArray),
"""
A :term:`generic <generic type>` version of
- `np.ndarray[Any, np.dtype[~ScalarType]] <numpy.ndarray>`.
+ `np.ndarray[Any, np.dtype[+ScalarType]] <numpy.ndarray>`.
Can be used during runtime for typing arrays with a given dtype
and unspecified shape.
@@ -127,7 +127,7 @@ add_newdoc('NDArray', repr(NDArray),
>>> import numpy.typing as npt
>>> print(npt.NDArray)
- numpy.ndarray[typing.Any, numpy.dtype[~ScalarType]]
+ numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]]
>>> print(npt.NDArray[np.float64])
numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]
diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py
index f98fca62e..68523827a 100644
--- a/numpy/typing/_generic_alias.py
+++ b/numpy/typing/_generic_alias.py
@@ -63,7 +63,8 @@ def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
elif isinstance(i, _GenericAlias):
value = _reconstruct_alias(i, parameters)
elif hasattr(i, "__parameters__"):
- value = i[next(parameters)]
+ prm_tup = tuple(next(parameters) for _ in i.__parameters__)
+ value = i[prm_tup]
else:
value = i
args.append(value)
@@ -195,7 +196,7 @@ if sys.version_info >= (3, 9):
else:
_GENERIC_ALIAS_TYPE = (_GenericAlias,)
-ScalarType = TypeVar("ScalarType", bound=np.generic)
+ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
if TYPE_CHECKING:
NDArray = np.ndarray[Any, np.dtype[ScalarType]]
diff --git a/numpy/typing/tests/test_generic_alias.py b/numpy/typing/tests/test_generic_alias.py
index 13072051a..0b9917439 100644
--- a/numpy/typing/tests/test_generic_alias.py
+++ b/numpy/typing/tests/test_generic_alias.py
@@ -10,7 +10,9 @@ import pytest
import numpy as np
from numpy.typing._generic_alias import _GenericAlias
-ScalarType = TypeVar("ScalarType", bound=np.generic)
+ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
+T1 = TypeVar("T1")
+T2 = TypeVar("T2")
DType = _GenericAlias(np.dtype, (ScalarType,))
NDArray = _GenericAlias(np.ndarray, (Any, DType))
@@ -50,6 +52,7 @@ class TestGenericAlias:
("__getitem__", lambda n: n[np.float64]),
("__getitem__", lambda n: n[ScalarType][np.float64]),
("__getitem__", lambda n: n[Union[np.int64, ScalarType]][np.float64]),
+ ("__getitem__", lambda n: n[Union[T1, T2]][np.float32, np.float64]),
("__eq__", lambda n: n == n),
("__ne__", lambda n: n != np.ndarray),
("__dir__", lambda n: dir(n)),