summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/ma/core.py241
-rw-r--r--numpy/ma/tests/test_subclassing.py28
2 files changed, 156 insertions, 113 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 5ef5cd93d..5cf11ffb9 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -452,7 +452,7 @@ def getdata(a, subok=True):
Input ``MaskedArray``, alternatively a ndarray or a subclass thereof.
subok : bool
Whether to force the output to be a `pure` ndarray (False) or to
- return a subclass of ndarray if approriate (True - default).
+ return a subclass of ndarray if approriate (True, default).
See Also
--------
@@ -483,7 +483,10 @@ def getdata(a, subok=True):
[3, 4]])
"""
- data = getattr(a, '_data', np.array(a, subok=subok))
+ try:
+ data = a._data
+ except AttributeError:
+ data = np.array(a, copy=False, subok=subok)
if not subok:
return data.view(ndarray)
return data
@@ -549,7 +552,9 @@ class _DomainCheckInterval:
"Execute the call behavior."
return umath.logical_or(umath.greater (x, self.b),
umath.less(x, self.a))
-#............................
+
+
+
class _DomainTan:
"""Define a valid interval for the `tan` function, so that:
@@ -559,14 +564,18 @@ class _DomainTan:
def __init__(self, eps):
"domain_tan(eps) = true where abs(cos(x)) < eps)"
self.eps = eps
+
def __call__ (self, x):
"Executes the call behavior."
return umath.less(umath.absolute(umath.cos(x)), self.eps)
-#............................
+
+
+
class _DomainSafeDivide:
"""Define a domain for safe division."""
def __init__ (self, tolerance=None):
self.tolerance = tolerance
+
def __call__ (self, a, b):
# Delay the selection of the tolerance to here in order to reduce numpy
# import times. The calculation of these parameters is a substantial
@@ -574,7 +583,9 @@ class _DomainSafeDivide:
if self.tolerance is None:
self.tolerance = np.finfo(float).tiny
return umath.absolute(a) * self.tolerance >= umath.absolute(b)
-#............................
+
+
+
class _DomainGreater:
"DomainGreater(v)(x) = true where x <= v"
def __init__(self, critical_value):
@@ -584,7 +595,9 @@ class _DomainGreater:
def __call__ (self, x):
"Executes the call behavior."
return umath.less_equal(x, self.critical_value)
-#............................
+
+
+
class _DomainGreaterEqual:
"DomainGreaterEqual(v)(x) = true where x < v"
def __init__(self, critical_value):
@@ -597,8 +610,9 @@ class _DomainGreaterEqual:
#..............................................................................
class _MaskedUnaryOperation:
- """Defines masked version of unary operations, where invalid
- values are pre-masked.
+ """
+ Defines masked version of unary operations, where invalid values are
+ pre-masked.
Parameters
----------
@@ -625,41 +639,54 @@ class _MaskedUnaryOperation:
#
def __call__ (self, a, *args, **kwargs):
"Execute the call behavior."
- #
- m = getmask(a)
- d1 = getdata(a)
- #
+ d = getdata(a)
+ # Case 1.1. : Domained function
if self.domain is not None:
- dm = np.array(self.domain(d1), copy=False)
- m = np.logical_or(m, dm)
- # The following two lines control the domain filling methods.
- d1 = d1.copy()
- # We could use smart indexing : d1[dm] = self.fill ...
- # ... but np.putmask looks more efficient, despite the copy.
- np.putmask(d1, dm, self.fill)
- # Take care of the masked singletong first ...
- if (not m.ndim) and m:
- return masked
- elif m is nomask:
- result = self.f(d1, *args, **kwargs)
+ # Save the error status
+ err_status_ini = np.geterr()
+ np.seterr(divide='ignore', invalid='ignore')
+ # Get the result
+ result = self.f(d, *args, **kwargs)
+ # Reset the error status
+ np.seterr(**err_status_ini)
+ # Make a mask
+ m = ~umath.isfinite(result)
+ m |= getmask(a)
+ # Case 1.2. : Function without a domain
else:
- result = np.where(m, d1, self.f(d1, *args, **kwargs))
- # If result is not a scalar
- if result.ndim:
- # Get the result subclass:
- if isinstance(a, MaskedArray):
- subtype = type(a)
- else:
- subtype = MaskedArray
- result = result.view(subtype)
- result._mask = m
- result._update_from(a)
+ # Get the result and the mask
+ result = self.f(d, *args, **kwargs)
+ m = getmask(a)
+ # Case 2.1. : The result is scalarscalar
+ if not result.ndim:
+ if m:
+ return masked
+ return result
+ # Case 2.2. The result is an array
+ # We need to fill the invalid data back w/ the input
+ # Now, that's plain silly: in C, we would just skip the element and keep
+ # the original, but we do have to do it that way in Python
+ if m is not nomask:
+ # In case result has a lower dtype than the inputs (as in equal)
+ try:
+ np.putmask(result, m, d)
+ except TypeError:
+ pass
+ # Transform to
+ if isinstance(a, MaskedArray):
+ subtype = type(a)
+ else:
+ subtype = MaskedArray
+ result = result.view(subtype)
+ result._mask = m
+ result._update_from(a)
return result
#
def __str__ (self):
return "Masked version of %s. [Invalid values are masked]" % str(self.f)
-#..............................................................................
+
+
class _MaskedBinaryOperation:
"""Define masked version of binary operations, where invalid
values are pre-masked.
@@ -689,50 +716,44 @@ class _MaskedBinaryOperation:
def __call__ (self, a, b, *args, **kwargs):
"Execute the call behavior."
- m = mask_or(getmask(a), getmask(b), shrink=False)
- (da, db) = (getdata(a), getdata(b))
- # Easy case: there's no mask...
- if m is nomask:
- result = self.f(da, db, *args, **kwargs)
- # There are some masked elements: run only on the unmasked
+ # Get the data, as ndarray
+ (da, db) = (getdata(a, subok=False), getdata(b, subok=False))
+ # Get the mask
+ (ma, mb) = (getmask(a), getmask(b))
+ if ma is nomask:
+ if mb is nomask:
+ m = nomask
+ else:
+ m = umath.logical_or(getmaskarray(a), mb)
+ elif mb is nomask:
+ m = umath.logical_or(ma, getmaskarray(b))
else:
- result = np.where(m, da, self.f(da, db, *args, **kwargs))
- # Transforms to a (subclass of) MaskedArray if we don't have a scalar
- if result.shape:
- result = result.view(get_masked_subclass(a, b))
- # If we have a mask, make sure it's broadcasted properly
- if m.any():
- result._mask = mask_or(getmaskarray(a), getmaskarray(b))
- # If some initial masks where not shrunk, don't shrink the result
- elif m.shape:
- result._mask = make_mask_none(result.shape, result.dtype)
+ m = umath.logical_or(ma, mb)
+ # Get the result
+ result = self.f(da, db, *args, **kwargs)
+ # Case 1. : scalar
+ if not result.ndim:
+ if m:
+ return masked
+ return result
+ # Case 2. : array
+ # Revert result to da where masked
+ if m.any():
+ np.putmask(result, m, 0)
+ result += m*da
+ # Transforms to a (subclass of) MaskedArray
+ result = result.view(get_masked_subclass(a, b))
+ result._mask = m
+ # Update the optional info from the inputs
+ if isinstance(b, MaskedArray):
if isinstance(a, MaskedArray):
result._update_from(a)
- if isinstance(b, MaskedArray):
+ else:
result._update_from(b)
- # ... or return masked if we have a scalar and the common mask is True
- elif m:
- return masked
+ elif isinstance(a, MaskedArray):
+ result._update_from(a)
return result
-#
-# result = self.f(d1, d2, *args, **kwargs).view(get_masked_subclass(a, b))
-# if len(result.shape):
-# if m is not nomask:
-# result._mask = make_mask_none(result.shape)
-# result._mask.flat = m
-# #!!!!!
-# # Force m to be at least 1D
-# m.shape = m.shape or (1,)
-# print "Resetting data"
-# result.data[m].flat = d1.flat
-# #!!!!!
-# if isinstance(a, MaskedArray):
-# result._update_from(a)
-# if isinstance(b, MaskedArray):
-# result._update_from(b)
-# elif m:
-# return masked
-# return result
+
def reduce(self, target, axis=0, dtype=None):
"""Reduce `target` along the given `axis`."""
@@ -776,10 +797,9 @@ class _MaskedBinaryOperation:
if (not m.ndim) and m:
return masked
(da, db) = (getdata(a), getdata(b))
- if m is nomask:
- d = self.f.outer(da, db)
- else:
- d = np.where(m, da, self.f.outer(da, db))
+ d = self.f.outer(da, db)
+ if m is not nomask:
+ np.putmask(d, m, da)
if d.shape:
d = d.view(get_masked_subclass(a, b))
d._mask = m
@@ -800,7 +820,8 @@ class _MaskedBinaryOperation:
def __str__ (self):
return "Masked version of " + str(self.f)
-#..............................................................................
+
+
class _DomainedBinaryOperation:
"""
Define binary operations that have a domain, like divide.
@@ -830,38 +851,38 @@ class _DomainedBinaryOperation:
def __call__(self, a, b, *args, **kwargs):
"Execute the call behavior."
- ma = getmask(a)
- mb = getmaskarray(b)
- da = getdata(a)
- db = getdata(b)
- t = narray(self.domain(da, db), copy=False)
- if t.any(None):
- mb = mask_or(mb, t, shrink=False)
- # The following line controls the domain filling
- if t.size == db.size:
- db = np.where(t, self.filly, db)
+ # Get the data and the mask
+ (da, db) = (getdata(a, subok=False), getdata(b, subok=False))
+ (ma, mb) = (getmask(a), getmask(b))
+ # Save the current error status
+ err_status_ini = np.geterr()
+ np.seterr(divide='ignore', invalid='ignore')
+ # Get the result
+ result = self.f(da, db, *args, **kwargs)
+ # Reset the error status
+ np.seterr(**err_status_ini)
+ # Get the mask as a combination of ma, mb and invalid
+ m = ~umath.isfinite(result)
+ m |= ma
+ m |= mb
+ # Take care of the scalar case first
+ if (not m.ndim):
+ if m:
+ return masked
else:
- db = np.where(np.resize(t, db.shape), self.filly, db)
- # Shrink m if a.mask was nomask, otherwise don't.
- m = mask_or(ma, mb, shrink=(getattr(a, '_mask', nomask) is nomask))
- if (not m.ndim) and m:
- return masked
- elif (m is nomask):
- result = self.f(da, db, *args, **kwargs)
- else:
- result = np.where(m, da, self.f(da, db, *args, **kwargs))
- if result.shape:
- result = result.view(get_masked_subclass(a, b))
- # If we have a mask, make sure it's broadcasted properly
- if m.any():
- result._mask = mask_or(getmaskarray(a), mb)
- # If some initial masks where not shrunk, don't shrink the result
- elif m.shape:
- result._mask = make_mask_none(result.shape, result.dtype)
+ return result
+ # When the mask is True, put back da
+ np.putmask(result, m, 0)
+ result += m*da
+ result = result.view(get_masked_subclass(a, b))
+ result._mask = m
+ if isinstance(b, MaskedArray):
if isinstance(a, MaskedArray):
result._update_from(a)
- if isinstance(b, MaskedArray):
+ else:
result._update_from(b)
+ elif isinstance(a, MaskedArray):
+ result._update_from(a)
return result
def __str__ (self):
@@ -2637,8 +2658,8 @@ class MaskedArray(ndarray):
#........................................
# ndgetattr = ndarray.__getattribute__
_data = self._data
- _dtype = ndarray.__getattribute__(_data,'dtype')
- _mask = ndarray.__getattribute__(self,'_mask')
+ _dtype = ndarray.__getattribute__(_data, 'dtype')
+ _mask = ndarray.__getattribute__(self, '_mask')
nbfields = len(_dtype.names or ())
#........................................
if value is masked:
diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py
index 5943ad6c1..b732cf845 100644
--- a/numpy/ma/tests/test_subclassing.py
+++ b/numpy/ma/tests/test_subclassing.py
@@ -70,6 +70,11 @@ mmatrix = MMatrix
class TestSubclassing(TestCase):
"""Test suite for masked subclasses of ndarray."""
+ def setUp(self):
+ x = np.arange(5)
+ mx = mmatrix(x, mask=[0, 1, 0, 0, 0])
+ self.data = (x, mx)
+
def test_data_subclassing(self):
"Tests whether the subclass is kept."
x = np.arange(5)
@@ -82,19 +87,36 @@ class TestSubclassing(TestCase):
def test_maskedarray_subclassing(self):
"Tests subclassing MaskedArray"
- x = np.arange(5)
- mx = mmatrix(x,mask=[0,1,0,0,0])
+ (x, mx) = self.data
self.failUnless(isinstance(mx._data, np.matrix))
+
+ def test_masked_unary_operations(self):
"Tests masked_unary_operation"
+ (x, mx) = self.data
+ self.failUnless(isinstance(log(mx), mmatrix))
+ assert_equal(log(x), np.log(x))
+
+ def test_masked_binary_operations(self):
+ "Tests masked_binary_operation"
+ (x, mx) = self.data
+ # Result should be a mmatrix
self.failUnless(isinstance(add(mx,mx), mmatrix))
self.failUnless(isinstance(add(mx,x), mmatrix))
+ # Result should work
assert_equal(add(mx,x), mx+x)
self.failUnless(isinstance(add(mx,mx)._data, np.matrix))
self.failUnless(isinstance(add.outer(mx,mx), mmatrix))
- "Tests masked_binary_operation"
self.failUnless(isinstance(hypot(mx,mx), mmatrix))
self.failUnless(isinstance(hypot(mx,x), mmatrix))
+ def test_masked_binary_operations(self):
+ "Tests domained_masked_binary_operation"
+ (x, mx) = self.data
+ xmx = masked_array(mx.data.__array__(), mask=mx.mask)
+ self.failUnless(isinstance(divide(mx,mx), mmatrix))
+ self.failUnless(isinstance(divide(mx,x), mmatrix))
+ assert_equal(divide(mx, mx), divide(xmx, xmx))
+
def test_attributepropagation(self):
x = array(arange(5), mask=[0]+[1]*4)
my = masked_array(subarray(x))