From f375d71ca101db9541b1e70476999b574634556d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Mar 2022 18:12:42 -0600 Subject: Properly restrict the input dtypes for the array_api trace, svdvals, and vecdot --- numpy/array_api/linalg.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'numpy/array_api') diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index cb04113b6..8398147ba 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -344,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) # Note: tensordot is the numpy top-level namespace but not in np.linalg @@ -364,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in trace') # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1))) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in vecdot') return tensordot(x1, x2, axes=((axis,), (axis,))) -- cgit v1.2.1