diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-02-28 17:33:52 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-05-25 22:55:59 -0700 |
commit | 7a3c50ab427cd9c4f3125a8e72d31f9141bd558a (patch) | |
tree | a1115f67bc353b3a3ef325c4f1e8c5db403c8d86 | |
parent | 905e906d55fdcb8cc215de8aa287ea9654d1c95c (diff) | |
download | numpy-7a3c50ab427cd9c4f3125a8e72d31f9141bd558a.tar.gz |
MAINT: rewrite np.ma.(median|sort) to use take_along_axis
-rw-r--r-- | numpy/ma/core.py | 10 | ||||
-rw-r--r-- | numpy/ma/extras.py | 36 |
2 files changed, 13 insertions, 33 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index fb28fa8e5..b96620215 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -5531,15 +5531,7 @@ class MaskedArray(ndarray): sidx = self.argsort(axis=axis, kind=kind, order=order, fill_value=fill_value, endwith=endwith) - # save memory for 1d arrays - if self.ndim == 1: - idx = sidx - else: - idx = list(np.ix_(*[np.arange(x) for x in self.shape])) - idx[axis] = sidx - idx = tuple(idx) - - self[...] = self[idx] + self[...] = np.take_along_axis(self, sidx, axis=axis) def min(self, axis=None, out=None, fill_value=None, keepdims=np._NoValue): """ diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index da35217d1..3be4d3625 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -747,19 +747,17 @@ def _median(a, axis=None, out=None, overwrite_input=False): return np.ma.minimum_fill_value(asorted) return s - counts = count(asorted, axis=axis) + counts = count(asorted, axis=axis, keepdims=True) h = counts // 2 - # create indexing mesh grid for all but reduced axis - axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape) - if i != axis] - ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij') + # duplicate high if odd number of elements so mean does nothing + odd = counts % 2 == 1 + l = np.where(odd, h, h-1) - # insert indices of low and high median - ind.insert(axis, h - 1) - low = asorted[tuple(ind)] - ind[axis] = np.minimum(h, asorted.shape[axis] - 1) - high = asorted[tuple(ind)] + lh = np.concatenate([l,h], axis=axis) + + # get low and high median + low_high = np.take_along_axis(asorted, lh, axis=axis) def replace_masked(s): # Replace masked entries with minimum_full_value unless it all values @@ -767,30 +765,20 @@ def _median(a, axis=None, out=None, overwrite_input=False): # larger than the fill value is undefined and a valid value placed # elsewhere, e.g. [4, --, inf]. if np.ma.is_masked(s): - rep = (~np.all(asorted.mask, axis=axis)) & s.mask + rep = (~np.all(asorted.mask, axis=axis, keepdims=True)) & s.mask s.data[rep] = np.ma.minimum_fill_value(asorted) s.mask[rep] = False - replace_masked(low) - replace_masked(high) - - # duplicate high if odd number of elements so mean does nothing - odd = counts % 2 == 1 - np.copyto(low, high, where=odd) - # not necessary for scalar True/False masks - try: - np.copyto(low.mask, high.mask, where=odd) - except Exception: - pass + replace_masked(low_high) if np.issubdtype(asorted.dtype, np.inexact): # avoid inf / x = masked - s = np.ma.sum([low, high], axis=0, out=out) + s = np.ma.sum(low_high, axis=axis, out=out) np.true_divide(s.data, 2., casting='unsafe', out=s.data) s = np.lib.utils._median_nancheck(asorted, s, axis, out) else: - s = np.ma.mean([low, high], axis=0, out=out) + s = np.ma.mean(low_high, axis=axis, out=out) return s |