summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorRoy Smart <roytsmart@gmail.com>2022-12-02 16:09:33 -0700
committerRoy Smart <roytsmart@gmail.com>2022-12-05 15:53:32 -0700
commit91432a36a3611c2374ea9e2d45592f0ac5e71adb (patch)
treec8f5686e420e3365bc7af432476c353ef289218c /numpy/lib
parentf74b3720ca6148574da40ed96d4712e221ee84bd (diff)
downloadnumpy-91432a36a3611c2374ea9e2d45592f0ac5e71adb.tar.gz
BUG: `keepdims=True` is ignored if `out` is not `None` in `numpy.median()`, `numpy.percentile()`, and `numpy.quantile()`.
Closes #22714, #22544.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py51
-rw-r--r--numpy/lib/nanfunctions.py14
-rw-r--r--numpy/lib/tests/test_function_base.py50
-rw-r--r--numpy/lib/tests/test_nanfunctions.py58
4 files changed, 145 insertions, 28 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 0ab49fa11..35a3b3543 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3689,7 +3689,7 @@ def msort(a):
return b
-def _ureduce(a, func, **kwargs):
+def _ureduce(a, func, keepdims=False, **kwargs):
"""
Internal Function.
Call `func` with `a` as first argument swapping the axes to use extended
@@ -3717,13 +3717,20 @@ def _ureduce(a, func, **kwargs):
"""
a = np.asanyarray(a)
axis = kwargs.get('axis', None)
+ out = kwargs.get('out', None)
+
+ if keepdims is np._NoValue:
+ keepdims = False
+
+ nd = a.ndim
if axis is not None:
- keepdim = list(a.shape)
- nd = a.ndim
axis = _nx.normalize_axis_tuple(axis, nd)
- for ax in axis:
- keepdim[ax] = 1
+ if keepdims:
+ if out is not None:
+ index_out = tuple(
+ 0 if i in axis else slice(None) for i in range(nd))
+ kwargs['out'] = out[(Ellipsis, ) + index_out]
if len(axis) == 1:
kwargs['axis'] = axis[0]
@@ -3736,12 +3743,27 @@ def _ureduce(a, func, **kwargs):
# merge reduced axis
a = a.reshape(a.shape[:nkeep] + (-1,))
kwargs['axis'] = -1
- keepdim = tuple(keepdim)
else:
- keepdim = (1,) * a.ndim
+ if keepdims:
+ if out is not None:
+ index_out = (0, ) * nd
+ kwargs['out'] = out[(Ellipsis, ) + index_out]
r = func(a, **kwargs)
- return r, keepdim
+
+ if out is not None:
+ return out
+
+ if keepdims:
+ if axis is None:
+ index_r = (np.newaxis, ) * nd
+ else:
+ index_r = tuple(
+ np.newaxis if i in axis else slice(None)
+ for i in range(nd))
+ r = r[(Ellipsis, ) + index_r]
+
+ return r
def _median_dispatcher(
@@ -3831,12 +3853,8 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
>>> assert not np.all(a==b)
"""
- r, k = _ureduce(a, func=_median, axis=axis, out=out,
+ return _ureduce(a, func=_median, keepdims=keepdims, axis=axis, out=out,
overwrite_input=overwrite_input)
- if keepdims:
- return r.reshape(k)
- else:
- return r
def _median(a, axis=None, out=None, overwrite_input=False):
@@ -4452,17 +4470,14 @@ def _quantile_unchecked(a,
method="linear",
keepdims=False):
"""Assumes that q is in [0, 1], and is an ndarray"""
- r, k = _ureduce(a,
+ return _ureduce(a,
func=_quantile_ureduce_func,
q=q,
+ keepdims=keepdims,
axis=axis,
out=out,
overwrite_input=overwrite_input,
method=method)
- if keepdims:
- return r.reshape(q.shape + k)
- else:
- return r
def _quantile_is_valid(q):
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py
index 3814c0727..ae2dfa165 100644
--- a/numpy/lib/nanfunctions.py
+++ b/numpy/lib/nanfunctions.py
@@ -1214,12 +1214,9 @@ def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=np._NoValu
if a.size == 0:
return np.nanmean(a, axis, out=out, keepdims=keepdims)
- r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
+ return function_base._ureduce(a, func=_nanmedian, keepdims=keepdims,
+ axis=axis, out=out,
overwrite_input=overwrite_input)
- if keepdims and keepdims is not np._NoValue:
- return r.reshape(k)
- else:
- return r
def _nanpercentile_dispatcher(
@@ -1556,17 +1553,14 @@ def _nanquantile_unchecked(
# so deal them upfront
if a.size == 0:
return np.nanmean(a, axis, out=out, keepdims=keepdims)
- r, k = function_base._ureduce(a,
+ return function_base._ureduce(a,
func=_nanquantile_ureduce_func,
q=q,
+ keepdims=keepdims,
axis=axis,
out=out,
overwrite_input=overwrite_input,
method=method)
- if keepdims and keepdims is not np._NoValue:
- return r.reshape(q.shape + k)
- else:
- return r
def _nanquantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False,
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index c5b31ebf4..1bb4c4efa 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -25,6 +25,7 @@ from numpy.lib import (
i0, insert, interp, kaiser, meshgrid, msort, piecewise, place, rot90,
select, setxor1d, sinc, trapz, trim_zeros, unwrap, unique, vectorize
)
+from numpy.core.numeric import normalize_axis_tuple
def get_mat(n):
@@ -3331,6 +3332,32 @@ class TestPercentile:
assert_equal(np.percentile(d, [1, 7], axis=(0, 3),
keepdims=True).shape, (2, 1, 5, 7, 1))
+ @pytest.mark.parametrize('q', [7, [1, 7]])
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1,),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, q, axis):
+ d = np.ones((3, 5, 7, 11))
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ shape_out = np.shape(q) + shape_out
+
+ out = np.empty(shape_out)
+ result = np.percentile(d, q, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
def test_out(self):
o = np.zeros((4,))
d = np.ones((3, 4))
@@ -3843,6 +3870,29 @@ class TestMedian:
assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape,
(1, 1, 7, 1))
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1, ),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, axis):
+ d = np.ones((3, 5, 7, 11))
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ out = np.empty(shape_out)
+ result = np.median(d, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
class TestAdd_newdoc_ufunc:
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 733a077ea..64464edcc 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -3,6 +3,7 @@ import pytest
import inspect
import numpy as np
+from numpy.core.numeric import normalize_axis_tuple
from numpy.lib.nanfunctions import _nan_mask, _replace_nan
from numpy.testing import (
assert_, assert_equal, assert_almost_equal, assert_raises,
@@ -807,6 +808,33 @@ class TestNanFunctions_Median:
res = np.nanmedian(d, axis=(0, 1, 3), keepdims=True)
assert_equal(res.shape, (1, 1, 7, 1))
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1, ),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, axis):
+ d = np.ones((3, 5, 7, 11))
+ # Randomly set some elements to NaN:
+ w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
+ w = w.astype(np.intp)
+ d[tuple(w)] = np.nan
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ out = np.empty(shape_out)
+ result = np.nanmedian(d, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
def test_out(self):
mat = np.random.rand(3, 3)
nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)
@@ -982,6 +1010,36 @@ class TestNanFunctions_Percentile:
res = np.nanpercentile(d, 90, axis=(0, 1, 3), keepdims=True)
assert_equal(res.shape, (1, 1, 7, 1))
+ @pytest.mark.parametrize('q', [7, [1, 7]])
+ @pytest.mark.parametrize(
+ argnames='axis',
+ argvalues=[
+ None,
+ 1,
+ (1,),
+ (0, 1),
+ (-3, -1),
+ ]
+ )
+ def test_keepdims_out(self, q, axis):
+ d = np.ones((3, 5, 7, 11))
+ # Randomly set some elements to NaN:
+ w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
+ w = w.astype(np.intp)
+ d[tuple(w)] = np.nan
+ if axis is None:
+ shape_out = (1,) * d.ndim
+ else:
+ axis_norm = normalize_axis_tuple(axis, d.ndim)
+ shape_out = tuple(
+ 1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
+ shape_out = np.shape(q) + shape_out
+
+ out = np.empty(shape_out)
+ result = np.nanpercentile(d, q, axis=axis, keepdims=True, out=out)
+ assert result is out
+ assert_equal(result.shape, shape_out)
+
def test_out(self):
mat = np.random.rand(3, 3)
nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)