summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-01-11 17:45:05 -0700
committerAaron Meurer <asmeurer@gmail.com>2021-01-11 17:45:05 -0700
commit10427b0cb9895d9d1d55c95815d9f27732cfbeaa (patch)
tree659d79039495a658b4db8978544be20809ad3b2d /numpy/_array_api
parentc8efdbb72eab78ed2fc735d3078ef8534dcc6ef7 (diff)
downloadnumpy-10427b0cb9895d9d1d55c95815d9f27732cfbeaa.tar.gz
Correct some differing keyword arguments in the array API namespace
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/sorting_functions.py14
-rw-r--r--numpy/_array_api/statistical_functions.py6
2 files changed, 16 insertions, 4 deletions
diff --git a/numpy/_array_api/sorting_functions.py b/numpy/_array_api/sorting_functions.py
index 384ec08f9..b75387737 100644
--- a/numpy/_array_api/sorting_functions.py
+++ b/numpy/_array_api/sorting_functions.py
@@ -1,9 +1,19 @@
def argsort(x, /, *, axis=-1, descending=False, stable=True):
from .. import argsort
- return argsort(x, axis=axis, descending=descending, stable=stable)
+ from .. import flip
+ # Note: this keyword argument is different, and the default is different.
+ kind = 'stable' if stable else 'quicksort'
+ res = argsort(x, axis=axis, kind=kind)
+ if descending:
+ res = flip(res, axis=axis)
def sort(x, /, *, axis=-1, descending=False, stable=True):
from .. import sort
- return sort(x, axis=axis, descending=descending, stable=stable)
+ from .. import flip
+ # Note: this keyword argument is different, and the default is different.
+ kind = 'stable' if stable else 'quicksort'
+ res = sort(x, axis=axis, kind=kind)
+ if descending:
+ res = flip(res, axis=axis)
__all__ = ['argsort', 'sort']
diff --git a/numpy/_array_api/statistical_functions.py b/numpy/_array_api/statistical_functions.py
index 2cc712aea..b9180a863 100644
--- a/numpy/_array_api/statistical_functions.py
+++ b/numpy/_array_api/statistical_functions.py
@@ -16,7 +16,8 @@ def prod(x, /, *, axis=None, keepdims=False):
def std(x, /, *, axis=None, correction=0.0, keepdims=False):
from .. import std
- return std(x, axis=axis, correction=correction, keepdims=keepdims)
+ # Note: the keyword argument correction is different here
+ return std(x, axis=axis, ddof=correction, keepdims=keepdims)
def sum(x, /, *, axis=None, keepdims=False):
from .. import sum
@@ -24,6 +25,7 @@ def sum(x, /, *, axis=None, keepdims=False):
def var(x, /, *, axis=None, correction=0.0, keepdims=False):
from .. import var
- return var(x, axis=axis, correction=correction, keepdims=keepdims)
+ # Note: the keyword argument correction is different here
+ return var(x, axis=axis, ddof=correction, keepdims=keepdims)
__all__ = ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var']