diff options
author | Aaron Meurer <asmeurer@gmail.com> | 2021-11-14 14:35:06 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-14 22:35:06 +0100 |
commit | a1813504ad44b70fb139181a9df8465bcb22e24d (patch) | |
tree | 1c638079fe4e675976d73a0fb120f2f05d2ec522 /numpy/array_api/_linear_algebra_functions.py | |
parent | b8a0f339dcd90c134e1cc3e19d06348069af685b (diff) | |
download | numpy-a1813504ad44b70fb139181a9df8465bcb22e24d.tar.gz |
ENH: Add the linalg extension to the array_api submodule (#19980)
Diffstat (limited to 'numpy/array_api/_linear_algebra_functions.py')
-rw-r--r-- | numpy/array_api/_linear_algebra_functions.py | 67 |
1 files changed, 0 insertions, 67 deletions
diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py deleted file mode 100644 index 7a6c9846c..000000000 --- a/numpy/array_api/_linear_algebra_functions.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from ._array_object import Array -from ._dtypes import _numeric_dtypes, _result_type - -from typing import Optional, Sequence, Tuple, Union - -import numpy as np - -# einsum is not yet implemented in the array API spec. - -# def einsum(): -# """ -# Array API compatible wrapper for :py:func:`np.einsum <numpy.einsum>`. -# -# See its docstring for more information. -# """ -# return np.einsum() - - -def matmul(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`. - - See its docstring for more information. - """ - # Note: the restriction to numeric dtypes only is different from - # np.matmul. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in matmul") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - - return Array._new(np.matmul(x1._array, x2._array)) - - -# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. -def tensordot( - x1: Array, - x2: Array, - /, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, -) -> Array: - # Note: the restriction to numeric dtypes only is different from - # np.tensordot. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in tensordot") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - - return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) - - -# Note: this function is new in the array API spec. Unlike transpose, it only -# transposes the last two axes. -def matrix_transpose(x: Array, /) -> Array: - if x.ndim < 2: - raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return Array._new(np.swapaxes(x._array, -1, -2)) - - -# Note: vecdot is not in NumPy -def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array: - if axis is None: - axis = -1 - return tensordot(x1, x2, axes=((axis,), (axis,))) |