diff options
| author | Aaron Meurer <asmeurer@gmail.com> | 2021-01-11 16:03:43 -0700 |
|---|---|---|
| committer | Aaron Meurer <asmeurer@gmail.com> | 2021-01-11 16:03:43 -0700 |
| commit | 012343dec5599418b77512733fc5b8db6bc14c4c (patch) | |
| tree | 13ef3a6bfdcd603036baaf8d65d1647eea749afc /numpy/_array_api | |
| parent | 33dc7bea24f1ab6c47047b49521e732caeb485d5 (diff) | |
| download | numpy-012343dec5599418b77512733fc5b8db6bc14c4c.tar.gz | |
Add initial array_api sub-namespace
This is based on the function stubs from the array API test suite, and is
currently based on the assumption that NumPy already follows the array API
standard. Now it needs to be modified to fix it in the places where NumPy
deviates (for example, different function names for inverse trigonometric
functions).
Diffstat (limited to 'numpy/_array_api')
| -rw-r--r-- | numpy/_array_api/__init__.py | 45 | ||||
| -rw-r--r-- | numpy/_array_api/constants.py | 3 | ||||
| -rw-r--r-- | numpy/_array_api/creation_functions.py | 45 | ||||
| -rw-r--r-- | numpy/_array_api/elementwise_functions.py | 221 | ||||
| -rw-r--r-- | numpy/_array_api/linear_algebra_functions.py | 91 | ||||
| -rw-r--r-- | numpy/_array_api/manipulation_functions.py | 29 | ||||
| -rw-r--r-- | numpy/_array_api/searching_functions.py | 17 | ||||
| -rw-r--r-- | numpy/_array_api/set_functions.py | 5 | ||||
| -rw-r--r-- | numpy/_array_api/sorting_functions.py | 9 | ||||
| -rw-r--r-- | numpy/_array_api/statistical_functions.py | 29 | ||||
| -rw-r--r-- | numpy/_array_api/utility_functions.py | 9 |
11 files changed, 503 insertions, 0 deletions
diff --git a/numpy/_array_api/__init__.py b/numpy/_array_api/__init__.py new file mode 100644 index 000000000..878251e7c --- /dev/null +++ b/numpy/_array_api/__init__.py @@ -0,0 +1,45 @@ +__all__ = [] + +from .constants import e, inf, nan, pi + +__all__ += ['e', 'inf', 'nan', 'pi'] + +from .creation_functions import arange, empty, empty_like, eye, full, full_like, linspace, ones, ones_like, zeros, zeros_like + +__all__ += ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like'] + +from .elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and, bitwise_left_shift, bitwise_invert, bitwise_or, bitwise_right_shift, bitwise_xor, ceil, cos, cosh, divide, equal, exp, expm1, floor, floor_divide, greater, greater_equal, isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logical_and, logical_not, logical_or, logical_xor, multiply, negative, not_equal, positive, pow, remainder, round, sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc + +__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] + +from .linear_algebra_functions import cross, det, diagonal, inv, norm, outer, trace, transpose + +__all__ += ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose'] + +# from .linear_algebra_functions import cholesky, cross, det, diagonal, dot, eig, eigvalsh, einsum, inv, lstsq, matmul, matrix_power, matrix_rank, norm, outer, pinv, qr, slogdet, solve, svd, trace, transpose +# +# __all__ += ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose'] + +from .manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack + +__all__ += ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack'] + +from .searching_functions import argmax, argmin, nonzero, where + +__all__ += ['argmax', 'argmin', 'nonzero', 'where'] + +from .set_functions import unique + +__all__ += ['unique'] + +from .sorting_functions import argsort, sort + +__all__ += ['argsort', 'sort'] + +from .statistical_functions import max, mean, min, prod, std, sum, var + +__all__ += ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] + +from .utility_functions import all, any + +__all__ += ['all', 'any'] diff --git a/numpy/_array_api/constants.py b/numpy/_array_api/constants.py new file mode 100644 index 000000000..000777029 --- /dev/null +++ b/numpy/_array_api/constants.py @@ -0,0 +1,3 @@ +from .. import e, inf, nan, pi + +__all__ = ['e', 'inf', 'nan', 'pi'] diff --git a/numpy/_array_api/creation_functions.py b/numpy/_array_api/creation_functions.py new file mode 100644 index 000000000..50b0bd252 --- /dev/null +++ b/numpy/_array_api/creation_functions.py @@ -0,0 +1,45 @@ +def arange(start, /, *, stop=None, step=1, dtype=None): + from .. import arange + return arange(start, stop=stop, step=step, dtype=dtype) + +def empty(shape, /, *, dtype=None): + from .. import empty + return empty(shape, dtype=dtype) + +def empty_like(x, /, *, dtype=None): + from .. import empty_like + return empty_like(x, dtype=dtype) + +def eye(N, /, *, M=None, k=0, dtype=None): + from .. import eye + return eye(N, M=M, k=k, dtype=dtype) + +def full(shape, fill_value, /, *, dtype=None): + from .. import full + return full(shape, fill_value, dtype=dtype) + +def full_like(x, fill_value, /, *, dtype=None): + from .. import full_like + return full_like(x, fill_value, dtype=dtype) + +def linspace(start, stop, num, /, *, dtype=None, endpoint=True): + from .. import linspace + return linspace(start, stop, num, dtype=dtype, endpoint=endpoint) + +def ones(shape, /, *, dtype=None): + from .. import ones + return ones(shape, dtype=dtype) + +def ones_like(x, /, *, dtype=None): + from .. import ones_like + return ones_like(x, dtype=dtype) + +def zeros(shape, /, *, dtype=None): + from .. import zeros + return zeros(shape, dtype=dtype) + +def zeros_like(x, /, *, dtype=None): + from .. import zeros_like + return zeros_like(x, dtype=dtype) + +__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like'] diff --git a/numpy/_array_api/elementwise_functions.py b/numpy/_array_api/elementwise_functions.py new file mode 100644 index 000000000..3e8349d29 --- /dev/null +++ b/numpy/_array_api/elementwise_functions.py @@ -0,0 +1,221 @@ +def abs(x, /): + from .. import abs + return abs(x) + +def acos(x, /): + from .. import acos + return acos(x) + +def acosh(x, /): + from .. import acosh + return acosh(x) + +def add(x1, x2, /): + from .. import add + return add(x1, x2) + +def asin(x, /): + from .. import asin + return asin(x) + +def asinh(x, /): + from .. import asinh + return asinh(x) + +def atan(x, /): + from .. import atan + return atan(x) + +def atan2(x1, x2, /): + from .. import atan2 + return atan2(x1, x2) + +def atanh(x, /): + from .. import atanh + return atanh(x) + +def bitwise_and(x1, x2, /): + from .. import bitwise_and + return bitwise_and(x1, x2) + +def bitwise_left_shift(x1, x2, /): + from .. import bitwise_left_shift + return bitwise_left_shift(x1, x2) + +def bitwise_invert(x, /): + from .. import bitwise_invert + return bitwise_invert(x) + +def bitwise_or(x1, x2, /): + from .. import bitwise_or + return bitwise_or(x1, x2) + +def bitwise_right_shift(x1, x2, /): + from .. import bitwise_right_shift + return bitwise_right_shift(x1, x2) + +def bitwise_xor(x1, x2, /): + from .. import bitwise_xor + return bitwise_xor(x1, x2) + +def ceil(x, /): + from .. import ceil + return ceil(x) + +def cos(x, /): + from .. import cos + return cos(x) + +def cosh(x, /): + from .. import cosh + return cosh(x) + +def divide(x1, x2, /): + from .. import divide + return divide(x1, x2) + +def equal(x1, x2, /): + from .. import equal + return equal(x1, x2) + +def exp(x, /): + from .. import exp + return exp(x) + +def expm1(x, /): + from .. import expm1 + return expm1(x) + +def floor(x, /): + from .. import floor + return floor(x) + +def floor_divide(x1, x2, /): + from .. import floor_divide + return floor_divide(x1, x2) + +def greater(x1, x2, /): + from .. import greater + return greater(x1, x2) + +def greater_equal(x1, x2, /): + from .. import greater_equal + return greater_equal(x1, x2) + +def isfinite(x, /): + from .. import isfinite + return isfinite(x) + +def isinf(x, /): + from .. import isinf + return isinf(x) + +def isnan(x, /): + from .. import isnan + return isnan(x) + +def less(x1, x2, /): + from .. import less + return less(x1, x2) + +def less_equal(x1, x2, /): + from .. import less_equal + return less_equal(x1, x2) + +def log(x, /): + from .. import log + return log(x) + +def log1p(x, /): + from .. import log1p + return log1p(x) + +def log2(x, /): + from .. import log2 + return log2(x) + +def log10(x, /): + from .. import log10 + return log10(x) + +def logical_and(x1, x2, /): + from .. import logical_and + return logical_and(x1, x2) + +def logical_not(x, /): + from .. import logical_not + return logical_not(x) + +def logical_or(x1, x2, /): + from .. import logical_or + return logical_or(x1, x2) + +def logical_xor(x1, x2, /): + from .. import logical_xor + return logical_xor(x1, x2) + +def multiply(x1, x2, /): + from .. import multiply + return multiply(x1, x2) + +def negative(x, /): + from .. import negative + return negative(x) + +def not_equal(x1, x2, /): + from .. import not_equal + return not_equal(x1, x2) + +def positive(x, /): + from .. import positive + return positive(x) + +def pow(x1, x2, /): + from .. import pow + return pow(x1, x2) + +def remainder(x1, x2, /): + from .. import remainder + return remainder(x1, x2) + +def round(x, /): + from .. import round + return round(x) + +def sign(x, /): + from .. import sign + return sign(x) + +def sin(x, /): + from .. import sin + return sin(x) + +def sinh(x, /): + from .. import sinh + return sinh(x) + +def square(x, /): + from .. import square + return square(x) + +def sqrt(x, /): + from .. import sqrt + return sqrt(x) + +def subtract(x1, x2, /): + from .. import subtract + return subtract(x1, x2) + +def tan(x, /): + from .. import tan + return tan(x) + +def tanh(x, /): + from .. import tanh + return tanh(x) + +def trunc(x, /): + from .. import trunc + return trunc(x) + +__all__ = ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc'] diff --git a/numpy/_array_api/linear_algebra_functions.py b/numpy/_array_api/linear_algebra_functions.py new file mode 100644 index 000000000..5da7ac17b --- /dev/null +++ b/numpy/_array_api/linear_algebra_functions.py @@ -0,0 +1,91 @@ +# def cholesky(): +# from .. import cholesky +# return cholesky() + +def cross(x1, x2, /, *, axis=-1): + from .. import cross + return cross(x1, x2, axis=axis) + +def det(x, /): + from .. import det + return det(x) + +def diagonal(x, /, *, axis1=0, axis2=1, offset=0): + from .. import diagonal + return diagonal(x, axis1=axis1, axis2=axis2, offset=offset) + +# def dot(): +# from .. import dot +# return dot() +# +# def eig(): +# from .. import eig +# return eig() +# +# def eigvalsh(): +# from .. import eigvalsh +# return eigvalsh() +# +# def einsum(): +# from .. import einsum +# return einsum() + +def inv(x): + from .. import inv + return inv(x) + +# def lstsq(): +# from .. import lstsq +# return lstsq() +# +# def matmul(): +# from .. import matmul +# return matmul() +# +# def matrix_power(): +# from .. import matrix_power +# return matrix_power() +# +# def matrix_rank(): +# from .. import matrix_rank +# return matrix_rank() + +def norm(x, /, *, axis=None, keepdims=False, ord=None): + from .. import norm + return norm(x, axis=axis, keepdims=keepdims, ord=ord) + +def outer(x1, x2, /): + from .. import outer + return outer(x1, x2) + +# def pinv(): +# from .. import pinv +# return pinv() +# +# def qr(): +# from .. import qr +# return qr() +# +# def slogdet(): +# from .. import slogdet +# return slogdet() +# +# def solve(): +# from .. import solve +# return solve() +# +# def svd(): +# from .. import svd +# return svd() + +def trace(x, /, *, axis1=0, axis2=1, offset=0): + from .. import trace + return trace(x, axis1=axis1, axis2=axis2, offset=offset) + +def transpose(x, /, *, axes=None): + from .. import transpose + return transpose(x, axes=axes) + +# __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose'] + +__all__ = ['cross', 'det', 'diagonal', 'inv', 'norm', 'outer', 'trace', 'transpose'] diff --git a/numpy/_array_api/manipulation_functions.py b/numpy/_array_api/manipulation_functions.py new file mode 100644 index 000000000..1934e8e4e --- /dev/null +++ b/numpy/_array_api/manipulation_functions.py @@ -0,0 +1,29 @@ +def concat(arrays, /, *, axis=0): + from .. import concat + return concat(arrays, axis=axis) + +def expand_dims(x, axis, /): + from .. import expand_dims + return expand_dims(x, axis) + +def flip(x, /, *, axis=None): + from .. import flip + return flip(x, axis=axis) + +def reshape(x, shape, /): + from .. import reshape + return reshape(x, shape) + +def roll(x, shift, /, *, axis=None): + from .. import roll + return roll(x, shift, axis=axis) + +def squeeze(x, /, *, axis=None): + from .. import squeeze + return squeeze(x, axis=axis) + +def stack(arrays, /, *, axis=0): + from .. import stack + return stack(arrays, axis=axis) + +__all__ = ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack'] diff --git a/numpy/_array_api/searching_functions.py b/numpy/_array_api/searching_functions.py new file mode 100644 index 000000000..c4b6c58b5 --- /dev/null +++ b/numpy/_array_api/searching_functions.py @@ -0,0 +1,17 @@ +def argmax(x, /, *, axis=None, keepdims=False): + from .. import argmax + return argmax(x, axis=axis, keepdims=keepdims) + +def argmin(x, /, *, axis=None, keepdims=False): + from .. import argmin + return argmin(x, axis=axis, keepdims=keepdims) + +def nonzero(x, /): + from .. import nonzero + return nonzero(x) + +def where(condition, x1, x2, /): + from .. import where + return where(condition, x1, x2) + +__all__ = ['argmax', 'argmin', 'nonzero', 'where'] diff --git a/numpy/_array_api/set_functions.py b/numpy/_array_api/set_functions.py new file mode 100644 index 000000000..f218f1187 --- /dev/null +++ b/numpy/_array_api/set_functions.py @@ -0,0 +1,5 @@ +def unique(x, /, *, return_counts=False, return_index=False, return_inverse=False, sorted=True): + from .. import unique + return unique(x, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse, sorted=sorted) + +__all__ = ['unique'] diff --git a/numpy/_array_api/sorting_functions.py b/numpy/_array_api/sorting_functions.py new file mode 100644 index 000000000..384ec08f9 --- /dev/null +++ b/numpy/_array_api/sorting_functions.py @@ -0,0 +1,9 @@ +def argsort(x, /, *, axis=-1, descending=False, stable=True): + from .. import argsort + return argsort(x, axis=axis, descending=descending, stable=stable) + +def sort(x, /, *, axis=-1, descending=False, stable=True): + from .. import sort + return sort(x, axis=axis, descending=descending, stable=stable) + +__all__ = ['argsort', 'sort'] diff --git a/numpy/_array_api/statistical_functions.py b/numpy/_array_api/statistical_functions.py new file mode 100644 index 000000000..2cc712aea --- /dev/null +++ b/numpy/_array_api/statistical_functions.py @@ -0,0 +1,29 @@ +def max(x, /, *, axis=None, keepdims=False): + from .. import max + return max(x, axis=axis, keepdims=keepdims) + +def mean(x, /, *, axis=None, keepdims=False): + from .. import mean + return mean(x, axis=axis, keepdims=keepdims) + +def min(x, /, *, axis=None, keepdims=False): + from .. import min + return min(x, axis=axis, keepdims=keepdims) + +def prod(x, /, *, axis=None, keepdims=False): + from .. import prod + return prod(x, axis=axis, keepdims=keepdims) + +def std(x, /, *, axis=None, correction=0.0, keepdims=False): + from .. import std + return std(x, axis=axis, correction=correction, keepdims=keepdims) + +def sum(x, /, *, axis=None, keepdims=False): + from .. import sum + return sum(x, axis=axis, keepdims=keepdims) + +def var(x, /, *, axis=None, correction=0.0, keepdims=False): + from .. import var + return var(x, axis=axis, correction=correction, keepdims=keepdims) + +__all__ = ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] diff --git a/numpy/_array_api/utility_functions.py b/numpy/_array_api/utility_functions.py new file mode 100644 index 000000000..eac0d4eaa --- /dev/null +++ b/numpy/_array_api/utility_functions.py @@ -0,0 +1,9 @@ +def all(x, /, *, axis=None, keepdims=False): + from .. import all + return all(x, axis=axis, keepdims=keepdims) + +def any(x, /, *, axis=None, keepdims=False): + from .. import any + return any(x, axis=axis, keepdims=keepdims) + +__all__ = ['all', 'any'] |
