summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-05-24 15:01:48 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-05-27 17:24:05 +0200
commit44e72e771bff812c00e48bd6c94d90f9f46a7254 (patch)
tree306a8de927ed49eae752e0e65fe376953f19e651
parentc646ab78d780a9bb06cdd8d7fba6b29092aa7507 (diff)
downloadnumpy-44e72e771bff812c00e48bd6c94d90f9f46a7254.tar.gz
STY: Use `GenericAlias` to get rid of string-based literal expressions
-rw-r--r--numpy/typing/__init__.py1
-rw-r--r--numpy/typing/_dtype_like.py63
-rw-r--r--numpy/typing/_generic_alias.py1
3 files changed, 37 insertions, 28 deletions
diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py
index ba41cc68a..f27e6ff1a 100644
--- a/numpy/typing/__init__.py
+++ b/numpy/typing/__init__.py
@@ -360,6 +360,7 @@ from ._array_like import (
)
from ._generic_alias import (
NDArray as NDArray,
+ _DType,
_GenericAlias,
)
diff --git a/numpy/typing/_dtype_like.py b/numpy/typing/_dtype_like.py
index 5d5dadc2e..405cc4a3c 100644
--- a/numpy/typing/_dtype_like.py
+++ b/numpy/typing/_dtype_like.py
@@ -5,12 +5,18 @@ import numpy as np
from . import _HAS_TYPING_EXTENSIONS
from ._shape import _ShapeLike
+from ._generic_alias import _DType as DType
if sys.version_info >= (3, 8):
from typing import Protocol, TypedDict
elif _HAS_TYPING_EXTENSIONS:
from typing_extensions import Protocol, TypedDict
+if sys.version_info >= (3, 9):
+ from types import GenericAlias
+else:
+ from ._generic_alias import _GenericAlias as GenericAlias
+
from ._char_codes import (
_BoolCodes,
_UInt8Codes,
@@ -54,6 +60,7 @@ from ._char_codes import (
)
_DTypeLikeNested = Any # TODO: wait for support for recursive types
+_DType_co = TypeVar("_DType_co", covariant=True, bound=DType[Any])
if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS:
# Mandatory keys
@@ -68,8 +75,6 @@ if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS:
itemsize: int
aligned: bool
- _DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype)
-
# A protocol for anything with the dtype attribute
class _SupportsDType(Protocol[_DType_co]):
@property
@@ -77,7 +82,9 @@ if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS:
else:
_DTypeDict = NotImplemented
- _SupportsDType = NotImplemented
+
+ class _SupportsDType: ...
+ _SupportsDType = GenericAlias(_SupportsDType, _DType_co)
# Would create a dtype[np.void]
@@ -102,13 +109,13 @@ _VoidDTypeLike = Union[
# Anything that can be coerced into numpy.dtype.
# Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
DTypeLike = Union[
- np.dtype,
+ DType[Any],
# default data type (float64)
None,
# array-scalar types and generic types
- type, # TODO: enumerate these when we add type hints for numpy scalars
+ Type[Any], # TODO: enumerate these when we add type hints for numpy scalars
# anything with a dtype attribute
- "_SupportsDType[np.dtype[Any]]",
+ _SupportsDType[DType[Any]],
# character codes, type strings or comma-separated fields, e.g., 'float64'
str,
_VoidDTypeLike,
@@ -126,14 +133,14 @@ DTypeLike = Union[
_DTypeLikeBool = Union[
Type[bool],
Type[np.bool_],
- "np.dtype[np.bool_]",
- "_SupportsDType[np.dtype[np.bool_]]",
+ DType[np.bool_],
+ _SupportsDType[DType[np.bool_]],
_BoolCodes,
]
_DTypeLikeUInt = Union[
Type[np.unsignedinteger],
- "np.dtype[np.unsignedinteger]",
- "_SupportsDType[np.dtype[np.unsignedinteger]]",
+ DType[np.unsignedinteger],
+ _SupportsDType[DType[np.unsignedinteger]],
_UInt8Codes,
_UInt16Codes,
_UInt32Codes,
@@ -148,8 +155,8 @@ _DTypeLikeUInt = Union[
_DTypeLikeInt = Union[
Type[int],
Type[np.signedinteger],
- "np.dtype[np.signedinteger]",
- "_SupportsDType[np.dtype[np.signedinteger]]",
+ DType[np.signedinteger],
+ _SupportsDType[DType[np.signedinteger]],
_Int8Codes,
_Int16Codes,
_Int32Codes,
@@ -164,8 +171,8 @@ _DTypeLikeInt = Union[
_DTypeLikeFloat = Union[
Type[float],
Type[np.floating],
- "np.dtype[np.floating]",
- "_SupportsDType[np.dtype[np.floating]]",
+ DType[np.floating],
+ _SupportsDType[DType[np.floating]],
_Float16Codes,
_Float32Codes,
_Float64Codes,
@@ -177,8 +184,8 @@ _DTypeLikeFloat = Union[
_DTypeLikeComplex = Union[
Type[complex],
Type[np.complexfloating],
- "np.dtype[np.complexfloating]",
- "_SupportsDType[np.dtype[np.complexfloating]]",
+ DType[np.complexfloating],
+ _SupportsDType[DType[np.complexfloating]],
_Complex64Codes,
_Complex128Codes,
_CSingleCodes,
@@ -187,41 +194,41 @@ _DTypeLikeComplex = Union[
]
_DTypeLikeDT64 = Union[
Type[np.timedelta64],
- "np.dtype[np.timedelta64]",
- "_SupportsDType[np.dtype[np.timedelta64]]",
+ DType[np.timedelta64],
+ _SupportsDType[DType[np.timedelta64]],
_TD64Codes,
]
_DTypeLikeTD64 = Union[
Type[np.datetime64],
- "np.dtype[np.datetime64]",
- "_SupportsDType[np.dtype[np.datetime64]]",
+ DType[np.datetime64],
+ _SupportsDType[DType[np.datetime64]],
_DT64Codes,
]
_DTypeLikeStr = Union[
Type[str],
Type[np.str_],
- "np.dtype[np.str_]",
- "_SupportsDType[np.dtype[np.str_]]",
+ DType[np.str_],
+ _SupportsDType[DType[np.str_]],
_StrCodes,
]
_DTypeLikeBytes = Union[
Type[bytes],
Type[np.bytes_],
- "np.dtype[np.bytes_]",
- "_SupportsDType[np.dtype[np.bytes_]]",
+ DType[np.bytes_],
+ _SupportsDType[DType[np.bytes_]],
_BytesCodes,
]
_DTypeLikeVoid = Union[
Type[np.void],
- "np.dtype[np.void]",
- "_SupportsDType[np.dtype[np.void]]",
+ DType[np.void],
+ _SupportsDType[DType[np.void]],
_VoidCodes,
_VoidDTypeLike,
]
_DTypeLikeObject = Union[
type,
- "np.dtype[np.object_]",
- "_SupportsDType[np.dtype[np.object_]]",
+ DType[np.object_],
+ _SupportsDType[DType[np.object_]],
_ObjectCodes,
]
diff --git a/numpy/typing/_generic_alias.py b/numpy/typing/_generic_alias.py
index 68523827a..0d30f54ca 100644
--- a/numpy/typing/_generic_alias.py
+++ b/numpy/typing/_generic_alias.py
@@ -199,6 +199,7 @@ else:
ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
if TYPE_CHECKING:
+ _DType = np.dtype[ScalarType]
NDArray = np.ndarray[Any, np.dtype[ScalarType]]
elif sys.version_info >= (3, 9):
_DType = types.GenericAlias(np.dtype, (ScalarType,))