summaryrefslogtreecommitdiff
path: root/numpy/_array_api/statistical_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-01-11 16:03:43 -0700
committerAaron Meurer <asmeurer@gmail.com>2021-01-11 16:03:43 -0700
commit012343dec5599418b77512733fc5b8db6bc14c4c (patch)
tree13ef3a6bfdcd603036baaf8d65d1647eea749afc /numpy/_array_api/statistical_functions.py
parent33dc7bea24f1ab6c47047b49521e732caeb485d5 (diff)
downloadnumpy-012343dec5599418b77512733fc5b8db6bc14c4c.tar.gz
Add initial array_api sub-namespace
This is based on the function stubs from the array API test suite, and is currently based on the assumption that NumPy already follows the array API standard. Now it needs to be modified to fix it in the places where NumPy deviates (for example, different function names for inverse trigonometric functions).
Diffstat (limited to 'numpy/_array_api/statistical_functions.py')
-rw-r--r--numpy/_array_api/statistical_functions.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/numpy/_array_api/statistical_functions.py b/numpy/_array_api/statistical_functions.py
new file mode 100644
index 000000000..2cc712aea
--- /dev/null
+++ b/numpy/_array_api/statistical_functions.py
@@ -0,0 +1,29 @@
+def max(x, /, *, axis=None, keepdims=False):
+ from .. import max
+ return max(x, axis=axis, keepdims=keepdims)
+
+def mean(x, /, *, axis=None, keepdims=False):
+ from .. import mean
+ return mean(x, axis=axis, keepdims=keepdims)
+
+def min(x, /, *, axis=None, keepdims=False):
+ from .. import min
+ return min(x, axis=axis, keepdims=keepdims)
+
+def prod(x, /, *, axis=None, keepdims=False):
+ from .. import prod
+ return prod(x, axis=axis, keepdims=keepdims)
+
+def std(x, /, *, axis=None, correction=0.0, keepdims=False):
+ from .. import std
+ return std(x, axis=axis, correction=correction, keepdims=keepdims)
+
+def sum(x, /, *, axis=None, keepdims=False):
+ from .. import sum
+ return sum(x, axis=axis, keepdims=keepdims)
+
+def var(x, /, *, axis=None, correction=0.0, keepdims=False):
+ from .. import var
+ return var(x, axis=axis, correction=correction, keepdims=keepdims)
+
+__all__ = ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var']