summaryrefslogtreecommitdiff
path: root/numpy/typing/_callable.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/typing/_callable.py')
-rw-r--r--numpy/typing/_callable.py29
1 files changed, 28 insertions, 1 deletions
diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py
index 5e14b708f..943441cf4 100644
--- a/numpy/typing/_callable.py
+++ b/numpy/typing/_callable.py
@@ -9,7 +9,7 @@ See the `Mypy documentation`_ on protocols for more details.
"""
import sys
-from typing import Union, TypeVar, overload, Any
+from typing import Union, TypeVar, overload, Any, NoReturn
from numpy import (
_BoolLike,
@@ -26,6 +26,7 @@ from numpy import (
signedinteger,
int32,
int64,
+ uint64,
floating,
float32,
float64,
@@ -45,6 +46,7 @@ else:
HAVE_PROTOCOL = True
if HAVE_PROTOCOL:
+ _IntType = TypeVar("_IntType", bound=integer)
_NumberType = TypeVar("_NumberType", bound=number)
_NumberType_co = TypeVar("_NumberType_co", covariant=True, bound=number)
_GenericType_co = TypeVar("_GenericType_co", covariant=True, bound=generic)
@@ -61,6 +63,14 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: _NumberType) -> _NumberType: ...
+ class _BoolBitOp(Protocol[_GenericType_co]):
+ @overload
+ def __call__(self, __other: _BoolLike) -> _GenericType_co: ...
+ @overload # platform dependent
+ def __call__(self, __other: int) -> Union[int32, int64]: ...
+ @overload
+ def __call__(self, __other: _IntType) -> _IntType: ...
+
class _BoolSub(Protocol):
# Note that `__other: bool_` is absent here
@overload # platform dependent
@@ -103,6 +113,17 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ class _UnsignedIntBitOp(Protocol):
+ # The likes of `uint64 | np.signedinteger` will fail as there
+ # is no signed integer type large enough to hold a `uint64`
+ # See https://github.com/numpy/numpy/issues/2524
+ @overload
+ def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ...
+ @overload
+ def __call__(self: uint64, __other: Union[int, signedinteger]) -> NoReturn: ...
+ @overload
+ def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+
class _SignedIntOp(Protocol):
@overload
def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
@@ -111,6 +132,9 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ class _SignedIntBitOp(Protocol):
+ def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+
class _FloatOp(Protocol):
@overload
def __call__(self, __other: _FloatLike) -> floating: ...
@@ -125,12 +149,15 @@ if HAVE_PROTOCOL:
else:
_BoolOp = Any
+ _BoolBitOp = Any
_BoolSub = Any
_BoolTrueDiv = Any
_TD64Div = Any
_IntTrueDiv = Any
_UnsignedIntOp = Any
+ _UnsignedIntBitOp = Any
_SignedIntOp = Any
+ _SignedIntBitOp = Any
_FloatOp = Any
_ComplexOp = Any
_NumberOp = Any