summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-02-03 00:50:40 +0100
committerBas van Beek <b.f.van.beek@vu.nl>2021-02-05 17:55:25 +0100
commitaf7106f76e3eb2ab6ada000ced512951f378b5fc (patch)
treefc83b8a7fd3d7a76571e83506a5da79bcd6228fb /numpy
parenta1640ad416c427d397695f51011000a1d7583f22 (diff)
downloadnumpy-af7106f76e3eb2ab6ada000ced512951f378b5fc.tar.gz
ENH: Add a plugin for exposing platform-specific extended-precision `np.number`s
Diffstat (limited to 'numpy')
-rw-r--r--numpy/__init__.pyi3
-rw-r--r--numpy/typing/__init__.py21
-rw-r--r--numpy/typing/mypy_plugin.py60
3 files changed, 76 insertions, 8 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 0312dfad0..e5d5536b8 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -127,6 +127,9 @@ from numpy.typing._callable import (
_NumberOp,
_ComparisonOp,
)
+
+# NOTE: Numpy's mypy plugin is used for removing the types unavailable
+# to the specific platform
from numpy.typing._extended_precision import (
uint128 as uint128,
uint256 as uint256,
diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py
index 8147789fb..8f5df483b 100644
--- a/numpy/typing/__init__.py
+++ b/numpy/typing/__init__.py
@@ -22,14 +22,19 @@ the two below:
Mypy plugin
-----------
-A mypy_ plugin is available for automatically assigning the (platform-dependent)
-precisions of certain `~numpy.number` subclasses, including the likes of
-`~numpy.int_`, `~numpy.intp` and `~numpy.longlong`. See the documentation on
-:ref:`scalar types <arrays.scalars.built-in>` for a comprehensive overview
-of the affected classes.
-
-Note that while usage of the plugin is completely optional, without it the
-precision of above-mentioned classes will be inferred as `~typing.Any`.
+A mypy_ plugin is distributed in `numpy.typing` for managing a number of
+platform-specific annotations. Its function can be split into to parts:
+
+* Assigning the (platform-dependent) precisions of certain `~numpy.number` subclasses,
+ including the likes of `~numpy.int_`, `~numpy.intp` and `~numpy.longlong`.
+ See the documentation on :ref:`scalar types <arrays.scalars.built-in>` for a
+ comprehensive overview of the affected classes. without the plugin the precision
+ of all relevant classes will be inferred as `~typing.Any`.
+* Removing all extended-precision `~numpy.number` subclasses that are unavailable
+ for the platform in question. Most notable this includes the likes of
+ `~numpy.float128` and `~numpy.complex256`. Without the plugin *all*
+ extended-precision types will, as far as mypy is concerned, be available
+ to all platforms.
To enable the plugin, one must add it to their mypy `configuration file`_:
diff --git a/numpy/typing/mypy_plugin.py b/numpy/typing/mypy_plugin.py
index 813ff16d2..901bf4fb1 100644
--- a/numpy/typing/mypy_plugin.py
+++ b/numpy/typing/mypy_plugin.py
@@ -10,6 +10,9 @@ try:
import mypy.types
from mypy.types import Type
from mypy.plugin import Plugin, AnalyzeTypeContext
+ from mypy.nodes import MypyFile, ImportFrom, Statement
+ from mypy.build import PRI_MED
+
_HookFunc = t.Callable[[AnalyzeTypeContext], Type]
MYPY_EX: t.Optional[ModuleNotFoundError] = None
except ModuleNotFoundError as ex:
@@ -39,10 +42,32 @@ def _get_precision_dict() -> t.Dict[str, str]:
return ret
+def _get_extended_precision_list() -> t.List[str]:
+ extended_types = [np.ulonglong, np.longlong, np.longdouble, np.clongdouble]
+ extended_names = {
+ "uint128",
+ "uint256",
+ "int128",
+ "int256",
+ "float80",
+ "float96",
+ "float128",
+ "float256",
+ "complex160",
+ "complex192",
+ "complex256",
+ "complex512",
+ }
+ return [i.__name__ for i in extended_types if i.__name__ in extended_names]
+
+
#: A dictionary mapping type-aliases in `numpy.typing._nbit` to
#: concrete `numpy.typing.NBitBase` subclasses.
_PRECISION_DICT: t.Final = _get_precision_dict()
+#: A list with the names of all extended precision `np.number` subclasses.
+_EXTENDED_PRECISION_LIST: t.Final = _get_extended_precision_list()
+
def _hook(ctx: AnalyzeTypeContext) -> Type:
"""Replace a type-alias with a concrete ``NBitBase`` subclass."""
@@ -53,14 +78,49 @@ def _hook(ctx: AnalyzeTypeContext) -> Type:
if t.TYPE_CHECKING or MYPY_EX is None:
+ def _index(iterable: t.Iterable[Statement], id: str) -> int:
+ """Identify the first ``ImportFrom`` instance the specified `id`."""
+ for i, value in enumerate(iterable):
+ if getattr(value, "id", None) == id:
+ return i
+ else:
+ raise ValueError("Failed to identify a `ImportFrom` instance "
+ f"with the following id: {id!r}")
+
class _NumpyPlugin(Plugin):
"""A plugin for assigning platform-specific `numpy.number` precisions."""
def get_type_analyze_hook(self, fullname: str) -> t.Optional[_HookFunc]:
+ """Set the precision of platform-specific `numpy.number` subclasses.
+
+ For example: `numpy.int_`, `numpy.longlong` and `numpy.longdouble`.
+ """
if fullname in _PRECISION_DICT:
return _hook
return None
+ def get_additional_deps(self, file: MypyFile) -> t.List[t.Tuple[int, str, int]]:
+ """Import platform-specific extended-precision `numpy.number` subclasses.
+
+ For example: `numpy.float96`, `numpy.float128` and `numpy.complex256`.
+ """
+ ret = [(PRI_MED, file.fullname, -1)]
+ if file.fullname == "numpy":
+ # Import ONLY the extended precision types available to the
+ # platform in question
+ imports = ImportFrom(
+ "numpy.typing._extended_precision", 0,
+ names=[(v, v) for v in _EXTENDED_PRECISION_LIST],
+ )
+ imports.is_top_level = True
+
+ # Replace the much broader extended-precision import
+ # (defined in `numpy/__init__.pyi`) with a more specific one
+ for lst in [file.defs, file.imports]: # type: t.List[Statement]
+ i = _index(lst, "numpy.typing._extended_precision")
+ lst[i] = imports
+ return ret
+
def plugin(version: str) -> t.Type[_NumpyPlugin]:
"""An entry-point for mypy."""
return _NumpyPlugin