diff options
Diffstat (limited to 'numpy/array_api/linalg.py')
-rw-r--r-- | numpy/array_api/linalg.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index 8d7ba659e..f422e1c27 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -89,7 +89,6 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array: return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) -# Note: the keyword argument name upper is different from np.linalg.eigh def eigh(x: Array, /) -> EighResult: """ Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`. @@ -106,7 +105,6 @@ def eigh(x: Array, /) -> EighResult: return EighResult(*map(Array._new, np.linalg.eigh(x._array))) -# Note: the keyword argument name upper is different from np.linalg.eigvalsh def eigvalsh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`. @@ -346,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) # Note: tensordot is the numpy top-level namespace but not in np.linalg @@ -366,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in trace') # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1))) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in vecdot') return tensordot(x1, x2, axes=((axis,), (axis,))) @@ -380,7 +384,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. -def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: +def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`. |