diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/_typing/_ufunc.pyi | 46 |
1 files changed, 44 insertions, 2 deletions
diff --git a/numpy/_typing/_ufunc.pyi b/numpy/_typing/_ufunc.pyi index ee0317cf9..9f8e0d4ed 100644 --- a/numpy/_typing/_ufunc.pyi +++ b/numpy/_typing/_ufunc.pyi @@ -14,6 +14,7 @@ from typing import ( TypeVar, Literal, SupportsIndex, + Protocol, ) from numpy import ufunc, _CastingKind, _OrderKACF @@ -33,6 +34,17 @@ _NTypes = TypeVar("_NTypes", bound=int) _IDType = TypeVar("_IDType", bound=Any) _NameType = TypeVar("_NameType", bound=str) + +class _SupportsArrayUFunc(Protocol): + def __array_ufunc__( + self, + ufunc: ufunc, + method: Literal["__call__", "reduce", "reduceat", "accumulate", "outer", "inner"], + *inputs: Any, + **kwargs: Any, + ) -> Any: ... + + # NOTE: In reality `extobj` should be a length of list 3 containing an # int, an int, and a callable, but there's no way to properly express # non-homogenous lists. @@ -100,10 +112,24 @@ class _UFunc_Nin1_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: i signature: str | _2Tuple[None | str] = ..., extobj: list[Any] = ..., ) -> NDArray[Any]: ... + @overload + def __call__( + self, + __x1: _SupportsArrayUFunc, + out: None | NDArray[Any] | tuple[NDArray[Any]] = ..., + *, + where: None | _ArrayLikeBool_co = ..., + casting: _CastingKind = ..., + order: _OrderKACF = ..., + dtype: DTypeLike = ..., + subok: bool = ..., + signature: str | _2Tuple[None | str] = ..., + extobj: list[Any] = ..., + ) -> Any: ... def at( self, - a: NDArray[Any], + a: _SupportsArrayUFunc, indices: _ArrayLikeInt_co, /, ) -> None: ... @@ -280,6 +306,22 @@ class _UFunc_Nin1_Nout2(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: i signature: str | _3Tuple[None | str] = ..., extobj: list[Any] = ..., ) -> _2Tuple[NDArray[Any]]: ... + @overload + def __call__( + self, + __x1: _SupportsArrayUFunc, + __out1: None | NDArray[Any] = ..., + __out2: None | NDArray[Any] = ..., + *, + out: _2Tuple[NDArray[Any]] = ..., + where: None | _ArrayLikeBool_co = ..., + casting: _CastingKind = ..., + order: _OrderKACF = ..., + dtype: DTypeLike = ..., + subok: bool = ..., + signature: str | _3Tuple[None | str] = ..., + extobj: list[Any] = ..., + ) -> _2Tuple[Any]: ... class _UFunc_Nin2_Nout2(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: ignore[misc] @property @@ -355,7 +397,7 @@ class _GUFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: @property def nargs(self) -> Literal[3]: ... - # NOTE: In practice the only gufunc in the main name is `matmul`, + # NOTE: In practice the only gufunc in the main namespace is `matmul`, # so we can use its signature here @property def signature(self) -> Literal["(n?,k),(k,m?)->(n?,m?)"]: ... |
