summaryrefslogtreecommitdiff
path: root/numpy/typing/mypy_plugin.py
blob: 3418701675e55005263bef58962efb5ae06e5dcb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""A module containing `numpy`-specific plugins for mypy."""

import typing as t

import numpy as np

import mypy.types
from mypy.types import Type
from mypy.plugin import Plugin, AnalyzeTypeContext

__all__: t.List[str] = []

HookFunc = t.Callable[[AnalyzeTypeContext], Type]


def _get_precision_dict() -> t.Dict[str, str]:
    names = [
        ("_NBitByte", np.byte),
        ("_NBitShort", np.short),
        ("_NBitIntC", np.intc),
        ("_NBitIntP", np.intp),
        ("_NBitInt", np.int_),
        ("_NBitLongLong", np.longlong),

        ("_NBitHalf", np.half),
        ("_NBitSingle", np.single),
        ("_NBitDouble", np.double),
        ("_NBitLongDouble", np.longdouble),
    ]
    ret = {}
    for name, typ in names:
        n: int = 8 * typ().dtype.alignment
        ret[f'numpy.typing._nbit.{name}'] = f"numpy._{n}Bit"
    return ret


#: A dictionary mapping type-aliases in `numpy.typing._nbit` to
#: concrete `numpy.typing.NBitBase` subclasses.
_PRECISION_DICT = _get_precision_dict()


def _hook(ctx: AnalyzeTypeContext) -> Type:
    """Replace a type-alias with a concrete ``NBitBase`` subclass."""
    typ, _, api = ctx
    name = typ.name.split(".")[-1]
    name_new = _PRECISION_DICT[f"numpy.typing._nbit.{name}"]
    return api.named_type(name_new)


class _NumpyPlugin(Plugin):
    """A plugin for assigning platform-specific `numpy.number` precisions."""

    def get_type_analyze_hook(self, fullname: str) -> t.Optional[HookFunc]:
        if fullname in _PRECISION_DICT:
            return _hook
        return None


def plugin(version: str) -> t.Type[_NumpyPlugin]:
    """An entry-point for mypy."""
    return _NumpyPlugin