diff options
Diffstat (limited to 'numpy/typing/_callable.py')
-rw-r--r-- | numpy/typing/_callable.py | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py index 5e14b708f..943441cf4 100644 --- a/numpy/typing/_callable.py +++ b/numpy/typing/_callable.py @@ -9,7 +9,7 @@ See the `Mypy documentation`_ on protocols for more details. """ import sys -from typing import Union, TypeVar, overload, Any +from typing import Union, TypeVar, overload, Any, NoReturn from numpy import ( _BoolLike, @@ -26,6 +26,7 @@ from numpy import ( signedinteger, int32, int64, + uint64, floating, float32, float64, @@ -45,6 +46,7 @@ else: HAVE_PROTOCOL = True if HAVE_PROTOCOL: + _IntType = TypeVar("_IntType", bound=integer) _NumberType = TypeVar("_NumberType", bound=number) _NumberType_co = TypeVar("_NumberType_co", covariant=True, bound=number) _GenericType_co = TypeVar("_GenericType_co", covariant=True, bound=generic) @@ -61,6 +63,14 @@ if HAVE_PROTOCOL: @overload def __call__(self, __other: _NumberType) -> _NumberType: ... + class _BoolBitOp(Protocol[_GenericType_co]): + @overload + def __call__(self, __other: _BoolLike) -> _GenericType_co: ... + @overload # platform dependent + def __call__(self, __other: int) -> Union[int32, int64]: ... + @overload + def __call__(self, __other: _IntType) -> _IntType: ... + class _BoolSub(Protocol): # Note that `__other: bool_` is absent here @overload # platform dependent @@ -103,6 +113,17 @@ if HAVE_PROTOCOL: @overload def __call__(self, __other: complex) -> complexfloating[floating]: ... + class _UnsignedIntBitOp(Protocol): + # The likes of `uint64 | np.signedinteger` will fail as there + # is no signed integer type large enough to hold a `uint64` + # See https://github.com/numpy/numpy/issues/2524 + @overload + def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ... + @overload + def __call__(self: uint64, __other: Union[int, signedinteger]) -> NoReturn: ... + @overload + def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ... + class _SignedIntOp(Protocol): @overload def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ... @@ -111,6 +132,9 @@ if HAVE_PROTOCOL: @overload def __call__(self, __other: complex) -> complexfloating[floating]: ... + class _SignedIntBitOp(Protocol): + def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ... + class _FloatOp(Protocol): @overload def __call__(self, __other: _FloatLike) -> floating: ... @@ -125,12 +149,15 @@ if HAVE_PROTOCOL: else: _BoolOp = Any + _BoolBitOp = Any _BoolSub = Any _BoolTrueDiv = Any _TD64Div = Any _IntTrueDiv = Any _UnsignedIntOp = Any + _UnsignedIntBitOp = Any _SignedIntOp = Any + _SignedIntBitOp = Any _FloatOp = Any _ComplexOp = Any _NumberOp = Any |