summaryrefslogtreecommitdiff
path: root/numpy/array_api/_linear_algebra_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-11-14 14:35:06 -0700
committerGitHub <noreply@github.com>2021-11-14 22:35:06 +0100
commita1813504ad44b70fb139181a9df8465bcb22e24d (patch)
tree1c638079fe4e675976d73a0fb120f2f05d2ec522 /numpy/array_api/_linear_algebra_functions.py
parentb8a0f339dcd90c134e1cc3e19d06348069af685b (diff)
downloadnumpy-a1813504ad44b70fb139181a9df8465bcb22e24d.tar.gz
ENH: Add the linalg extension to the array_api submodule (#19980)
Diffstat (limited to 'numpy/array_api/_linear_algebra_functions.py')
-rw-r--r--numpy/array_api/_linear_algebra_functions.py67
1 files changed, 0 insertions, 67 deletions
diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py
deleted file mode 100644
index 7a6c9846c..000000000
--- a/numpy/array_api/_linear_algebra_functions.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from __future__ import annotations
-
-from ._array_object import Array
-from ._dtypes import _numeric_dtypes, _result_type
-
-from typing import Optional, Sequence, Tuple, Union
-
-import numpy as np
-
-# einsum is not yet implemented in the array API spec.
-
-# def einsum():
-# """
-# Array API compatible wrapper for :py:func:`np.einsum <numpy.einsum>`.
-#
-# See its docstring for more information.
-# """
-# return np.einsum()
-
-
-def matmul(x1: Array, x2: Array, /) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
-
- See its docstring for more information.
- """
- # Note: the restriction to numeric dtypes only is different from
- # np.matmul.
- if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError("Only numeric dtypes are allowed in matmul")
- # Call result type here just to raise on disallowed type combinations
- _result_type(x1.dtype, x2.dtype)
-
- return Array._new(np.matmul(x1._array, x2._array))
-
-
-# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
-def tensordot(
- x1: Array,
- x2: Array,
- /,
- *,
- axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
-) -> Array:
- # Note: the restriction to numeric dtypes only is different from
- # np.tensordot.
- if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError("Only numeric dtypes are allowed in tensordot")
- # Call result type here just to raise on disallowed type combinations
- _result_type(x1.dtype, x2.dtype)
-
- return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
-
-
-# Note: this function is new in the array API spec. Unlike transpose, it only
-# transposes the last two axes.
-def matrix_transpose(x: Array, /) -> Array:
- if x.ndim < 2:
- raise ValueError("x must be at least 2-dimensional for matrix_transpose")
- return Array._new(np.swapaxes(x._array, -1, -2))
-
-
-# Note: vecdot is not in NumPy
-def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array:
- if axis is None:
- axis = -1
- return tensordot(x1, x2, axes=((axis,), (axis,)))