summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/__init__.py45
-rw-r--r--numpy/_array_api/constants.py3
-rw-r--r--numpy/_array_api/creation_functions.py45
-rw-r--r--numpy/_array_api/elementwise_functions.py221
-rw-r--r--numpy/_array_api/linear_algebra_functions.py91
-rw-r--r--numpy/_array_api/manipulation_functions.py29
-rw-r--r--numpy/_array_api/searching_functions.py17
-rw-r--r--numpy/_array_api/set_functions.py5
-rw-r--r--numpy/_array_api/sorting_functions.py9
-rw-r--r--numpy/_array_api/statistical_functions.py29
-rw-r--r--numpy/_array_api/utility_functions.py9
11 files changed, 503 insertions, 0 deletions
diff --git a/numpy/_array_api/__init__.py b/numpy/_array_api/__init__.py
new file mode 100644
index 000000000..878251e7c
--- /dev/null
+++ b/numpy/_array_api/__init__.py
@@ -0,0 +1,45 @@
+__all__ = []
+
+from .constants import e, inf, nan, pi
+
+__all__ += ['e', 'inf', 'nan', 'pi']
+
+from .creation_functions import arange, empty, empty_like, eye, full, full_like, linspace, ones, ones_like, zeros, zeros_like
+
+__all__ += ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like']
+
+from .elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and, bitwise_left_shift, bitwise_invert, bitwise_or, bitwise_right_shift, bitwise_xor, ceil, cos, cosh, divide, equal, exp, expm1, floor, floor_divide, greater, greater_equal, isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logical_and, logical_not, logical_or, logical_xor, multiply, negative, not_equal, positive, pow, remainder, round, sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc
+
+__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
+
+from .linear_algebra_functions import cross, det, diagonal, inv, norm, outer, trace, transpose
+
+__all__ += ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose']
+
+# from .linear_algebra_functions import cholesky, cross, det, diagonal, dot, eig, eigvalsh, einsum, inv, lstsq, matmul, matrix_power, matrix_rank, norm, outer, pinv, qr, slogdet, solve, svd, trace, transpose
+#
+# __all__ += ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']
+
+from .manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack
+
+__all__ += ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack']
+
+from .searching_functions import argmax, argmin, nonzero, where
+
+__all__ += ['argmax', 'argmin', 'nonzero', 'where']
+
+from .set_functions import unique
+
+__all__ += ['unique']
+
+from .sorting_functions import argsort, sort
+
+__all__ += ['argsort', 'sort']
+
+from .statistical_functions import max, mean, min, prod, std, sum, var
+
+__all__ += ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var']
+
+from .utility_functions import all, any
+
+__all__ += ['all', 'any']
diff --git a/numpy/_array_api/constants.py b/numpy/_array_api/constants.py
new file mode 100644
index 000000000..000777029
--- /dev/null
+++ b/numpy/_array_api/constants.py
@@ -0,0 +1,3 @@
+from .. import e, inf, nan, pi
+
+__all__ = ['e', 'inf', 'nan', 'pi']
diff --git a/numpy/_array_api/creation_functions.py b/numpy/_array_api/creation_functions.py
new file mode 100644
index 000000000..50b0bd252
--- /dev/null
+++ b/numpy/_array_api/creation_functions.py
@@ -0,0 +1,45 @@
+def arange(start, /, *, stop=None, step=1, dtype=None):
+ from .. import arange
+ return arange(start, stop=stop, step=step, dtype=dtype)
+
+def empty(shape, /, *, dtype=None):
+ from .. import empty
+ return empty(shape, dtype=dtype)
+
+def empty_like(x, /, *, dtype=None):
+ from .. import empty_like
+ return empty_like(x, dtype=dtype)
+
+def eye(N, /, *, M=None, k=0, dtype=None):
+ from .. import eye
+ return eye(N, M=M, k=k, dtype=dtype)
+
+def full(shape, fill_value, /, *, dtype=None):
+ from .. import full
+ return full(shape, fill_value, dtype=dtype)
+
+def full_like(x, fill_value, /, *, dtype=None):
+ from .. import full_like
+ return full_like(x, fill_value, dtype=dtype)
+
+def linspace(start, stop, num, /, *, dtype=None, endpoint=True):
+ from .. import linspace
+ return linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
+
+def ones(shape, /, *, dtype=None):
+ from .. import ones
+ return ones(shape, dtype=dtype)
+
+def ones_like(x, /, *, dtype=None):
+ from .. import ones_like
+ return ones_like(x, dtype=dtype)
+
+def zeros(shape, /, *, dtype=None):
+ from .. import zeros
+ return zeros(shape, dtype=dtype)
+
+def zeros_like(x, /, *, dtype=None):
+ from .. import zeros_like
+ return zeros_like(x, dtype=dtype)
+
+__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like']
diff --git a/numpy/_array_api/elementwise_functions.py b/numpy/_array_api/elementwise_functions.py
new file mode 100644
index 000000000..3e8349d29
--- /dev/null
+++ b/numpy/_array_api/elementwise_functions.py
@@ -0,0 +1,221 @@
+def abs(x, /):
+ from .. import abs
+ return abs(x)
+
+def acos(x, /):
+ from .. import acos
+ return acos(x)
+
+def acosh(x, /):
+ from .. import acosh
+ return acosh(x)
+
+def add(x1, x2, /):
+ from .. import add
+ return add(x1, x2)
+
+def asin(x, /):
+ from .. import asin
+ return asin(x)
+
+def asinh(x, /):
+ from .. import asinh
+ return asinh(x)
+
+def atan(x, /):
+ from .. import atan
+ return atan(x)
+
+def atan2(x1, x2, /):
+ from .. import atan2
+ return atan2(x1, x2)
+
+def atanh(x, /):
+ from .. import atanh
+ return atanh(x)
+
+def bitwise_and(x1, x2, /):
+ from .. import bitwise_and
+ return bitwise_and(x1, x2)
+
+def bitwise_left_shift(x1, x2, /):
+ from .. import bitwise_left_shift
+ return bitwise_left_shift(x1, x2)
+
+def bitwise_invert(x, /):
+ from .. import bitwise_invert
+ return bitwise_invert(x)
+
+def bitwise_or(x1, x2, /):
+ from .. import bitwise_or
+ return bitwise_or(x1, x2)
+
+def bitwise_right_shift(x1, x2, /):
+ from .. import bitwise_right_shift
+ return bitwise_right_shift(x1, x2)
+
+def bitwise_xor(x1, x2, /):
+ from .. import bitwise_xor
+ return bitwise_xor(x1, x2)
+
+def ceil(x, /):
+ from .. import ceil
+ return ceil(x)
+
+def cos(x, /):
+ from .. import cos
+ return cos(x)
+
+def cosh(x, /):
+ from .. import cosh
+ return cosh(x)
+
+def divide(x1, x2, /):
+ from .. import divide
+ return divide(x1, x2)
+
+def equal(x1, x2, /):
+ from .. import equal
+ return equal(x1, x2)
+
+def exp(x, /):
+ from .. import exp
+ return exp(x)
+
+def expm1(x, /):
+ from .. import expm1
+ return expm1(x)
+
+def floor(x, /):
+ from .. import floor
+ return floor(x)
+
+def floor_divide(x1, x2, /):
+ from .. import floor_divide
+ return floor_divide(x1, x2)
+
+def greater(x1, x2, /):
+ from .. import greater
+ return greater(x1, x2)
+
+def greater_equal(x1, x2, /):
+ from .. import greater_equal
+ return greater_equal(x1, x2)
+
+def isfinite(x, /):
+ from .. import isfinite
+ return isfinite(x)
+
+def isinf(x, /):
+ from .. import isinf
+ return isinf(x)
+
+def isnan(x, /):
+ from .. import isnan
+ return isnan(x)
+
+def less(x1, x2, /):
+ from .. import less
+ return less(x1, x2)
+
+def less_equal(x1, x2, /):
+ from .. import less_equal
+ return less_equal(x1, x2)
+
+def log(x, /):
+ from .. import log
+ return log(x)
+
+def log1p(x, /):
+ from .. import log1p
+ return log1p(x)
+
+def log2(x, /):
+ from .. import log2
+ return log2(x)
+
+def log10(x, /):
+ from .. import log10
+ return log10(x)
+
+def logical_and(x1, x2, /):
+ from .. import logical_and
+ return logical_and(x1, x2)
+
+def logical_not(x, /):
+ from .. import logical_not
+ return logical_not(x)
+
+def logical_or(x1, x2, /):
+ from .. import logical_or
+ return logical_or(x1, x2)
+
+def logical_xor(x1, x2, /):
+ from .. import logical_xor
+ return logical_xor(x1, x2)
+
+def multiply(x1, x2, /):
+ from .. import multiply
+ return multiply(x1, x2)
+
+def negative(x, /):
+ from .. import negative
+ return negative(x)
+
+def not_equal(x1, x2, /):
+ from .. import not_equal
+ return not_equal(x1, x2)
+
+def positive(x, /):
+ from .. import positive
+ return positive(x)
+
+def pow(x1, x2, /):
+ from .. import pow
+ return pow(x1, x2)
+
+def remainder(x1, x2, /):
+ from .. import remainder
+ return remainder(x1, x2)
+
+def round(x, /):
+ from .. import round
+ return round(x)
+
+def sign(x, /):
+ from .. import sign
+ return sign(x)
+
+def sin(x, /):
+ from .. import sin
+ return sin(x)
+
+def sinh(x, /):
+ from .. import sinh
+ return sinh(x)
+
+def square(x, /):
+ from .. import square
+ return square(x)
+
+def sqrt(x, /):
+ from .. import sqrt
+ return sqrt(x)
+
+def subtract(x1, x2, /):
+ from .. import subtract
+ return subtract(x1, x2)
+
+def tan(x, /):
+ from .. import tan
+ return tan(x)
+
+def tanh(x, /):
+ from .. import tanh
+ return tanh(x)
+
+def trunc(x, /):
+ from .. import trunc
+ return trunc(x)
+
+__all__ = ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
diff --git a/numpy/_array_api/linear_algebra_functions.py b/numpy/_array_api/linear_algebra_functions.py
new file mode 100644
index 000000000..5da7ac17b
--- /dev/null
+++ b/numpy/_array_api/linear_algebra_functions.py
@@ -0,0 +1,91 @@
+# def cholesky():
+# from .. import cholesky
+# return cholesky()
+
+def cross(x1, x2, /, *, axis=-1):
+ from .. import cross
+ return cross(x1, x2, axis=axis)
+
+def det(x, /):
+ from .. import det
+ return det(x)
+
+def diagonal(x, /, *, axis1=0, axis2=1, offset=0):
+ from .. import diagonal
+ return diagonal(x, axis1=axis1, axis2=axis2, offset=offset)
+
+# def dot():
+# from .. import dot
+# return dot()
+#
+# def eig():
+# from .. import eig
+# return eig()
+#
+# def eigvalsh():
+# from .. import eigvalsh
+# return eigvalsh()
+#
+# def einsum():
+# from .. import einsum
+# return einsum()
+
+def inv(x):
+ from .. import inv
+ return inv(x)
+
+# def lstsq():
+# from .. import lstsq
+# return lstsq()
+#
+# def matmul():
+# from .. import matmul
+# return matmul()
+#
+# def matrix_power():
+# from .. import matrix_power
+# return matrix_power()
+#
+# def matrix_rank():
+# from .. import matrix_rank
+# return matrix_rank()
+
+def norm(x, /, *, axis=None, keepdims=False, ord=None):
+ from .. import norm
+ return norm(x, axis=axis, keepdims=keepdims, ord=ord)
+
+def outer(x1, x2, /):
+ from .. import outer
+ return outer(x1, x2)
+
+# def pinv():
+# from .. import pinv
+# return pinv()
+#
+# def qr():
+# from .. import qr
+# return qr()
+#
+# def slogdet():
+# from .. import slogdet
+# return slogdet()
+#
+# def solve():
+# from .. import solve
+# return solve()
+#
+# def svd():
+# from .. import svd
+# return svd()
+
+def trace(x, /, *, axis1=0, axis2=1, offset=0):
+ from .. import trace
+ return trace(x, axis1=axis1, axis2=axis2, offset=offset)
+
+def transpose(x, /, *, axes=None):
+ from .. import transpose
+ return transpose(x, axes=axes)
+
+# __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']
+
+__all__ = ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose']
diff --git a/numpy/_array_api/manipulation_functions.py b/numpy/_array_api/manipulation_functions.py
new file mode 100644
index 000000000..1934e8e4e
--- /dev/null
+++ b/numpy/_array_api/manipulation_functions.py
@@ -0,0 +1,29 @@
+def concat(arrays, /, *, axis=0):
+ from .. import concat
+ return concat(arrays, axis=axis)
+
+def expand_dims(x, axis, /):
+ from .. import expand_dims
+ return expand_dims(x, axis)
+
+def flip(x, /, *, axis=None):
+ from .. import flip
+ return flip(x, axis=axis)
+
+def reshape(x, shape, /):
+ from .. import reshape
+ return reshape(x, shape)
+
+def roll(x, shift, /, *, axis=None):
+ from .. import roll
+ return roll(x, shift, axis=axis)
+
+def squeeze(x, /, *, axis=None):
+ from .. import squeeze
+ return squeeze(x, axis=axis)
+
+def stack(arrays, /, *, axis=0):
+ from .. import stack
+ return stack(arrays, axis=axis)
+
+__all__ = ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack']
diff --git a/numpy/_array_api/searching_functions.py b/numpy/_array_api/searching_functions.py
new file mode 100644
index 000000000..c4b6c58b5
--- /dev/null
+++ b/numpy/_array_api/searching_functions.py
@@ -0,0 +1,17 @@
+def argmax(x, /, *, axis=None, keepdims=False):
+ from .. import argmax
+ return argmax(x, axis=axis, keepdims=keepdims)
+
+def argmin(x, /, *, axis=None, keepdims=False):
+ from .. import argmin
+ return argmin(x, axis=axis, keepdims=keepdims)
+
+def nonzero(x, /):
+ from .. import nonzero
+ return nonzero(x)
+
+def where(condition, x1, x2, /):
+ from .. import where
+ return where(condition, x1, x2)
+
+__all__ = ['argmax', 'argmin', 'nonzero', 'where']
diff --git a/numpy/_array_api/set_functions.py b/numpy/_array_api/set_functions.py
new file mode 100644
index 000000000..f218f1187
--- /dev/null
+++ b/numpy/_array_api/set_functions.py
@@ -0,0 +1,5 @@
+def unique(x, /, *, return_counts=False, return_index=False, return_inverse=False, sorted=True):
+ from .. import unique
+ return unique(x, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse, sorted=sorted)
+
+__all__ = ['unique']
diff --git a/numpy/_array_api/sorting_functions.py b/numpy/_array_api/sorting_functions.py
new file mode 100644
index 000000000..384ec08f9
--- /dev/null
+++ b/numpy/_array_api/sorting_functions.py
@@ -0,0 +1,9 @@
+def argsort(x, /, *, axis=-1, descending=False, stable=True):
+ from .. import argsort
+ return argsort(x, axis=axis, descending=descending, stable=stable)
+
+def sort(x, /, *, axis=-1, descending=False, stable=True):
+ from .. import sort
+ return sort(x, axis=axis, descending=descending, stable=stable)
+
+__all__ = ['argsort', 'sort']
diff --git a/numpy/_array_api/statistical_functions.py b/numpy/_array_api/statistical_functions.py
new file mode 100644
index 000000000..2cc712aea
--- /dev/null
+++ b/numpy/_array_api/statistical_functions.py
@@ -0,0 +1,29 @@
+def max(x, /, *, axis=None, keepdims=False):
+ from .. import max
+ return max(x, axis=axis, keepdims=keepdims)
+
+def mean(x, /, *, axis=None, keepdims=False):
+ from .. import mean
+ return mean(x, axis=axis, keepdims=keepdims)
+
+def min(x, /, *, axis=None, keepdims=False):
+ from .. import min
+ return min(x, axis=axis, keepdims=keepdims)
+
+def prod(x, /, *, axis=None, keepdims=False):
+ from .. import prod
+ return prod(x, axis=axis, keepdims=keepdims)
+
+def std(x, /, *, axis=None, correction=0.0, keepdims=False):
+ from .. import std
+ return std(x, axis=axis, correction=correction, keepdims=keepdims)
+
+def sum(x, /, *, axis=None, keepdims=False):
+ from .. import sum
+ return sum(x, axis=axis, keepdims=keepdims)
+
+def var(x, /, *, axis=None, correction=0.0, keepdims=False):
+ from .. import var
+ return var(x, axis=axis, correction=correction, keepdims=keepdims)
+
+__all__ = ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var']
diff --git a/numpy/_array_api/utility_functions.py b/numpy/_array_api/utility_functions.py
new file mode 100644
index 000000000..eac0d4eaa
--- /dev/null
+++ b/numpy/_array_api/utility_functions.py
@@ -0,0 +1,9 @@
+def all(x, /, *, axis=None, keepdims=False):
+ from .. import all
+ return all(x, axis=axis, keepdims=keepdims)
+
+def any(x, /, *, axis=None, keepdims=False):
+ from .. import any
+ return any(x, axis=axis, keepdims=keepdims)
+
+__all__ = ['all', 'any']