summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2014-04-03 16:24:58 -0600
committerCharles Harris <charlesr.harris@gmail.com>2014-04-03 16:24:58 -0600
commit3998fdf2b99d976528e3de399fb536666d588bce (patch)
tree7c82b7a02d4f2cf47657c67e7e46e91c05e6eda2 /numpy/ma/extras.py
parente4082eb30e65276c6acac0b46b9e2d28b3b99948 (diff)
parent76c9bb336f77aa406c62ab5f3077517dec133d30 (diff)
downloadnumpy-3998fdf2b99d976528e3de399fb536666d588bce.tar.gz
Merge pull request #4463 from abalkin/issue-4461
BUG: Masked arrays and apply_over_axes
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py43
1 files changed, 35 insertions, 8 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 7182bf4ae..f53d9c7e5 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -416,26 +416,53 @@ def apply_over_axes(func, a, axes):
"""
(This docstring will be overwritten)
"""
- val = np.asarray(a)
- msk = getmaskarray(a)
+ val = asarray(a)
N = a.ndim
if array(axes).ndim == 0:
axes = (axes,)
for axis in axes:
if axis < 0: axis = N + axis
args = (val, axis)
- res = ma.array(func(*(val, axis)), mask=func(*(msk, axis)))
+ res = func(*args)
if res.ndim == val.ndim:
- (val, msk) = (res._data, res._mask)
+ val = res
else:
res = ma.expand_dims(res, axis)
if res.ndim == val.ndim:
- (val, msk) = (res._data, res._mask)
+ val = res
else:
- raise ValueError("Function is not returning"\
- " an array of correct shape")
+ raise ValueError("function is not returning "
+ "an array of the correct shape")
return val
-apply_over_axes.__doc__ = np.apply_over_axes.__doc__
+apply_over_axes.__doc__ = np.apply_over_axes.__doc__[
+ :np.apply_over_axes.__doc__.find('Notes')].rstrip() + \
+ """
+
+ Examples
+ --------
+ >>> a = ma.arange(24).reshape(2,3,4)
+ >>> a[:,0,1] = ma.masked
+ >>> a[:,1,:] = ma.masked
+ >>> print a
+ [[[0 -- 2 3]
+ [-- -- -- --]
+ [8 9 10 11]]
+
+ [[12 -- 14 15]
+ [-- -- -- --]
+ [20 21 22 23]]]
+ >>> print ma.apply_over_axes(ma.sum, a, [0,2])
+ [[[46]
+ [--]
+ [124]]]
+
+ Tuple axis arguments to ufuncs are equivalent:
+
+ >>> print ma.sum(a, axis=(0,2)).reshape((1,-1,1))
+ [[[46]
+ [--]
+ [124]]]
+"""
def average(a, axis=None, weights=None, returned=False):