summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py96
1 files changed, 66 insertions, 30 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index d14812093..82a61a67c 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -416,26 +416,53 @@ def apply_over_axes(func, a, axes):
"""
(This docstring will be overwritten)
"""
- val = np.asarray(a)
- msk = getmaskarray(a)
+ val = asarray(a)
N = a.ndim
if array(axes).ndim == 0:
axes = (axes,)
for axis in axes:
if axis < 0: axis = N + axis
args = (val, axis)
- res = ma.array(func(*(val, axis)), mask=func(*(msk, axis)))
+ res = func(*args)
if res.ndim == val.ndim:
- (val, msk) = (res._data, res._mask)
+ val = res
else:
res = ma.expand_dims(res, axis)
if res.ndim == val.ndim:
- (val, msk) = (res._data, res._mask)
+ val = res
else:
- raise ValueError("Function is not returning"\
- " an array of correct shape")
+ raise ValueError("function is not returning "
+ "an array of the correct shape")
return val
-apply_over_axes.__doc__ = np.apply_over_axes.__doc__
+apply_over_axes.__doc__ = np.apply_over_axes.__doc__[
+ :np.apply_over_axes.__doc__.find('Notes')].rstrip() + \
+ """
+
+ Examples
+ --------
+ >>> a = ma.arange(24).reshape(2,3,4)
+ >>> a[:,0,1] = ma.masked
+ >>> a[:,1,:] = ma.masked
+ >>> print a
+ [[[0 -- 2 3]
+ [-- -- -- --]
+ [8 9 10 11]]
+
+ [[12 -- 14 15]
+ [-- -- -- --]
+ [20 21 22 23]]]
+ >>> print ma.apply_over_axes(ma.sum, a, [0,2])
+ [[[46]
+ [--]
+ [124]]]
+
+ Tuple axis arguments to ufuncs are equivalent:
+
+ >>> print ma.sum(a, axis=(0,2)).reshape((1,-1,1))
+ [[[46]
+ [--]
+ [124]]]
+"""
def average(a, axis=None, weights=None, returned=False):
@@ -448,8 +475,8 @@ def average(a, axis=None, weights=None, returned=False):
Data to be averaged.
Masked entries are not taken into account in the computation.
axis : int, optional
- Axis along which the variance is computed. The default is to compute
- the variance of the flattened array.
+ Axis along which the average is computed. The default is to compute
+ the average of the flattened array.
weights : array_like, optional
The importance that each element has in the computation of the average.
The weights array can either be 1-D (in which case its length must be
@@ -540,7 +567,7 @@ def average(a, axis=None, weights=None, returned=False):
else:
if weights is None:
n = add.reduce(a, axis)
- d = umath.add.reduce((-mask), axis=axis, dtype=float)
+ d = umath.add.reduce((~mask), axis=axis, dtype=float)
else:
w = filled(weights, 0.0)
wsh = w.shape
@@ -641,15 +668,9 @@ def median(a, axis=None, out=None, overwrite_input=False):
fill_value = 1e+20)
"""
- def _median1D(data):
- counts = filled(count(data), 0)
- (idx, rmd) = divmod(counts, 2)
- if rmd:
- choice = slice(idx, idx + 1)
- else:
- choice = slice(idx - 1, idx + 1)
- return data[choice].mean(0)
- #
+ if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0:
+ return masked_array(np.median(a, axis=axis, out=out,
+ overwrite_input=overwrite_input), copy=False)
if overwrite_input:
if axis is None:
asorted = a.ravel()
@@ -660,14 +681,29 @@ def median(a, axis=None, out=None, overwrite_input=False):
else:
asorted = sort(a, axis=axis)
if axis is None:
- result = _median1D(asorted)
+ axis = 0
+ elif axis < 0:
+ axis += a.ndim
+
+ counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
+ h = counts // 2
+ # create indexing mesh grid for all but reduced axis
+ axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape)
+ if i != axis]
+ ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
+ # insert indices of low and high median
+ ind.insert(axis, h - 1)
+ low = asorted[ind]
+ ind[axis] = h
+ high = asorted[ind]
+ # duplicate high if odd number of elements so mean does nothing
+ odd = counts % 2 == 1
+ if asorted.ndim == 1:
+ if odd:
+ low = high
else:
- result = apply_along_axis(_median1D, axis, asorted)
- if out is not None:
- out = result
- return result
-
-
+ low[odd] = high[odd]
+ return np.ma.mean([low, high], axis=0, out=out)
#..............................................................................
@@ -842,9 +878,9 @@ def mask_rowcols(a, axis=None):
fill_value=999999)
"""
- a = asarray(a)
+ a = array(a, subok=False)
if a.ndim != 2:
- raise NotImplementedError("compress2d works for 2D arrays only.")
+ raise NotImplementedError("mask_rowcols works for 2D arrays only.")
m = getmask(a)
# Nothing is masked: return a
if m is nomask or not m.any():
@@ -1735,7 +1771,7 @@ def _ezclump(mask):
#def clump_masked(a):
if mask.ndim > 1:
mask = mask.ravel()
- idx = (mask[1:] - mask[:-1]).nonzero()
+ idx = (mask[1:] ^ mask[:-1]).nonzero()
idx = idx[0] + 1
slices = [slice(left, right)
for (left, right) in zip(itertools.chain([0], idx),