summaryrefslogtreecommitdiff
path: root/numpy/typing/_generic_alias.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/typing/_generic_alias.py')
-rw-r--r--numpy/typing/_generic_alias.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py
index 8d65ef855..932f12dd0 100644
--- a/numpy/typing/_generic_alias.py
+++ b/numpy/typing/_generic_alias.py
@@ -51,7 +51,7 @@ def _parse_parameters(args: Iterable[Any]) -> Generator[TypeVar, None, None]:
def _reconstruct_alias(alias: _T, parameters: Iterator[TypeVar]) -> _T:
- """Recursivelly replace all typevars with those from `parameters`.
+ """Recursively replace all typevars with those from `parameters`.
Helper function for `_GenericAlias.__getitem__`.
@@ -205,12 +205,9 @@ else:
ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
-if TYPE_CHECKING:
+if TYPE_CHECKING or sys.version_info >= (3, 9):
_DType = np.dtype[ScalarType]
NDArray = np.ndarray[Any, np.dtype[ScalarType]]
-elif sys.version_info >= (3, 9):
- _DType = types.GenericAlias(np.dtype, (ScalarType,))
- NDArray = types.GenericAlias(np.ndarray, (Any, _DType))
else:
_DType = _GenericAlias(np.dtype, (ScalarType,))
NDArray = _GenericAlias(np.ndarray, (Any, _DType))