diff options
Diffstat (limited to 'numpy/_array_api')
| -rw-r--r-- | numpy/_array_api/__init__.py | 45 | ||||
| -rw-r--r-- | numpy/_array_api/constants.py | 3 | ||||
| -rw-r--r-- | numpy/_array_api/creation_functions.py | 45 | ||||
| -rw-r--r-- | numpy/_array_api/elementwise_functions.py | 221 | ||||
| -rw-r--r-- | numpy/_array_api/linear_algebra_functions.py | 91 | ||||
| -rw-r--r-- | numpy/_array_api/manipulation_functions.py | 29 | ||||
| -rw-r--r-- | numpy/_array_api/searching_functions.py | 17 | ||||
| -rw-r--r-- | numpy/_array_api/set_functions.py | 5 | ||||
| -rw-r--r-- | numpy/_array_api/sorting_functions.py | 9 | ||||
| -rw-r--r-- | numpy/_array_api/statistical_functions.py | 29 | ||||
| -rw-r--r-- | numpy/_array_api/utility_functions.py | 9 |
11 files changed, 503 insertions, 0 deletions
diff --git a/numpy/_array_api/__init__.py b/numpy/_array_api/__init__.py new file mode 100644 index 000000000..878251e7c --- /dev/null +++ b/numpy/_array_api/__init__.py @@ -0,0 +1,45 @@ +__all__ = [] + +from .constants import e, inf, nan, pi + +__all__ += ['e', 'inf', 'nan', 'pi'] + +from .creation_functions import arange, empty, empty_like, eye, full, full_like, linspace, ones, ones_like, zeros, zeros_like + +__all__ += ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like'] + +from .elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and, bitwise_left_shift, bitwise_invert, bitwise_or, bitwise_right_shift, bitwise_xor, ceil, cos, cosh, divide, equal, exp, expm1, floor, floor_divide, greater, greater_equal, isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logical_and, logical_not, logical_or, logical_xor, multiply, negative, not_equal, positive, pow, remainder, round, sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc + +__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] + +from .linear_algebra_functions import cross, det, diagonal, inv, norm, outer, trace, transpose + +__all__ += ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose'] + +# from .linear_algebra_functions import cholesky, cross, det, diagonal, dot, eig, eigvalsh, einsum, inv, lstsq, matmul, matrix_power, matrix_rank, norm, outer, pinv, qr, slogdet, solve, svd, trace, transpose +# +# __all__ += ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose'] + +from .manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack + +__all__ += ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack'] + +from .searching_functions import argmax, argmin, nonzero, where + +__all__ += ['argmax', 'argmin', 'nonzero', 'where'] + +from .set_functions import unique + +__all__ += ['unique'] + +from .sorting_functions import argsort, sort + +__all__ += ['argsort', 'sort'] + +from .statistical_functions import max, mean, min, prod, std, sum, var + +__all__ += ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] + +from .utility_functions import all, any + +__all__ += ['all', 'any'] diff --git a/numpy/_array_api/constants.py b/numpy/_array_api/constants.py new file mode 100644 index 000000000..000777029 --- /dev/null +++ b/numpy/_array_api/constants.py @@ -0,0 +1,3 @@ +from .. import e, inf, nan, pi + +__all__ = ['e', 'inf', 'nan', 'pi'] diff --git a/numpy/_array_api/creation_functions.py b/numpy/_array_api/creation_functions.py new file mode 100644 index 000000000..50b0bd252 --- /dev/null +++ b/numpy/_array_api/creation_functions.py @@ -0,0 +1,45 @@ +def arange(start, /, *, stop=None, step=1, dtype=None): + from .. import arange + return arange(start, stop=stop, step=step, dtype=dtype) + +def empty(shape, /, *, dtype=None): + from .. import empty + return empty(shape, dtype=dtype) + +def empty_like(x, /, *, dtype=None): + from .. import empty_like + return empty_like(x, dtype=dtype) + +def eye(N, /, *, M=None, k=0, dtype=None): + from .. import eye + return eye(N, M=M, k=k, dtype=dtype) + +def full(shape, fill_value, /, *, dtype=None): + from .. import full + return full(shape, fill_value, dtype=dtype) + +def full_like(x, fill_value, /, *, dtype=None): + from .. import full_like + return full_like(x, fill_value, dtype=dtype) + +def linspace(start, stop, num, /, *, dtype=None, endpoint=True): + from .. import linspace + return linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + +def ones(shape, /, *, dtype=None): + from .. import ones + return ones(shape, dtype=dtype) + +def ones_like(x, /, *, dtype=None): + from .. import ones_like + return ones_like(x, dtype=dtype) + +def zeros(shape, /, *, dtype=None): + from .. import zeros + return zeros(shape, dtype=dtype) + +def zeros_like(x, /, *, dtype=None): + from .. import zeros_like + return zeros_like(x, dtype=dtype) + +__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like'] diff --git a/numpy/_array_api/elementwise_functions.py b/numpy/_array_api/elementwise_functions.py new file mode 100644 index 000000000..3e8349d29 --- /dev/null +++ b/numpy/_array_api/elementwise_functions.py @@ -0,0 +1,221 @@ +def abs(x, /): + from .. import abs + return abs(x) + +def acos(x, /): + from .. import acos + return acos(x) + +def acosh(x, /): + from .. import acosh + return acosh(x) + +def add(x1, x2, /): + from .. import add + return add(x1, x2) + +def asin(x, /): + from .. import asin + return asin(x) + +def asinh(x, /): + from .. import asinh + return asinh(x) + +def atan(x, /): + from .. import atan + return atan(x) + +def atan2(x1, x2, /): + from .. import atan2 + return atan2(x1, x2) + +def atanh(x, /): + from .. import atanh + return atanh(x) + +def bitwise_and(x1, x2, /): + from .. import bitwise_and + return bitwise_and(x1, x2) + +def bitwise_left_shift(x1, x2, /): + from .. import bitwise_left_shift + return bitwise_left_shift(x1, x2) + +def bitwise_invert(x, /): + from .. import bitwise_invert + return bitwise_invert(x) + +def bitwise_or(x1, x2, /): + from .. import bitwise_or + return bitwise_or(x1, x2) + +def bitwise_right_shift(x1, x2, /): + from .. import bitwise_right_shift + return bitwise_right_shift(x1, x2) + +def bitwise_xor(x1, x2, /): + from .. import bitwise_xor + return bitwise_xor(x1, x2) + +def ceil(x, /): + from .. import ceil + return ceil(x) + +def cos(x, /): + from .. import cos + return cos(x) + +def cosh(x, /): + from .. import cosh + return cosh(x) + +def divide(x1, x2, /): + from .. import divide + return divide(x1, x2) + +def equal(x1, x2, /): + from .. import equal + return equal(x1, x2) + +def exp(x, /): + from .. import exp + return exp(x) + +def expm1(x, /): + from .. import expm1 + return expm1(x) + +def floor(x, /): + from .. import floor + return floor(x) + +def floor_divide(x1, x2, /): + from .. import floor_divide + return floor_divide(x1, x2) + +def greater(x1, x2, /): + from .. import greater + return greater(x1, x2) + +def greater_equal(x1, x2, /): + from .. import greater_equal + return greater_equal(x1, x2) + +def isfinite(x, /): + from .. import isfinite + return isfinite(x) + +def isinf(x, /): + from .. import isinf + return isinf(x) + +def isnan(x, /): + from .. import isnan + return isnan(x) + +def less(x1, x2, /): + from .. import less + return less(x1, x2) + +def less_equal(x1, x2, /): + from .. import less_equal + return less_equal(x1, x2) + +def log(x, /): + from .. import log + return log(x) + +def log1p(x, /): + from .. import log1p + return log1p(x) + +def log2(x, /): + from .. import log2 + return log2(x) + +def log10(x, /): + from .. import log10 + return log10(x) + +def logical_and(x1, x2, /): + from .. import logical_and + return logical_and(x1, x2) + +def logical_not(x, /): + from .. import logical_not + return logical_not(x) + +def logical_or(x1, x2, /): + from .. import logical_or + return logical_or(x1, x2) + +def logical_xor(x1, x2, /): + from .. import logical_xor + return logical_xor(x1, x2) + +def multiply(x1, x2, /): + from .. import multiply + return multiply(x1, x2) + +def negative(x, /): + from .. import negative + return negative(x) + +def not_equal(x1, x2, /): + from .. import not_equal + return not_equal(x1, x2) + +def positive(x, /): + from .. import positive + return positive(x) + +def pow(x1, x2, /): + from .. import pow + return pow(x1, x2) + +def remainder(x1, x2, /): + from .. import remainder + return remainder(x1, x2) + +def round(x, /): + from .. import round + return round(x) + +def sign(x, /): + from .. import sign + return sign(x) + +def sin(x, /): + from .. import sin + return sin(x) + +def sinh(x, /): + from .. import sinh + return sinh(x) + +def square(x, /): + from .. import square + return square(x) + +def sqrt(x, /): + from .. import sqrt + return sqrt(x) + +def subtract(x1, x2, /): + from .. import subtract + return subtract(x1, x2) + +def tan(x, /): + from .. import tan + return tan(x) + +def tanh(x, /): + from .. import tanh + return tanh(x) + +def trunc(x, /): + from .. import trunc + return trunc(x) + +__all__ = ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] diff --git a/numpy/_array_api/linear_algebra_functions.py b/numpy/_array_api/linear_algebra_functions.py new file mode 100644 index 000000000..5da7ac17b --- /dev/null +++ b/numpy/_array_api/linear_algebra_functions.py @@ -0,0 +1,91 @@ +# def cholesky(): +# from .. import cholesky +# return cholesky() + +def cross(x1, x2, /, *, axis=-1): + from .. import cross + return cross(x1, x2, axis=axis) + +def det(x, /): + from .. import det + return det(x) + +def diagonal(x, /, *, axis1=0, axis2=1, offset=0): + from .. import diagonal + return diagonal(x, axis1=axis1, axis2=axis2, offset=offset) + +# def dot(): +# from .. import dot +# return dot() +# +# def eig(): +# from .. import eig +# return eig() +# +# def eigvalsh(): +# from .. import eigvalsh +# return eigvalsh() +# +# def einsum(): +# from .. import einsum +# return einsum() + +def inv(x): + from .. import inv + return inv(x) + +# def lstsq(): +# from .. import lstsq +# return lstsq() +# +# def matmul(): +# from .. import matmul +# return matmul() +# +# def matrix_power(): +# from .. import matrix_power +# return matrix_power() +# +# def matrix_rank(): +# from .. import matrix_rank +# return matrix_rank() + +def norm(x, /, *, axis=None, keepdims=False, ord=None): + from .. import norm + return norm(x, axis=axis, keepdims=keepdims, ord=ord) + +def outer(x1, x2, /): + from .. import outer + return outer(x1, x2) + +# def pinv(): +# from .. import pinv +# return pinv() +# +# def qr(): +# from .. import qr +# return qr() +# +# def slogdet(): +# from .. import slogdet +# return slogdet() +# +# def solve(): +# from .. import solve +# return solve() +# +# def svd(): +# from .. import svd +# return svd() + +def trace(x, /, *, axis1=0, axis2=1, offset=0): + from .. import trace + return trace(x, axis1=axis1, axis2=axis2, offset=offset) + +def transpose(x, /, *, axes=None): + from .. import transpose + return transpose(x, axes=axes) + +# __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose'] + +__all__ = ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose'] diff --git a/numpy/_array_api/manipulation_functions.py b/numpy/_array_api/manipulation_functions.py new file mode 100644 index 000000000..1934e8e4e --- /dev/null +++ b/numpy/_array_api/manipulation_functions.py @@ -0,0 +1,29 @@ +def concat(arrays, /, *, axis=0): + from .. import concat + return concat(arrays, axis=axis) + +def expand_dims(x, axis, /): + from .. import expand_dims + return expand_dims(x, axis) + +def flip(x, /, *, axis=None): + from .. import flip + return flip(x, axis=axis) + +def reshape(x, shape, /): + from .. import reshape + return reshape(x, shape) + +def roll(x, shift, /, *, axis=None): + from .. import roll + return roll(x, shift, axis=axis) + +def squeeze(x, /, *, axis=None): + from .. import squeeze + return squeeze(x, axis=axis) + +def stack(arrays, /, *, axis=0): + from .. import stack + return stack(arrays, axis=axis) + +__all__ = ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack'] diff --git a/numpy/_array_api/searching_functions.py b/numpy/_array_api/searching_functions.py new file mode 100644 index 000000000..c4b6c58b5 --- /dev/null +++ b/numpy/_array_api/searching_functions.py @@ -0,0 +1,17 @@ +def argmax(x, /, *, axis=None, keepdims=False): + from .. import argmax + return argmax(x, axis=axis, keepdims=keepdims) + +def argmin(x, /, *, axis=None, keepdims=False): + from .. import argmin + return argmin(x, axis=axis, keepdims=keepdims) + +def nonzero(x, /): + from .. import nonzero + return nonzero(x) + +def where(condition, x1, x2, /): + from .. import where + return where(condition, x1, x2) + +__all__ = ['argmax', 'argmin', 'nonzero', 'where'] diff --git a/numpy/_array_api/set_functions.py b/numpy/_array_api/set_functions.py new file mode 100644 index 000000000..f218f1187 --- /dev/null +++ b/numpy/_array_api/set_functions.py @@ -0,0 +1,5 @@ +def unique(x, /, *, return_counts=False, return_index=False, return_inverse=False, sorted=True): + from .. import unique + return unique(x, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse, sorted=sorted) + +__all__ = ['unique'] diff --git a/numpy/_array_api/sorting_functions.py b/numpy/_array_api/sorting_functions.py new file mode 100644 index 000000000..384ec08f9 --- /dev/null +++ b/numpy/_array_api/sorting_functions.py @@ -0,0 +1,9 @@ +def argsort(x, /, *, axis=-1, descending=False, stable=True): + from .. import argsort + return argsort(x, axis=axis, descending=descending, stable=stable) + +def sort(x, /, *, axis=-1, descending=False, stable=True): + from .. import sort + return sort(x, axis=axis, descending=descending, stable=stable) + +__all__ = ['argsort', 'sort'] diff --git a/numpy/_array_api/statistical_functions.py b/numpy/_array_api/statistical_functions.py new file mode 100644 index 000000000..2cc712aea --- /dev/null +++ b/numpy/_array_api/statistical_functions.py @@ -0,0 +1,29 @@ +def max(x, /, *, axis=None, keepdims=False): + from .. import max + return max(x, axis=axis, keepdims=keepdims) + +def mean(x, /, *, axis=None, keepdims=False): + from .. import mean + return mean(x, axis=axis, keepdims=keepdims) + +def min(x, /, *, axis=None, keepdims=False): + from .. import min + return min(x, axis=axis, keepdims=keepdims) + +def prod(x, /, *, axis=None, keepdims=False): + from .. import prod + return prod(x, axis=axis, keepdims=keepdims) + +def std(x, /, *, axis=None, correction=0.0, keepdims=False): + from .. import std + return std(x, axis=axis, correction=correction, keepdims=keepdims) + +def sum(x, /, *, axis=None, keepdims=False): + from .. import sum + return sum(x, axis=axis, keepdims=keepdims) + +def var(x, /, *, axis=None, correction=0.0, keepdims=False): + from .. import var + return var(x, axis=axis, correction=correction, keepdims=keepdims) + +__all__ = ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] diff --git a/numpy/_array_api/utility_functions.py b/numpy/_array_api/utility_functions.py new file mode 100644 index 000000000..eac0d4eaa --- /dev/null +++ b/numpy/_array_api/utility_functions.py @@ -0,0 +1,9 @@ +def all(x, /, *, axis=None, keepdims=False): + from .. import all + return all(x, axis=axis, keepdims=keepdims) + +def any(x, /, *, axis=None, keepdims=False): + from .. import any + return any(x, axis=axis, keepdims=keepdims) + +__all__ = ['all', 'any'] |
