summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2020-10-03 01:06:56 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2020-10-17 18:05:12 +0200
commitc53797e838f002b14e0d33c9651bffecd9934404 (patch)
tree3975389ed2be99664d7046b62b7a962b39dc875a
parent7b0a764fee6e1614f3249e9082d8c4acf1dc62d5 (diff)
downloadnumpy-c53797e838f002b14e0d33c9651bffecd9934404.tar.gz
ENH: Added support for `number` precision
-rw-r--r--numpy/__init__.pyi224
-rw-r--r--numpy/typing/__init__.py80
-rw-r--r--numpy/typing/_callable.py107
3 files changed, 269 insertions, 142 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 2fff82d59..854d9e1ce 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -17,6 +17,11 @@ from numpy.typing import (
_NumberLike,
_SupportsDtype,
_VoidDtypeLike,
+ NBitBase,
+ _64Bit,
+ _32Bit,
+ _16Bit,
+ _8Bit,
)
from numpy.typing._callable import (
_BoolOp,
@@ -1606,13 +1611,15 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
# See https://github.com/numpy/numpy-stubs/pull/80 for more details.
+_NBit_co = TypeVar("_NBit_co", covariant=True, bound=NBitBase)
+
class generic(_ArrayOrScalarCommon):
@abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
@property
def base(self) -> None: ...
-class number(generic): # type: ignore
+class number(generic, Generic[_NBit_co]): # type: ignore
@property
def real(self: _ArraySelf) -> _ArraySelf: ...
@property
@@ -1699,12 +1706,12 @@ else:
_FloatValue = Union[None, _CharLike, SupportsFloat]
_ComplexValue = Union[None, _CharLike, SupportsFloat, SupportsComplex]
-class integer(number): # type: ignore
+class integer(number[_NBit_co]): # type: ignore
# NOTE: `__index__` is technically defined in the bottom-most
# sub-classes (`int64`, `uint32`, etc)
def __index__(self) -> int: ...
- __truediv__: _IntTrueDiv
- __rtruediv__: _IntTrueDiv
+ __truediv__: _IntTrueDiv[_NBit_co]
+ __rtruediv__: _IntTrueDiv[_NBit_co]
def __invert__(self: _IntType) -> _IntType: ...
# Ensure that objects annotated as `integer` support bit-wise operations
def __lshift__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
@@ -1718,39 +1725,33 @@ class integer(number): # type: ignore
def __xor__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
def __rxor__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
-class signedinteger(integer): # type: ignore
- __add__: _SignedIntOp
- __radd__: _SignedIntOp
- __sub__: _SignedIntOp
- __rsub__: _SignedIntOp
- __mul__: _SignedIntOp
- __rmul__: _SignedIntOp
- __floordiv__: _SignedIntOp
- __rfloordiv__: _SignedIntOp
- __pow__: _SignedIntOp
- __rpow__: _SignedIntOp
- __lshift__: _SignedIntBitOp
- __rlshift__: _SignedIntBitOp
- __rshift__: _SignedIntBitOp
- __rrshift__: _SignedIntBitOp
- __and__: _SignedIntBitOp
- __rand__: _SignedIntBitOp
- __xor__: _SignedIntBitOp
- __rxor__: _SignedIntBitOp
- __or__: _SignedIntBitOp
- __ror__: _SignedIntBitOp
-
-class int8(signedinteger):
- def __init__(self, __value: _IntValue = ...) -> None: ...
-
-class int16(signedinteger):
- def __init__(self, __value: _IntValue = ...) -> None: ...
-
-class int32(signedinteger):
- def __init__(self, __value: _IntValue = ...) -> None: ...
-
-class int64(signedinteger):
+class signedinteger(integer[_NBit_co]):
def __init__(self, __value: _IntValue = ...) -> None: ...
+ __add__: _SignedIntOp[_NBit_co]
+ __radd__: _SignedIntOp[_NBit_co]
+ __sub__: _SignedIntOp[_NBit_co]
+ __rsub__: _SignedIntOp[_NBit_co]
+ __mul__: _SignedIntOp[_NBit_co]
+ __rmul__: _SignedIntOp[_NBit_co]
+ __floordiv__: _SignedIntOp[_NBit_co]
+ __rfloordiv__: _SignedIntOp[_NBit_co]
+ __pow__: _SignedIntOp[_NBit_co]
+ __rpow__: _SignedIntOp[_NBit_co]
+ __lshift__: _SignedIntBitOp[_NBit_co]
+ __rlshift__: _SignedIntBitOp[_NBit_co]
+ __rshift__: _SignedIntBitOp[_NBit_co]
+ __rrshift__: _SignedIntBitOp[_NBit_co]
+ __and__: _SignedIntBitOp[_NBit_co]
+ __rand__: _SignedIntBitOp[_NBit_co]
+ __xor__: _SignedIntBitOp[_NBit_co]
+ __rxor__: _SignedIntBitOp[_NBit_co]
+ __or__: _SignedIntBitOp[_NBit_co]
+ __ror__: _SignedIntBitOp[_NBit_co]
+
+int8 = signedinteger[_8Bit]
+int16 = signedinteger[_16Bit]
+int32 = signedinteger[_32Bit]
+int64 = signedinteger[_64Bit]
class timedelta64(generic):
def __init__(
@@ -1765,98 +1766,85 @@ class timedelta64(generic):
def __mul__(self, other: Union[_FloatLike, _BoolLike]) -> timedelta64: ...
def __rmul__(self, other: Union[_FloatLike, _BoolLike]) -> timedelta64: ...
__truediv__: _TD64Div[float64]
- __floordiv__: _TD64Div[signedinteger]
+ __floordiv__: _TD64Div[int64]
def __rtruediv__(self, other: timedelta64) -> float64: ...
- def __rfloordiv__(self, other: timedelta64) -> signedinteger: ...
+ def __rfloordiv__(self, other: timedelta64) -> int64: ...
def __mod__(self, other: timedelta64) -> timedelta64: ...
-class unsignedinteger(integer): # type: ignore
+class unsignedinteger(integer[_NBit_co]): # type: ignore
# NOTE: `uint64 + signedinteger -> float64`
- __add__: _UnsignedIntOp
- __radd__: _UnsignedIntOp
- __sub__: _UnsignedIntOp
- __rsub__: _UnsignedIntOp
- __mul__: _UnsignedIntOp
- __rmul__: _UnsignedIntOp
- __floordiv__: _UnsignedIntOp
- __rfloordiv__: _UnsignedIntOp
- __pow__: _UnsignedIntOp
- __rpow__: _UnsignedIntOp
- __lshift__: _UnsignedIntBitOp
- __rlshift__: _UnsignedIntBitOp
- __rshift__: _UnsignedIntBitOp
- __rrshift__: _UnsignedIntBitOp
- __and__: _UnsignedIntBitOp
- __rand__: _UnsignedIntBitOp
- __xor__: _UnsignedIntBitOp
- __rxor__: _UnsignedIntBitOp
- __or__: _UnsignedIntBitOp
- __ror__: _UnsignedIntBitOp
-
-class uint8(unsignedinteger):
- def __init__(self, __value: _IntValue = ...) -> None: ...
-
-class uint16(unsignedinteger):
- def __init__(self, __value: _IntValue = ...) -> None: ...
-
-class uint32(unsignedinteger):
- def __init__(self, __value: _IntValue = ...) -> None: ...
-
-class uint64(unsignedinteger):
- def __init__(self, __value: _IntValue = ...) -> None: ...
-
-class inexact(number): ... # type: ignore
-
-class floating(inexact): # type: ignore
- __add__: _FloatOp
- __radd__: _FloatOp
- __sub__: _FloatOp
- __rsub__: _FloatOp
- __mul__: _FloatOp
- __rmul__: _FloatOp
- __truediv__: _FloatOp
- __rtruediv__: _FloatOp
- __floordiv__: _FloatOp
- __rfloordiv__: _FloatOp
- __pow__: _FloatOp
- __rpow__: _FloatOp
+ __add__: _UnsignedIntOp[_NBit_co]
+ __radd__: _UnsignedIntOp[_NBit_co]
+ __sub__: _UnsignedIntOp[_NBit_co]
+ __rsub__: _UnsignedIntOp[_NBit_co]
+ __mul__: _UnsignedIntOp[_NBit_co]
+ __rmul__: _UnsignedIntOp[_NBit_co]
+ __floordiv__: _UnsignedIntOp[_NBit_co]
+ __rfloordiv__: _UnsignedIntOp[_NBit_co]
+ __pow__: _UnsignedIntOp[_NBit_co]
+ __rpow__: _UnsignedIntOp[_NBit_co]
+ __lshift__: _UnsignedIntBitOp[_NBit_co]
+ __rlshift__: _UnsignedIntBitOp[_NBit_co]
+ __rshift__: _UnsignedIntBitOp[_NBit_co]
+ __rrshift__: _UnsignedIntBitOp[_NBit_co]
+ __and__: _UnsignedIntBitOp[_NBit_co]
+ __rand__: _UnsignedIntBitOp[_NBit_co]
+ __xor__: _UnsignedIntBitOp[_NBit_co]
+ __rxor__: _UnsignedIntBitOp[_NBit_co]
+ __or__: _UnsignedIntBitOp[_NBit_co]
+ __ror__: _UnsignedIntBitOp[_NBit_co]
+
+uint8 = unsignedinteger[_8Bit]
+uint16 = unsignedinteger[_16Bit]
+uint32 = unsignedinteger[_32Bit]
+uint64 = unsignedinteger[_64Bit]
+
+class inexact(number[_NBit_co]): ... # type: ignore
_IntType = TypeVar("_IntType", bound=integer)
_FloatType = TypeVar('_FloatType', bound=floating)
-class float16(floating):
+class floating(inexact[_NBit_co]):
def __init__(self, __value: _FloatValue = ...) -> None: ...
-
-class float32(floating):
- def __init__(self, __value: _FloatValue = ...) -> None: ...
-
-class float64(floating, float):
- def __init__(self, __value: _FloatValue = ...) -> None: ...
-
-class complexfloating(inexact, Generic[_FloatType]): # type: ignore
- @property
- def real(self) -> _FloatType: ... # type: ignore[override]
- @property
- def imag(self) -> _FloatType: ... # type: ignore[override]
- def __abs__(self) -> _FloatType: ... # type: ignore[override]
- __add__: _ComplexOp
- __radd__: _ComplexOp
- __sub__: _ComplexOp
- __rsub__: _ComplexOp
- __mul__: _ComplexOp
- __rmul__: _ComplexOp
- __truediv__: _ComplexOp
- __rtruediv__: _ComplexOp
- __floordiv__: _ComplexOp
- __rfloordiv__: _ComplexOp
- __pow__: _ComplexOp
- __rpow__: _ComplexOp
-
-class complex64(complexfloating[float32]):
- def __init__(self, __value: _ComplexValue = ...) -> None: ...
-
-class complex128(complexfloating[float64], complex):
+ __add__: _FloatOp[_NBit_co]
+ __radd__: _FloatOp[_NBit_co]
+ __sub__: _FloatOp[_NBit_co]
+ __rsub__: _FloatOp[_NBit_co]
+ __mul__: _FloatOp[_NBit_co]
+ __rmul__: _FloatOp[_NBit_co]
+ __truediv__: _FloatOp[_NBit_co]
+ __rtruediv__: _FloatOp[_NBit_co]
+ __floordiv__: _FloatOp[_NBit_co]
+ __rfloordiv__: _FloatOp[_NBit_co]
+ __pow__: _FloatOp[_NBit_co]
+ __rpow__: _FloatOp[_NBit_co]
+
+float16 = floating[_16Bit]
+float32 = floating[_32Bit]
+float64 = floating[_64Bit]
+
+class complexfloating(inexact[_NBit_co]):
def __init__(self, __value: _ComplexValue = ...) -> None: ...
+ @property
+ def real(self) -> floating[_NBit_co]: ... # type: ignore[override]
+ @property
+ def imag(self) -> floating[_NBit_co]: ... # type: ignore[override]
+ def __abs__(self) -> floating[_NBit_co]: ... # type: ignore[override]
+ __add__: _ComplexOp[_NBit_co]
+ __radd__: _ComplexOp[_NBit_co]
+ __sub__: _ComplexOp[_NBit_co]
+ __rsub__: _ComplexOp[_NBit_co]
+ __mul__: _ComplexOp[_NBit_co]
+ __rmul__: _ComplexOp[_NBit_co]
+ __truediv__: _ComplexOp[_NBit_co]
+ __rtruediv__: _ComplexOp[_NBit_co]
+ __floordiv__: _ComplexOp[_NBit_co]
+ __rfloordiv__: _ComplexOp[_NBit_co]
+ __pow__: _ComplexOp[_NBit_co]
+ __rpow__: _ComplexOp[_NBit_co]
+
+complex64 = complexfloating[_32Bit]
+complex128 = complexfloating[_64Bit]
class flexible(generic): ... # type: ignore
diff --git a/numpy/typing/__init__.py b/numpy/typing/__init__.py
index dafabd95a..5c6e5df99 100644
--- a/numpy/typing/__init__.py
+++ b/numpy/typing/__init__.py
@@ -89,7 +89,87 @@ Although this is valid Numpy code, the type checker will complain about it,
since its usage is discouraged.
Please see : https://numpy.org/devdocs/reference/arrays.dtypes.html
+NBitBase
+~~~~~~~~
+
+.. autoclass:: numpy.typing.NBitBase
+
"""
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import sys
+ if sys.version_info >= (3, 8):
+ from typing import final
+ else:
+ from typing_extensions import final
+else:
+ def final(f): return f
+
+
+@final # Dissallow the creation of arbitrary `NBitBase` subclasses
+class NBitBase:
+ """
+ An object representing `number` precision during static type checking.
+
+ Used exclusively for the purpose static type checking, `NBitBase`
+ represents the base of a hierachieral set of subclasses.
+ Each subsequent subclass is herein used for representing a lower level
+ of precision, _e.g._ `64Bit > 32Bit > 16Bit`.
+
+ Examples
+ --------
+ Below is a typical usage example: `NBitBase` is herein used for annotating a
+ function that takes a float and integer of arbitrary precision as arguments
+ and returns a new float of whichever precision is largest
+ (_e.g._ `np.float16 + np.int64 -> np.float64`).
+
+ >>> from typing import TypeVar, TYPE_CHECKING
+ >>> import numpy as np
+ >>> import numpy.typing as npt
+
+ >>> T = TypeVar("T", bound=npt.NBitBase)
+
+ >>> def add(a: "np.floating[T]", b: "np.integer[T]") -> "np.floating[T]":
+ ... return a + b
+
+ >>> a = np.float16()
+ >>> b = np.int64()
+ >>> out = add(a, b)
+
+ >>> if TYPE_CHECKING:
+ ... reveal_locals()
+ ... # note: Revealed local types are:
+ ... # note: a: numpy.floating[numpy.typing._16Bit*]
+ ... # note: b: numpy.signedinteger[numpy.typing._64Bit*]
+ ... # note: out: numpy.floating[numpy.typing._64Bit*]
+
+ """
+
+ def __init_subclass__(cls) -> None:
+ allowed_names = {
+ "NBitBase", "_256Bit", "_128Bit", "_96Bit", "_80Bit",
+ "_64Bit", "_32Bit", "_16Bit", "_8Bit",
+ }
+ if cls.__name__ not in allowed_names:
+ raise TypeError('cannot inherit from final class "NBitBase"')
+ super().__init_subclass__()
+
+
+# Silence errors about subclassing a `@final`-decorated class
+class _256Bit(NBitBase): ... # type: ignore[misc]
+class _128Bit(_256Bit): ... # type: ignore[misc]
+class _96Bit(_128Bit): ... # type: ignore[misc]
+class _80Bit(_96Bit): ... # type: ignore[misc]
+class _64Bit(_80Bit): ... # type: ignore[misc]
+class _32Bit(_64Bit): ... # type: ignore[misc]
+class _16Bit(_32Bit): ... # type: ignore[misc]
+class _8Bit(_16Bit): ... # type: ignore[misc]
+
+# Clean up the namespace
+del TYPE_CHECKING, final
+
from ._scalars import (
_CharLike,
_BoolLike,
diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py
index 7c2ee86cb..bc58f2774 100644
--- a/numpy/typing/_callable.py
+++ b/numpy/typing/_callable.py
@@ -9,7 +9,7 @@ See the `Mypy documentation`_ on protocols for more details.
"""
import sys
-from typing import Union, TypeVar, overload, Any
+from typing import Union, TypeVar, overload, Any, TYPE_CHECKING, NoReturn
from numpy import (
generic,
@@ -25,6 +25,7 @@ from numpy import (
float32,
float64,
complexfloating,
+ complex64,
complex128,
)
from ._scalars import (
@@ -34,6 +35,7 @@ from ._scalars import (
_ComplexLike,
_NumberLike,
)
+from . import NBitBase, _64Bit
if sys.version_info >= (3, 8):
from typing import Protocol
@@ -46,7 +48,8 @@ else:
else:
HAVE_PROTOCOL = True
-if HAVE_PROTOCOL:
+if TYPE_CHECKING or HAVE_PROTOCOL:
+ _NBit_co = TypeVar("_NBit_co", covariant=True, bound=NBitBase)
_IntType = TypeVar("_IntType", bound=integer)
_NumberType = TypeVar("_NumberType", bound=number)
_NumberType_co = TypeVar("_NumberType_co", covariant=True, bound=number)
@@ -74,6 +77,8 @@ if HAVE_PROTOCOL:
class _BoolSub(Protocol):
# Note that `__other: bool_` is absent here
+ @overload
+ def __call__(self, __other: bool) -> NoReturn: ...
@overload # platform dependent
def __call__(self, __other: int) -> Union[int32, int64]: ...
@overload
@@ -97,51 +102,105 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: _FloatLike) -> timedelta64: ...
- class _IntTrueDiv(Protocol):
+ class _IntTrueDiv(Protocol[_NBit_co]): # type: ignore[misc]
+ @overload
+ def __call__(self, __other: bool) -> floating[_NBit_co]: ...
+ @overload
+ def __call__(self, __other: int) -> Union[float32, float64]: ...
+ @overload
+ def __call__(self, __other: float) -> float64: ...
@overload
- def __call__(self, __other: Union[_IntLike, float]) -> floating: ...
+ def __call__(self, __other: complex) -> complex128: ...
@overload
- def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ def __call__(self, __other: integer[_NBit_co]) -> floating[_NBit_co]: ...
- class _UnsignedIntOp(Protocol):
+ class _UnsignedIntOp(Protocol[_NBit_co]): # type: ignore[misc]
# NOTE: `uint64 + signedinteger -> float64`
@overload
- def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ...
+ def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ...
@overload
- def __call__(self, __other: Union[int, signedinteger]) -> Union[signedinteger, float64]: ...
+ def __call__(
+ self, __other: Union[int, signedinteger[Any]]
+ ) -> Union[signedinteger[Any], float64]: ...
+ @overload
+ def __call__(self, __other: float) -> float64: ...
@overload
- def __call__(self, __other: float) -> floating: ...
+ def __call__(self, __other: complex) -> complex128: ...
@overload
- def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ def __call__(
+ self, __other: unsignedinteger[_NBit_co]
+ ) -> unsignedinteger[_NBit_co]: ...
- class _UnsignedIntBitOp(Protocol):
+ class _UnsignedIntBitOp(Protocol[_NBit_co]): # type: ignore[misc]
# TODO: The likes of `uint64 | np.signedinteger` will fail as there
# is no signed integer type large enough to hold a `uint64`
# See https://github.com/numpy/numpy/issues/2524
@overload
- def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ...
+ def __call__(self, __other: bool) -> unsignedinteger[_NBit_co]: ...
+ @overload
+ def __call__(
+ self, __other: unsignedinteger[_NBit_co]
+ ) -> unsignedinteger[_NBit_co]: ...
+ @overload
+ def __call__(
+ self: _UnsignedIntBitOp[_64Bit],
+ __other: Union[int, signedinteger[Any]],
+ ) -> NoReturn: ...
+ @overload
+ def __call__(self, __other: int) -> Union[int32, int64]: ...
@overload
- def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+ def __call__(self, __other: signedinteger[Any]) -> signedinteger[Any]: ...
- class _SignedIntOp(Protocol):
+ class _SignedIntOp(Protocol[_NBit_co]): # type: ignore[misc]
+ @overload
+ def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ...
@overload
- def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+ def __call__(self, __other: int) -> Union[int32, int64]: ...
@overload
- def __call__(self, __other: float) -> floating: ...
+ def __call__(self, __other: float) -> float64: ...
@overload
- def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ def __call__(self, __other: complex) -> complex128: ...
+ @overload
+ def __call__(
+ self, __other: signedinteger[_NBit_co]
+ ) -> signedinteger[_NBit_co]: ...
- class _SignedIntBitOp(Protocol):
- def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+ class _SignedIntBitOp(Protocol[_NBit_co]): # type: ignore[misc]
+ @overload
+ def __call__(self, __other: bool) -> signedinteger[_NBit_co]: ...
+ @overload
+ def __call__(self, __other: int) -> Union[int32, int64]: ...
+ @overload
+ def __call__(
+ self, __other: signedinteger[_NBit_co]
+ ) -> signedinteger[_NBit_co]: ...
- class _FloatOp(Protocol):
+ class _FloatOp(Protocol[_NBit_co]): # type: ignore[misc]
@overload
- def __call__(self, __other: _FloatLike) -> floating: ...
+ def __call__(self, __other: bool) -> floating[_NBit_co]: ...
@overload
- def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ def __call__(self, __other: int) -> Union[float32, float64]: ...
+ @overload
+ def __call__(self, __other: float) -> float64: ...
+ @overload
+ def __call__(self, __other: complex) -> complex128: ...
+ @overload
+ def __call__(
+ self, __other: Union[integer[_NBit_co], floating[_NBit_co]]
+ ) -> floating[_NBit_co]: ...
- class _ComplexOp(Protocol):
- def __call__(self, __other: _ComplexLike) -> complexfloating[floating]: ...
+ class _ComplexOp(Protocol[_NBit_co]): # type: ignore[misc]
+ @overload
+ def __call__(self, __other: bool) -> complexfloating[_NBit_co, _NBit_co]: ...
+ @overload
+ def __call__(self, __other: int) -> Union[complex64, complex128]: ...
+ @overload
+ def __call__(self, __other: Union[float, complex]) -> complex128: ...
+ @overload
+ def __call__(
+ self,
+ __other: Union[integer[_NBit_co], floating[_NBit_co], complexfloating[_NBit_co]]
+ ) -> complexfloating[_NBit_co]: ...
class _NumberOp(Protocol):
def __call__(self, __other: _NumberLike) -> number: ...