summaryrefslogtreecommitdiff
path: root/numpy/array_api/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/linalg.py')
-rw-r--r--numpy/array_api/linalg.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py
index 8d7ba659e..f422e1c27 100644
--- a/numpy/array_api/linalg.py
+++ b/numpy/array_api/linalg.py
@@ -89,7 +89,6 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array:
return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
-# Note: the keyword argument name upper is different from np.linalg.eigh
def eigh(x: Array, /) -> EighResult:
"""
Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`.
@@ -106,7 +105,6 @@ def eigh(x: Array, /) -> EighResult:
return EighResult(*map(Array._new, np.linalg.eigh(x._array)))
-# Note: the keyword argument name upper is different from np.linalg.eigvalsh
def eigvalsh(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`.
@@ -346,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# np.linalg.svd(compute_uv=False).
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in svdvals')
return Array._new(np.linalg.svd(x._array, compute_uv=False))
# Note: tensordot is the numpy top-level namespace but not in np.linalg
@@ -366,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in trace')
# Note: trace always operates on the last two axes, whereas np.trace
# operates on the first two axes by default
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
# Note: vecdot is not in NumPy
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in vecdot')
return tensordot(x1, x2, axes=((axis,), (axis,)))
@@ -380,7 +384,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
# The type for ord should be Optional[Union[int, float, Literal[np.inf,
# -np.inf]]] but Literal does not support floating-point literals.
-def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
+def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.