summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/_typing/_ufunc.pyi46
1 files changed, 44 insertions, 2 deletions
diff --git a/numpy/_typing/_ufunc.pyi b/numpy/_typing/_ufunc.pyi
index ee0317cf9..9f8e0d4ed 100644
--- a/numpy/_typing/_ufunc.pyi
+++ b/numpy/_typing/_ufunc.pyi
@@ -14,6 +14,7 @@ from typing import (
TypeVar,
Literal,
SupportsIndex,
+ Protocol,
)
from numpy import ufunc, _CastingKind, _OrderKACF
@@ -33,6 +34,17 @@ _NTypes = TypeVar("_NTypes", bound=int)
_IDType = TypeVar("_IDType", bound=Any)
_NameType = TypeVar("_NameType", bound=str)
+
+class _SupportsArrayUFunc(Protocol):
+ def __array_ufunc__(
+ self,
+ ufunc: ufunc,
+ method: Literal["__call__", "reduce", "reduceat", "accumulate", "outer", "inner"],
+ *inputs: Any,
+ **kwargs: Any,
+ ) -> Any: ...
+
+
# NOTE: In reality `extobj` should be a length of list 3 containing an
# int, an int, and a callable, but there's no way to properly express
# non-homogenous lists.
@@ -100,10 +112,24 @@ class _UFunc_Nin1_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: i
signature: str | _2Tuple[None | str] = ...,
extobj: list[Any] = ...,
) -> NDArray[Any]: ...
+ @overload
+ def __call__(
+ self,
+ __x1: _SupportsArrayUFunc,
+ out: None | NDArray[Any] | tuple[NDArray[Any]] = ...,
+ *,
+ where: None | _ArrayLikeBool_co = ...,
+ casting: _CastingKind = ...,
+ order: _OrderKACF = ...,
+ dtype: DTypeLike = ...,
+ subok: bool = ...,
+ signature: str | _2Tuple[None | str] = ...,
+ extobj: list[Any] = ...,
+ ) -> Any: ...
def at(
self,
- a: NDArray[Any],
+ a: _SupportsArrayUFunc,
indices: _ArrayLikeInt_co,
/,
) -> None: ...
@@ -280,6 +306,22 @@ class _UFunc_Nin1_Nout2(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: i
signature: str | _3Tuple[None | str] = ...,
extobj: list[Any] = ...,
) -> _2Tuple[NDArray[Any]]: ...
+ @overload
+ def __call__(
+ self,
+ __x1: _SupportsArrayUFunc,
+ __out1: None | NDArray[Any] = ...,
+ __out2: None | NDArray[Any] = ...,
+ *,
+ out: _2Tuple[NDArray[Any]] = ...,
+ where: None | _ArrayLikeBool_co = ...,
+ casting: _CastingKind = ...,
+ order: _OrderKACF = ...,
+ dtype: DTypeLike = ...,
+ subok: bool = ...,
+ signature: str | _3Tuple[None | str] = ...,
+ extobj: list[Any] = ...,
+ ) -> _2Tuple[Any]: ...
class _UFunc_Nin2_Nout2(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: ignore[misc]
@property
@@ -355,7 +397,7 @@ class _GUFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type:
@property
def nargs(self) -> Literal[3]: ...
- # NOTE: In practice the only gufunc in the main name is `matmul`,
+ # NOTE: In practice the only gufunc in the main namespace is `matmul`,
# so we can use its signature here
@property
def signature(self) -> Literal["(n?,k),(k,m?)->(n?,m?)"]: ...