summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-05-24 14:45:49 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-05-27 17:24:04 +0200
commit66fa0481c0dd85d4f90dc56afe3f1b42b96945ad (patch)
tree127daaf75e88eeb5e55bc876e2844f2970ec9700
parent89da72353f5e282a36a8e9ad9012400dbe452ced (diff)
downloadnumpy-66fa0481c0dd85d4f90dc56afe3f1b42b96945ad.tar.gz
ENH: Add a global constant to `numpy.typing` denoting whether or not `typing_extensions` is available
-rw-r--r--numpy/typing/__init__.py12
-rw-r--r--numpy/typing/_array_like.py15
-rw-r--r--numpy/typing/_callable.py16
-rw-r--r--numpy/typing/_char_codes.py16
-rw-r--r--numpy/typing/_dtype_like.py14
-rw-r--r--numpy/typing/_shape.py9
6 files changed, 39 insertions, 43 deletions
diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py
index 1bfdf07ae..80b4b18d4 100644
--- a/numpy/typing/__init__.py
+++ b/numpy/typing/__init__.py
@@ -164,6 +164,18 @@ API
from typing import TYPE_CHECKING, List
if TYPE_CHECKING:
+ # typing_extensions is always available when type-checking
+ from typing_extensions import Literal as L
+ _HAS_TYPING_EXTENSIONS: L[True]
+else:
+ try:
+ import typing_extensions
+ except ImportError:
+ _HAS_TYPING_EXTENSIONS = False
+ else:
+ _HAS_TYPING_EXTENSIONS = True
+
+if TYPE_CHECKING:
import sys
if sys.version_info >= (3, 8):
from typing import final
diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py
index 2283c98d7..2b823ecc0 100644
--- a/numpy/typing/_array_like.py
+++ b/numpy/typing/_array_like.py
@@ -21,23 +21,20 @@ from numpy import (
bytes_,
)
+from . import _HAS_TYPING_EXTENSIONS
+from ._dtype_like import DTypeLike
+
if sys.version_info >= (3, 8):
from typing import Protocol
- HAVE_PROTOCOL = True
-else:
- try:
- from typing_extensions import Protocol
- except ImportError:
- HAVE_PROTOCOL = False
- else:
- HAVE_PROTOCOL = True
+elif _HAS_TYPING_EXTENSIONS:
+ from typing_extensions import Protocol
_T = TypeVar("_T")
_ScalarType = TypeVar("_ScalarType", bound=generic)
_DType = TypeVar("_DType", bound="dtype[Any]")
_DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]")
-if TYPE_CHECKING or HAVE_PROTOCOL:
+if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS:
# The `_SupportsArray` protocol only cares about the default dtype
# (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
# array.
diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py
index 8f838f1ae..bb39559f4 100644
--- a/numpy/typing/_callable.py
+++ b/numpy/typing/_callable.py
@@ -45,21 +45,15 @@ from ._scalars import (
_FloatLike_co,
_NumberLike_co,
)
-from . import NBitBase
+from . import NBitBase, _HAS_TYPING_EXTENSIONS
from ._generic_alias import NDArray
if sys.version_info >= (3, 8):
from typing import Protocol
- HAVE_PROTOCOL = True
-else:
- try:
- from typing_extensions import Protocol
- except ImportError:
- HAVE_PROTOCOL = False
- else:
- HAVE_PROTOCOL = True
-
-if TYPE_CHECKING or HAVE_PROTOCOL:
+elif _HAS_TYPING_EXTENSIONS:
+ from typing_extensions import Protocol
+
+if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS:
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_2Tuple = Tuple[_T1, _T1]
diff --git a/numpy/typing/_char_codes.py b/numpy/typing/_char_codes.py
index 6b6f7ae88..24d39c62e 100644
--- a/numpy/typing/_char_codes.py
+++ b/numpy/typing/_char_codes.py
@@ -1,18 +1,14 @@
import sys
from typing import Any, TYPE_CHECKING
+from . import _HAS_TYPING_EXTENSIONS
+
if sys.version_info >= (3, 8):
from typing import Literal
- HAVE_LITERAL = True
-else:
- try:
- from typing_extensions import Literal
- except ImportError:
- HAVE_LITERAL = False
- else:
- HAVE_LITERAL = True
-
-if TYPE_CHECKING or HAVE_LITERAL:
+elif _HAS_TYPING_EXTENSIONS:
+ from typing_extensions import Literal
+
+if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS:
_BoolCodes = Literal["?", "=?", "<?", ">?", "bool", "bool_", "bool8"]
_UInt8Codes = Literal["uint8", "u1", "=u1", "<u1", ">u1"]
diff --git a/numpy/typing/_dtype_like.py b/numpy/typing/_dtype_like.py
index a41e2f358..3903105c4 100644
--- a/numpy/typing/_dtype_like.py
+++ b/numpy/typing/_dtype_like.py
@@ -2,18 +2,14 @@ import sys
from typing import Any, List, Sequence, Tuple, Union, Type, TypeVar, TYPE_CHECKING
import numpy as np
+
+from . import _HAS_TYPING_EXTENSIONS
from ._shape import _ShapeLike
if sys.version_info >= (3, 8):
from typing import Protocol, TypedDict
- HAVE_PROTOCOL = True
-else:
- try:
- from typing_extensions import Protocol, TypedDict
- except ImportError:
- HAVE_PROTOCOL = False
- else:
- HAVE_PROTOCOL = True
+elif _HAS_TYPING_EXTENSIONS:
+ from typing_extensions import Protocol, TypedDict
from ._char_codes import (
_BoolCodes,
@@ -59,7 +55,7 @@ from ._char_codes import (
_DTypeLikeNested = Any # TODO: wait for support for recursive types
-if TYPE_CHECKING or HAVE_PROTOCOL:
+if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS:
# Mandatory keys
class _DTypeDictBase(TypedDict):
names: Sequence[str]
diff --git a/numpy/typing/_shape.py b/numpy/typing/_shape.py
index b720c3ffc..0742be8a9 100644
--- a/numpy/typing/_shape.py
+++ b/numpy/typing/_shape.py
@@ -1,13 +1,14 @@
import sys
from typing import Sequence, Tuple, Union
+from . import _HAS_TYPING_EXTENSIONS
+
if sys.version_info >= (3, 8):
from typing import SupportsIndex
+elif _HAS_TYPING_EXTENSIONS:
+ from typing_extensions import SupportsIndex
else:
- try:
- from typing_extensions import SupportsIndex
- except ImportError:
- SupportsIndex = NotImplemented
+ SupportsIndex = NotImplemented
_Shape = Tuple[int, ...]