summaryrefslogtreecommitdiff
path: root/numpy/array_api/_linear_algebra_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/_linear_algebra_functions.py')
-rw-r--r--numpy/array_api/_linear_algebra_functions.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py
index 089081725..7a6c9846c 100644
--- a/numpy/array_api/_linear_algebra_functions.py
+++ b/numpy/array_api/_linear_algebra_functions.py
@@ -52,13 +52,12 @@ def tensordot(
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
-def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
-
- See its docstring for more information.
- """
- return Array._new(np.transpose(x._array, axes=axes))
+# Note: this function is new in the array API spec. Unlike transpose, it only
+# transposes the last two axes.
+def matrix_transpose(x: Array, /) -> Array:
+ if x.ndim < 2:
+ raise ValueError("x must be at least 2-dimensional for matrix_transpose")
+ return Array._new(np.swapaxes(x._array, -1, -2))
# Note: vecdot is not in NumPy