diff options
| author | Aaron Meurer <asmeurer@gmail.com> | 2021-01-11 17:45:05 -0700 |
|---|---|---|
| committer | Aaron Meurer <asmeurer@gmail.com> | 2021-01-11 17:45:05 -0700 |
| commit | 10427b0cb9895d9d1d55c95815d9f27732cfbeaa (patch) | |
| tree | 659d79039495a658b4db8978544be20809ad3b2d /numpy/_array_api | |
| parent | c8efdbb72eab78ed2fc735d3078ef8534dcc6ef7 (diff) | |
| download | numpy-10427b0cb9895d9d1d55c95815d9f27732cfbeaa.tar.gz | |
Correct some differing keyword arguments in the array API namespace
Diffstat (limited to 'numpy/_array_api')
| -rw-r--r-- | numpy/_array_api/sorting_functions.py | 14 | ||||
| -rw-r--r-- | numpy/_array_api/statistical_functions.py | 6 |
2 files changed, 16 insertions, 4 deletions
diff --git a/numpy/_array_api/sorting_functions.py b/numpy/_array_api/sorting_functions.py index 384ec08f9..b75387737 100644 --- a/numpy/_array_api/sorting_functions.py +++ b/numpy/_array_api/sorting_functions.py @@ -1,9 +1,19 @@ def argsort(x, /, *, axis=-1, descending=False, stable=True): from .. import argsort - return argsort(x, axis=axis, descending=descending, stable=stable) + from .. import flip + # Note: this keyword argument is different, and the default is different. + kind = 'stable' if stable else 'quicksort' + res = argsort(x, axis=axis, kind=kind) + if descending: + res = flip(res, axis=axis) def sort(x, /, *, axis=-1, descending=False, stable=True): from .. import sort - return sort(x, axis=axis, descending=descending, stable=stable) + from .. import flip + # Note: this keyword argument is different, and the default is different. + kind = 'stable' if stable else 'quicksort' + res = sort(x, axis=axis, kind=kind) + if descending: + res = flip(res, axis=axis) __all__ = ['argsort', 'sort'] diff --git a/numpy/_array_api/statistical_functions.py b/numpy/_array_api/statistical_functions.py index 2cc712aea..b9180a863 100644 --- a/numpy/_array_api/statistical_functions.py +++ b/numpy/_array_api/statistical_functions.py @@ -16,7 +16,8 @@ def prod(x, /, *, axis=None, keepdims=False): def std(x, /, *, axis=None, correction=0.0, keepdims=False): from .. import std - return std(x, axis=axis, correction=correction, keepdims=keepdims) + # Note: the keyword argument correction is different here + return std(x, axis=axis, ddof=correction, keepdims=keepdims) def sum(x, /, *, axis=None, keepdims=False): from .. import sum @@ -24,6 +25,7 @@ def sum(x, /, *, axis=None, keepdims=False): def var(x, /, *, axis=None, correction=0.0, keepdims=False): from .. import var - return var(x, axis=axis, correction=correction, keepdims=keepdims) + # Note: the keyword argument correction is different here + return var(x, axis=axis, ddof=correction, keepdims=keepdims) __all__ = ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var'] |
