diff options
Diffstat (limited to 'numpy/array_api/_linear_algebra_functions.py')
-rw-r--r-- | numpy/array_api/_linear_algebra_functions.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py index 089081725..7a6c9846c 100644 --- a/numpy/array_api/_linear_algebra_functions.py +++ b/numpy/array_api/_linear_algebra_functions.py @@ -52,13 +52,12 @@ def tensordot( return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) -def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: - """ - Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`. - - See its docstring for more information. - """ - return Array._new(np.transpose(x._array, axes=axes)) +# Note: this function is new in the array API spec. Unlike transpose, it only +# transposes the last two axes. +def matrix_transpose(x: Array, /) -> Array: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return Array._new(np.swapaxes(x._array, -1, -2)) # Note: vecdot is not in NumPy |