diff options
author | Bas van Beek <b.f.van.beek@vu.nl> | 2020-10-03 01:06:56 +0200 |
---|---|---|
committer | Bas van Beek <b.f.van.beek@vu.nl> | 2020-10-17 18:05:12 +0200 |
commit | c53797e838f002b14e0d33c9651bffecd9934404 (patch) | |
tree | 3975389ed2be99664d7046b62b7a962b39dc875a | |
parent | 7b0a764fee6e1614f3249e9082d8c4acf1dc62d5 (diff) | |
download | numpy-c53797e838f002b14e0d33c9651bffecd9934404.tar.gz |
ENH: Added support for `number` precision
-rw-r--r-- | numpy/__init__.pyi | 224 | ||||
-rw-r--r-- | numpy/typing/__init__.py | 80 | ||||
-rw-r--r-- | numpy/typing/_callable.py | 107 |
3 files changed, 269 insertions, 142 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 2fff82d59..854d9e1ce 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -17,6 +17,11 @@ from numpy.typing import ( _NumberLike, _SupportsDtype, _VoidDtypeLike, + NBitBase, + _64Bit, + _32Bit, + _16Bit, + _8Bit, ) from numpy.typing._callable import ( _BoolOp, @@ -1606,13 +1611,15 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container): # See https://github.com/numpy/numpy-stubs/pull/80 for more details. +_NBit_co = TypeVar("_NBit_co", covariant=True, bound=NBitBase) + class generic(_ArrayOrScalarCommon): @abstractmethod def __init__(self, *args: Any, **kwargs: Any) -> None: ... @property def base(self) -> None: ... -class number(generic): # type: ignore +class number(generic, Generic[_NBit_co]): # type: ignore @property def real(self: _ArraySelf) -> _ArraySelf: ... @property @@ -1699,12 +1706,12 @@ else: _FloatValue = Union[None, _CharLike, SupportsFloat] _ComplexValue = Union[None, _CharLike, SupportsFloat, SupportsComplex] -class integer(number): # type: ignore +class integer(number[_NBit_co]): # type: ignore # NOTE: `__index__` is technically defined in the bottom-most # sub-classes (`int64`, `uint32`, etc) def __index__(self) -> int: ... - __truediv__: _IntTrueDiv - __rtruediv__: _IntTrueDiv + __truediv__: _IntTrueDiv[_NBit_co] + __rtruediv__: _IntTrueDiv[_NBit_co] def __invert__(self: _IntType) -> _IntType: ... # Ensure that objects annotated as `integer` support bit-wise operations def __lshift__(self, other: Union[_IntLike, _BoolLike]) -> integer: ... @@ -1718,39 +1725,33 @@ class integer(number): # type: ignore def __xor__(self, other: Union[_IntLike, _BoolLike]) -> integer: ... def __rxor__(self, other: Union[_IntLike, _BoolLike]) -> integer: ... -class signedinteger(integer): # type: ignore - __add__: _SignedIntOp - __radd__: _SignedIntOp - __sub__: _SignedIntOp - __rsub__: _SignedIntOp - __mul__: _SignedIntOp - __rmul__: _SignedIntOp - __floordiv__: _SignedIntOp - __rfloordiv__: _SignedIntOp - __pow__: _SignedIntOp - __rpow__: _SignedIntOp - __lshift__: _SignedIntBitOp - __rlshift__: _SignedIntBitOp - __rshift__: _SignedIntBitOp - __rrshift__: _SignedIntBitOp - __and__: _SignedIntBitOp - __rand__: _SignedIntBitOp - __xor__: _SignedIntBitOp - __rxor__: _SignedIntBitOp - __or__: _SignedIntBitOp - __ror__: _SignedIntBitOp - -class int8(signedinteger): - def __init__(self, __value: _IntValue = ...) -> None: ... - -class int16(signedinteger): - def __init__(self, __value: _IntValue = ...) -> None: ... - -class int32(signedinteger): - def __init__(self, __value: _IntValue = ...) -> None: ... - -class int64(signedinteger): +class signedinteger(integer[_NBit_co]): def __init__(self, __value: _IntValue = ...) -> None: ... + __add__: _SignedIntOp[_NBit_co] + __radd__: _SignedIntOp[_NBit_co] + __sub__: _SignedIntOp[_NBit_co] + __rsub__: _SignedIntOp[_NBit_co] + __mul__: _SignedIntOp[_NBit_co] + __rmul__: _SignedIntOp[_NBit_co] + __floordiv__: _SignedIntOp[_NBit_co] + __rfloordiv__: _SignedIntOp[_NBit_co] + __pow__: _SignedIntOp[_NBit_co] + __rpow__: _SignedIntOp[_NBit_co] + __lshift__: _SignedIntBitOp[_NBit_co] + __rlshift__: _SignedIntBitOp[_NBit_co] + __rshift__: _SignedIntBitOp[_NBit_co] + __rrshift__: _SignedIntBitOp[_NBit_co] + __and__: _SignedIntBitOp[_NBit_co] + __rand__: _SignedIntBitOp[_NBit_co] + __xor__: _SignedIntBitOp[_NBit_co] + __rxor__: _SignedIntBitOp[_NBit_co] + __or__: _SignedIntBitOp[_NBit_co] + __ror__: _SignedIntBitOp[_NBit_co] + +int8 = signedinteger[_8Bit] +int16 = signedinteger[_16Bit] +int32 = signedinteger[_32Bit] +int64 = signedinteger[_64Bit] class timedelta64(generic): def __init__( @@ -1765,98 +1766,85 @@ class timedelta64(generic): def __mul__(self, other: Union[_FloatLike, _BoolLike]) -> timedelta64: ... def __rmul__(self, other: Union[_FloatLike, _BoolLike]) -> timedelta64: ... __truediv__: _TD64Div[float64] - __floordiv__: _TD64Div[signedinteger] + __floordiv__: _TD64Div[int64] def __rtruediv__(self, other: timedelta64) -> float64: ... - def __rfloordiv__(self, other: timedelta64) -> signedinteger: ... + def __rfloordiv__(self, other: timedelta64) -> int64: ... def __mod__(self, other: timedelta64) -> timedelta64: ... -class unsignedinteger(integer): # type: ignore +class unsignedinteger(integer[_NBit_co]): # type: ignore # NOTE: `uint64 + signedinteger -> float64` - __add__: _UnsignedIntOp - __radd__: _UnsignedIntOp - __sub__: _UnsignedIntOp - __rsub__: _UnsignedIntOp - __mul__: _UnsignedIntOp - __rmul__: _UnsignedIntOp - __floordiv__: _UnsignedIntOp - __rfloordiv__: _UnsignedIntOp - __pow__: _UnsignedIntOp - __rpow__: _UnsignedIntOp - __lshift__: _UnsignedIntBitOp - __rlshift__: _UnsignedIntBitOp - __rshift__: _UnsignedIntBitOp - __rrshift__: _UnsignedIntBitOp - __and__: _UnsignedIntBitOp - __rand__: _UnsignedIntBitOp - __xor__: _UnsignedIntBitOp - __rxor__: _UnsignedIntBitOp - __or__: _UnsignedIntBitOp - __ror__: _UnsignedIntBitOp - -class uint8(unsignedinteger): - def __init__(self, __value: _IntValue = ...) -> None: ... - -class uint16(unsignedinteger): - def __init__(self, __value: _IntValue = ...) -> None: ... - -class uint32(unsignedinteger): - def __init__(self, __value: _IntValue = ...) -> None: ... - -class uint64(unsignedinteger): - def __init__(self, __value: _IntValue = ...) -> None: ... - -class inexact(number): ... # type: ignore - -class floating(inexact): # type: ignore - __add__: _FloatOp - __radd__: _FloatOp - __sub__: _FloatOp - __rsub__: _FloatOp - __mul__: _FloatOp - __rmul__: _FloatOp - __truediv__: _FloatOp - __rtruediv__: _FloatOp - __floordiv__: _FloatOp - __rfloordiv__: _FloatOp - __pow__: _FloatOp - __rpow__: _FloatOp + __add__: _UnsignedIntOp[_NBit_co] + __radd__: _UnsignedIntOp[_NBit_co] + __sub__: _UnsignedIntOp[_NBit_co] + __rsub__: _UnsignedIntOp[_NBit_co] + __mul__: _UnsignedIntOp[_NBit_co] + __rmul__: _UnsignedIntOp[_NBit_co] + __floordiv__: _UnsignedIntOp[_NBit_co] + __rfloordiv__: _UnsignedIntOp[_NBit_co] + __pow__: _UnsignedIntOp[_NBit_co] + __rpow__: _UnsignedIntOp[_NBit_co] + __lshift__: _UnsignedIntBitOp[_NBit_co] + __rlshift__: _UnsignedIntBitOp[_NBit_co] + __rshift__: _UnsignedIntBitOp[_NBit_co] + __rrshift__: _UnsignedIntBitOp[_NBit_co] + __and__: _UnsignedIntBitOp[_NBit_co] + __rand__: _UnsignedIntBitOp[_NBit_co] + __xor__: _UnsignedIntBitOp[_NBit_co] + __rxor__: _UnsignedIntBitOp[_NBit_co] + __or__: _UnsignedIntBitOp[_NBit_co] + __ror__: _UnsignedIntBitOp[_NBit_co] + +uint8 = unsignedinteger[_8Bit] +uint16 = unsignedinteger[_16Bit] +uint32 = unsignedinteger[_32Bit] +uint64 = unsignedinteger[_64Bit] + +class inexact(number[_NBit_co]): ... # type: ignore _IntType = TypeVar("_IntType", bound=integer) _FloatType = TypeVar('_FloatType', bound=floating) -class float16(floating): +class floating(inexact[_NBit_co]): def __init__(self, __value: _FloatValue = ...) -> None: ... - -class float32(floating): - def __init__(self, __value: _FloatValue = ...) -> None: ... - -class float64(floating, float): - def __init__(self, __value: _FloatValue = ...) -> None: ... - -class complexfloating(inexact, Generic[_FloatType]): # type: ignore - @property - def real(self) -> _FloatType: ... # type: ignore[override] - @property - def imag(self) -> _FloatType: ... # type: ignore[override] - def __abs__(self) -> _FloatType: ... # type: ignore[override] - __add__: _ComplexOp - __radd__: _ComplexOp - __sub__: _ComplexOp - __rsub__: _ComplexOp - __mul__: _ComplexOp - __rmul__: _ComplexOp - __truediv__: _ComplexOp - __rtruediv__: _ComplexOp - __floordiv__: _ComplexOp - __rfloordiv__: _ComplexOp - __pow__: _ComplexOp - __rpow__: _ComplexOp - -class complex64(complexfloating[float32]): - def __init__(self, __value: _ComplexValue = ...) -> None: ... - -class complex128(complexfloating[float64], complex): + __add__: _FloatOp[_NBit_co] + __radd__: _FloatOp[_NBit_co] + __sub__: _FloatOp[_NBit_co] + __rsub__: _FloatOp[_NBit_co] + __mul__: _FloatOp[_NBit_co] + __rmul__: _FloatOp[_NBit_co] + __truediv__: _FloatOp[_NBit_co] + __rtruediv__: _FloatOp[_NBit_co] + __floordiv__: _FloatOp[_NBit_co] + __rfloordiv__: _FloatOp[_NBit_co] + __pow__: _FloatOp[_NBit_co] + __rpow__: _FloatOp[_NBit_co] + +float16 = floating[_16Bit] +float32 = floating[_32Bit] +float64 = floating[_64Bit] + +class complexfloating(inexact[_NBit_co]): def __init__(self, __value: _ComplexValue = ...) -> None: ... + @property + def real(self) -> floating[_NBit_co]: ... # type: ignore[override] + @property + def imag(self) -> floating[_NBit_co]: ... # type: ignore[override] + def __abs__(self) -> floating[_NBit_co]: ... # type: ignore[override] + __add__: _ComplexOp[_NBit_co] + __radd__: _ComplexOp[_NBit_co] + __sub__: _ComplexOp[_NBit_co] + __rsub__: _ComplexOp[_NBit_co] + __mul__: _ComplexOp[_NBit_co] + __rmul__: _ComplexOp[_NBit_co] + __truediv__: _ComplexOp[_NBit_co] + __rtruediv__: _ComplexOp[_NBit_co] + __floordiv__: _ComplexOp[_NBit_co] + __rfloordiv__: _ComplexOp[_NBit_co] + __pow__: _ComplexOp[_NBit_co] + __rpow__: _ComplexOp[_NBit_co] + +complex64 = complexfloating[_32Bit] +complex128 = complexfloating[_64Bit] class flexible(generic): ... # type: ignore diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py index dafabd95a..5c6e5df99 100644 --- a/numpy/typing/__init__.py +++ b/numpy/typing/__init__.py @@ -89,7 +89,87 @@ Although this is valid Numpy code, the type checker will complain about it, since its usage is discouraged. Please see : https://numpy.org/devdocs/reference/arrays.dtypes.html +NBitBase +~~~~~~~~ + +.. autoclass:: numpy.typing.NBitBase + """ + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import sys + if sys.version_info >= (3, 8): + from typing import final + else: + from typing_extensions import final +else: + def final(f): return f + + +@final # Dissallow the creation of arbitrary `NBitBase` subclasses +class NBitBase: + """ + An object representing `number` precision during static type checking. + + Used exclusively for the purpose static type checking, `NBitBase` + represents the base of a hierachieral set of subclasses. + Each subsequent subclass is herein used for representing a lower level + of precision, _e.g._ `64Bit > 32Bit > 16Bit`. + + Examples + -------- + Below is a typical usage example: `NBitBase` is herein used for annotating a + function that takes a float and integer of arbitrary precision as arguments + and returns a new float of whichever precision is largest + (_e.g._ `np.float16 + np.int64 -> np.float64`). + + >>> from typing import TypeVar, TYPE_CHECKING + >>> import numpy as np + >>> import numpy.typing as npt + + >>> T = TypeVar("T", bound=npt.NBitBase) + + >>> def add(a: "np.floating[T]", b: "np.integer[T]") -> "np.floating[T]": + ... return a + b + + >>> a = np.float16() + >>> b = np.int64() + >>> out = add(a, b) + + >>> if TYPE_CHECKING: + ... reveal_locals() + ... # note: Revealed local types are: + ... # note: a: numpy.floating[numpy.typing._16Bit*] + ... # note: b: numpy.signedinteger[numpy.typing._64Bit*] + ... # note: out: numpy.floating[numpy.typing._64Bit*] + + """ + + def __init_subclass__(cls) -> None: + allowed_names = { + "NBitBase", "_256Bit", "_128Bit", "_96Bit", "_80Bit", + "_64Bit", "_32Bit", "_16Bit", "_8Bit", + } + if cls.__name__ not in allowed_names: + raise TypeError('cannot inherit from final class "NBitBase"') + super().__init_subclass__() + + +# Silence errors about subclassing a `@final`-decorated class +class _256Bit(NBitBase): ... # type: ignore[misc] +class _128Bit(_256Bit): ... # type: ignore[misc] +class _96Bit(_128Bit): ... # type: ignore[misc] +class _80Bit(_96Bit): ... # type: ignore[misc] +class _64Bit(_80Bit): ... # type: ignore[misc] +class _32Bit(_64Bit): ... # type: ignore[misc] +class _16Bit(_32Bit): ... # type: ignore[misc] +class _8Bit(_16Bit): ... # type: ignore[misc] + +# Clean up the namespace +del TYPE_CHECKING, final + from ._scalars import ( _CharLike, _BoolLike, diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py index 7c2ee86cb..bc58f2774 100644 --- a/numpy/typing/_callable.py +++ b/numpy/typing/_callable.py @@ -9,7 +9,7 @@ See the `Mypy documentation`_ on protocols for more details. """ import sys -from typing import Union, TypeVar, overload, Any +from typing import Union, TypeVar, overload, Any, TYPE_CHECKING, NoReturn from numpy import ( generic, @@ -25,6 +25,7 @@ from numpy import ( float32, float64, complexfloating, + complex64, complex128, ) from ._scalars import ( @@ -34,6 +35,7 @@ from ._scalars import ( _ComplexLike, _NumberLike, ) +from . import NBitBase, _64Bit if sys.version_info >= (3, 8): from typing import Protocol @@ -46,7 +48,8 @@ else: else: HAVE_PROTOCOL = True -if HAVE_PROTOCOL: +if TYPE_CHECKING or HAVE_PROTOCOL: + _NBit_co = TypeVar("_NBit_co", covariant=True, bound=NBitBase) _IntType = TypeVar("_IntType", bound=integer) _NumberType = TypeVar("_NumberType", bound=number) _NumberType_co = TypeVar("_NumberType_co", covariant=True, bound=number) @@ -74,6 +77,8 @@ if HAVE_PROTOCOL: class _BoolSub(Protocol): # Note that `__other: bool_` is absent here + @overload + def __call__(self, __other: bool) -> NoReturn: ... @overload # platform dependent def __call__(self, __other: int) -> Union[int32, int64]: ... @overload @@ -97,51 +102,105 @@ if HAVE_PROTOCOL: @overload def __call__(self, __other: _FloatLike) -> timedelta64: ... - class _IntTrueDiv(Protocol): + class _IntTrueDiv(Protocol[_NBit_co]): # type: ignore[misc] + @overload + def __call__(self, __other: bool) -> floating[_NBit_co]: ... + @overload + def __call__(self, __other: int) -> Union[float32, float64]: ... + @overload + def __call__(self, __other: float) -> float64: ... @overload - def __call__(self, __other: Union[_IntLike, float]) -> floating: ... + def __call__(self, __other: complex) -> complex128: ... @overload - def __call__(self, __other: complex) -> complexfloating[floating]: ... + def __call__(self, __other: integer[_NBit_co]) -> floating[_NBit_co]: ... - class _UnsignedIntOp(Protocol): + class _UnsignedIntOp(Protocol[_NBit_co]): # type: ignore[misc] # NOTE: `uint64 + signedinteger -> float64` @overload - def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ... + def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ... @overload - def __call__(self, __other: Union[int, signedinteger]) -> Union[signedinteger, float64]: ... + def __call__( + self, __other: Union[int, signedinteger[Any]] + ) -> Union[signedinteger[Any], float64]: ... + @overload + def __call__(self, __other: float) -> float64: ... @overload - def __call__(self, __other: float) -> floating: ... + def __call__(self, __other: complex) -> complex128: ... @overload - def __call__(self, __other: complex) -> complexfloating[floating]: ... + def __call__( + self, __other: unsignedinteger[_NBit_co] + ) -> unsignedinteger[_NBit_co]: ... - class _UnsignedIntBitOp(Protocol): + class _UnsignedIntBitOp(Protocol[_NBit_co]): # type: ignore[misc] # TODO: The likes of `uint64 | np.signedinteger` will fail as there # is no signed integer type large enough to hold a `uint64` # See https://github.com/numpy/numpy/issues/2524 @overload - def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ... + def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ... + @overload + def __call__( + self, __other: unsignedinteger[_NBit_co] + ) -> unsignedinteger[_NBit_co]: ... + @overload + def __call__( + self: _UnsignedIntBitOp[_64Bit], + __other: Union[int, signedinteger[Any]], + ) -> NoReturn: ... + @overload + def __call__(self, __other: int) -> Union[int32, int64]: ... @overload - def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ... + def __call__(self, __other: signedinteger[Any]) -> signedinteger[Any]: ... - class _SignedIntOp(Protocol): + class _SignedIntOp(Protocol[_NBit_co]): # type: ignore[misc] + @overload + def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ... @overload - def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ... + def __call__(self, __other: int) -> Union[int32, int64]: ... @overload - def __call__(self, __other: float) -> floating: ... + def __call__(self, __other: float) -> float64: ... @overload - def __call__(self, __other: complex) -> complexfloating[floating]: ... + def __call__(self, __other: complex) -> complex128: ... + @overload + def __call__( + self, __other: signedinteger[_NBit_co] + ) -> signedinteger[_NBit_co]: ... - class _SignedIntBitOp(Protocol): - def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ... + class _SignedIntBitOp(Protocol[_NBit_co]): # type: ignore[misc] + @overload + def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ... + @overload + def __call__(self, __other: int) -> Union[int32, int64]: ... + @overload + def __call__( + self, __other: signedinteger[_NBit_co] + ) -> signedinteger[_NBit_co]: ... - class _FloatOp(Protocol): + class _FloatOp(Protocol[_NBit_co]): # type: ignore[misc] @overload - def __call__(self, __other: _FloatLike) -> floating: ... + def __call__(self, __other: bool) -> floating[_NBit_co]: ... @overload - def __call__(self, __other: complex) -> complexfloating[floating]: ... + def __call__(self, __other: int) -> Union[float32, float64]: ... + @overload + def __call__(self, __other: float) -> float64: ... + @overload + def __call__(self, __other: complex) -> complex128: ... + @overload + def __call__( + self, __other: Union[integer[_NBit_co], floating[_NBit_co]] + ) -> floating[_NBit_co]: ... - class _ComplexOp(Protocol): - def __call__(self, __other: _ComplexLike) -> complexfloating[floating]: ... + class _ComplexOp(Protocol[_NBit_co]): # type: ignore[misc] + @overload + def __call__(self, __other: bool) -> complexfloating[_NBit_co, _NBit_co]: ... + @overload + def __call__(self, __other: int) -> Union[complex64, complex128]: ... + @overload + def __call__(self, __other: Union[float, complex]) -> complex128: ... + @overload + def __call__( + self, + __other: Union[integer[_NBit_co], floating[_NBit_co], complexfloating[_NBit_co]] + ) -> complexfloating[_NBit_co]: ... class _NumberOp(Protocol): def __call__(self, __other: _NumberLike) -> number: ... |