diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2022-06-09 11:55:32 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-09 11:55:32 -0600 |
commit | 5ea81775e0d2d13a4fec6bab882002e050bd3da1 (patch) | |
tree | d2dc068c7c06c22393841ddc0dd7ee011996249f /numpy/array_api/_statistical_functions.py | |
parent | 9d6b102617a92ad41964a5611c8dec87d3053a9f (diff) | |
parent | 5447192c8dc5d43d837b93ed5c5d47457e367669 (diff) | |
download | numpy-5ea81775e0d2d13a4fec6bab882002e050bd3da1.tar.gz |
Merge branch 'main' into bugfix_16492_segfault_on_pyfragments
Diffstat (limited to 'numpy/array_api/_statistical_functions.py')
-rw-r--r-- | numpy/array_api/_statistical_functions.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py new file mode 100644 index 000000000..5bc831ac2 --- /dev/null +++ b/numpy/array_api/_statistical_functions.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from ._dtypes import ( + _floating_dtypes, + _numeric_dtypes, +) +from ._array_object import Array +from ._creation_functions import asarray +from ._dtypes import float32, float64 + +from typing import TYPE_CHECKING, Optional, Tuple, Union + +if TYPE_CHECKING: + from ._typing import Dtype + +import numpy as np + + +def max( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in max") + return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) + + +def mean( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in mean") + return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims)) + + +def min( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in min") + return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) + + +def prod( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in prod") + # Note: sum() and prod() always upcast float32 to float64 for dtype=None + # We need to do so here before computing the product to avoid overflow + if dtype is None and x.dtype == float32: + dtype = float64 + return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) + + +def std( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: + # Note: the keyword argument correction is different here + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in std") + return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)) + + +def sum( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in sum") + # Note: sum() and prod() always upcast integers to (u)int64 and float32 to + # float64 for dtype=None. `np.sum` does that too for integers, but not for + # float32, so we need to special-case it here + if dtype is None and x.dtype == float32: + dtype = float64 + return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) + + +def var( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: + # Note: the keyword argument correction is different here + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in var") + return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)) |