summaryrefslogtreecommitdiff
path: root/numpy/_array_api/linear_algebra_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-01-11 16:03:43 -0700
committerAaron Meurer <asmeurer@gmail.com>2021-01-11 16:03:43 -0700
commit012343dec5599418b77512733fc5b8db6bc14c4c (patch)
tree13ef3a6bfdcd603036baaf8d65d1647eea749afc /numpy/_array_api/linear_algebra_functions.py
parent33dc7bea24f1ab6c47047b49521e732caeb485d5 (diff)
downloadnumpy-012343dec5599418b77512733fc5b8db6bc14c4c.tar.gz
Add initial array_api sub-namespace
This is based on the function stubs from the array API test suite, and is currently based on the assumption that NumPy already follows the array API standard. Now it needs to be modified to fix it in the places where NumPy deviates (for example, different function names for inverse trigonometric functions).
Diffstat (limited to 'numpy/_array_api/linear_algebra_functions.py')
-rw-r--r--numpy/_array_api/linear_algebra_functions.py91
1 files changed, 91 insertions, 0 deletions
diff --git a/numpy/_array_api/linear_algebra_functions.py b/numpy/_array_api/linear_algebra_functions.py
new file mode 100644
index 000000000..5da7ac17b
--- /dev/null
+++ b/numpy/_array_api/linear_algebra_functions.py
@@ -0,0 +1,91 @@
+# def cholesky():
+# from .. import cholesky
+# return cholesky()
+
+def cross(x1, x2, /, *, axis=-1):
+ from .. import cross
+ return cross(x1, x2, axis=axis)
+
+def det(x, /):
+ from .. import det
+ return det(x)
+
+def diagonal(x, /, *, axis1=0, axis2=1, offset=0):
+ from .. import diagonal
+ return diagonal(x, axis1=axis1, axis2=axis2, offset=offset)
+
+# def dot():
+# from .. import dot
+# return dot()
+#
+# def eig():
+# from .. import eig
+# return eig()
+#
+# def eigvalsh():
+# from .. import eigvalsh
+# return eigvalsh()
+#
+# def einsum():
+# from .. import einsum
+# return einsum()
+
+def inv(x):
+ from .. import inv
+ return inv(x)
+
+# def lstsq():
+# from .. import lstsq
+# return lstsq()
+#
+# def matmul():
+# from .. import matmul
+# return matmul()
+#
+# def matrix_power():
+# from .. import matrix_power
+# return matrix_power()
+#
+# def matrix_rank():
+# from .. import matrix_rank
+# return matrix_rank()
+
+def norm(x, /, *, axis=None, keepdims=False, ord=None):
+ from .. import norm
+ return norm(x, axis=axis, keepdims=keepdims, ord=ord)
+
+def outer(x1, x2, /):
+ from .. import outer
+ return outer(x1, x2)
+
+# def pinv():
+# from .. import pinv
+# return pinv()
+#
+# def qr():
+# from .. import qr
+# return qr()
+#
+# def slogdet():
+# from .. import slogdet
+# return slogdet()
+#
+# def solve():
+# from .. import solve
+# return solve()
+#
+# def svd():
+# from .. import svd
+# return svd()
+
+def trace(x, /, *, axis1=0, axis2=1, offset=0):
+ from .. import trace
+ return trace(x, axis1=axis1, axis2=axis2, offset=offset)
+
+def transpose(x, /, *, axes=None):
+ from .. import transpose
+ return transpose(x, axes=axes)
+
+# __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']
+
+__all__ = ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose']