summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
authorBas van Beek <43369155+BvB93@users.noreply.github.com>2021-09-27 14:43:16 +0200
committerBas van Beek <43369155+BvB93@users.noreply.github.com>2021-09-27 14:52:05 +0200
commita5beccfa3574f4fcb1b6030737b728e65803791f (patch)
tree4cc8eddcbced42cca5f62e8fa41cc5a31af8890d /numpy/array_api
parent7e8285ba48179b19b4d251ac3447569141c7e8f7 (diff)
downloadnumpy-a5beccfa3574f4fcb1b6030737b728e65803791f.tar.gz
MAINT: Fix invalid parameter types used in `Dtype`
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_typing.py27
1 files changed, 21 insertions, 6 deletions
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
index 5f937a56c..f66279fbc 100644
--- a/numpy/array_api/_typing.py
+++ b/numpy/array_api/_typing.py
@@ -15,10 +15,12 @@ __all__ = [
"PyCapsule",
]
-from typing import Any, Literal, Sequence, Type, Union
+import sys
+from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING
-from . import (
- Array,
+from . import Array
+from numpy import (
+ dtype,
int8,
int16,
int32,
@@ -36,9 +38,22 @@ from . import (
NestedSequence = Sequence[Sequence[Any]]
Device = Literal["cpu"]
-Dtype = Type[
- Union[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]
-]
+if TYPE_CHECKING or sys.version_info >= (3, 9):
+ Dtype = dtype[Union[
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ ]]
+else:
+ Dtype = dtype
+
SupportsDLPack = Any
SupportsBufferProtocol = Any
PyCapsule = Any