diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2023-05-17 15:02:06 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-17 15:02:06 -0600 |
commit | 126b46c7abdd7970d6deb78349de4b6bf6525e44 (patch) | |
tree | 363273c29d27efc89984983cf5b7a3b152aba0b3 /numpy/linalg/linalg.pyi | |
parent | d9b38d687cd513aa6688f7fba805a908c0ac3979 (diff) | |
parent | a4c249653ec9d063a67a6cde8123dca2defb8f8b (diff) | |
download | numpy-main.tar.gz |
ENH: Add namedtuple return types to linalg functions that return tuples
Diffstat (limited to 'numpy/linalg/linalg.pyi')
-rw-r--r-- | numpy/linalg/linalg.pyi | 65 |
1 files changed, 40 insertions, 25 deletions
diff --git a/numpy/linalg/linalg.pyi b/numpy/linalg/linalg.pyi index 20cdb708b..c0b2f29b2 100644 --- a/numpy/linalg/linalg.pyi +++ b/numpy/linalg/linalg.pyi @@ -6,6 +6,8 @@ from typing import ( Any, SupportsIndex, SupportsInt, + NamedTuple, + Generic, ) from numpy import ( @@ -31,12 +33,37 @@ from numpy._typing import ( _T = TypeVar("_T") _ArrayType = TypeVar("_ArrayType", bound=NDArray[Any]) +_SCT = TypeVar("_SCT", bound=generic, covariant=True) +_SCT2 = TypeVar("_SCT2", bound=generic, covariant=True) _2Tuple = tuple[_T, _T] _ModeKind = L["reduced", "complete", "r", "raw"] __all__: list[str] +class EigResult(NamedTuple): + eigenvalues: NDArray[Any] + eigenvectors: NDArray[Any] + +class EighResult(NamedTuple): + eigenvalues: NDArray[Any] + eigenvectors: NDArray[Any] + +class QRResult(NamedTuple): + Q: NDArray[Any] + R: NDArray[Any] + +class SlogdetResult(NamedTuple): + # TODO: `sign` and `logabsdet` are scalars for input 2D arrays and + # a `(x.ndim - 2)`` dimensionl arrays otherwise + sign: Any + logabsdet: Any + +class SVDResult(NamedTuple): + U: NDArray[Any] + S: NDArray[Any] + Vh: NDArray[Any] + @overload def tensorsolve( a: _ArrayLikeInt_co, @@ -110,11 +137,11 @@ def cholesky(a: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ... def cholesky(a: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ... @overload -def qr(a: _ArrayLikeInt_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[float64]]: ... +def qr(a: _ArrayLikeInt_co, mode: _ModeKind = ...) -> QRResult: ... @overload -def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[floating[Any]]]: ... +def qr(a: _ArrayLikeFloat_co, mode: _ModeKind = ...) -> QRResult: ... @overload -def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = ...) -> _2Tuple[NDArray[complexfloating[Any, Any]]]: ... +def qr(a: _ArrayLikeComplex_co, mode: _ModeKind = ...) -> QRResult: ... @overload def eigvals(a: _ArrayLikeInt_co) -> NDArray[float64] | NDArray[complex128]: ... @@ -129,27 +156,27 @@ def eigvalsh(a: _ArrayLikeInt_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[ def eigvalsh(a: _ArrayLikeComplex_co, UPLO: L["L", "U", "l", "u"] = ...) -> NDArray[floating[Any]]: ... @overload -def eig(a: _ArrayLikeInt_co) -> _2Tuple[NDArray[float64]] | _2Tuple[NDArray[complex128]]: ... +def eig(a: _ArrayLikeInt_co) -> EigResult: ... @overload -def eig(a: _ArrayLikeFloat_co) -> _2Tuple[NDArray[floating[Any]]] | _2Tuple[NDArray[complexfloating[Any, Any]]]: ... +def eig(a: _ArrayLikeFloat_co) -> EigResult: ... @overload -def eig(a: _ArrayLikeComplex_co) -> _2Tuple[NDArray[complexfloating[Any, Any]]]: ... +def eig(a: _ArrayLikeComplex_co) -> EigResult: ... @overload def eigh( a: _ArrayLikeInt_co, UPLO: L["L", "U", "l", "u"] = ..., -) -> tuple[NDArray[float64], NDArray[float64]]: ... +) -> EighResult: ... @overload def eigh( a: _ArrayLikeFloat_co, UPLO: L["L", "U", "l", "u"] = ..., -) -> tuple[NDArray[floating[Any]], NDArray[floating[Any]]]: ... +) -> EighResult: ... @overload def eigh( a: _ArrayLikeComplex_co, UPLO: L["L", "U", "l", "u"] = ..., -) -> tuple[NDArray[floating[Any]], NDArray[complexfloating[Any, Any]]]: ... +) -> EighResult: ... @overload def svd( @@ -157,33 +184,21 @@ def svd( full_matrices: bool = ..., compute_uv: L[True] = ..., hermitian: bool = ..., -) -> tuple[ - NDArray[float64], - NDArray[float64], - NDArray[float64], -]: ... +) -> SVDResult: ... @overload def svd( a: _ArrayLikeFloat_co, full_matrices: bool = ..., compute_uv: L[True] = ..., hermitian: bool = ..., -) -> tuple[ - NDArray[floating[Any]], - NDArray[floating[Any]], - NDArray[floating[Any]], -]: ... +) -> SVDResult: ... @overload def svd( a: _ArrayLikeComplex_co, full_matrices: bool = ..., compute_uv: L[True] = ..., hermitian: bool = ..., -) -> tuple[ - NDArray[complexfloating[Any, Any]], - NDArray[floating[Any]], - NDArray[complexfloating[Any, Any]], -]: ... +) -> SVDResult: ... @overload def svd( a: _ArrayLikeInt_co, @@ -231,7 +246,7 @@ def pinv( # TODO: Returns a 2-tuple of scalars for 2D arrays and # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise -def slogdet(a: _ArrayLikeComplex_co) -> _2Tuple[Any]: ... +def slogdet(a: _ArrayLikeComplex_co) -> SlogdetResult: ... # TODO: Returns a 2-tuple of scalars for 2D arrays and # a 2-tuple of `(a.ndim - 2)`` dimensionl arrays otherwise |