diff options
author | pierregm <pierregm@localhost> | 2008-11-21 20:49:33 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2008-11-21 20:49:33 +0000 |
commit | 702538f9e45e63e7813034d552a1c8ea4d517513 (patch) | |
tree | d00b8ebabb6d804bc9ad9861eb4b8f5e4c97ce9d | |
parent | 3a1ffccbafc70bd5d45568fb4dd17de2accf92a4 (diff) | |
download | numpy-702538f9e45e63e7813034d552a1c8ea4d517513.tar.gz |
Rewrote allclose to allow comparison with a scalar
-rw-r--r-- | numpy/ma/core.py | 120 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 23 |
2 files changed, 110 insertions, 33 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 78c6fde81..a341af5e0 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -1105,10 +1105,10 @@ def masked_values(x, value, rtol=1.e-5, atol=1.e-8, copy=True, shrink=True): Whether to collapse a mask full of False to nomask """ - abs = umath.absolute + mabs = umath.absolute xnew = filled(x, value) if issubclass(xnew.dtype.type, np.floating): - condition = umath.less_equal(abs(xnew-value), atol+rtol*abs(value)) + condition = umath.less_equal(mabs(xnew-value), atol + rtol*mabs(value)) mask = getattr(x, '_mask', nomask) else: condition = umath.equal(xnew, value) @@ -1446,7 +1446,7 @@ class MaskedArray(ndarray): self.__dict__.update(_optinfo) return #........................ - def __array_finalize__(self,obj): + def __array_finalize__(self, obj): """Finalizes the masked array. """ # Get main attributes ......... @@ -1574,7 +1574,7 @@ class MaskedArray(ndarray): if self._mask is nomask: output._mask = nomask else: - output._mask = self._mask.astype([(n,bool) for n in names]) + output._mask = self._mask.astype([(n, bool) for n in names]) # Don't check _fill_value if it's None, that'll speed things up if self._fill_value is not None: output._fill_value = _check_fill_value(self._fill_value, newtype) @@ -1685,7 +1685,7 @@ class MaskedArray(ndarray): ndarray.__setitem__(_mask, indx, mval) elif hasattr(indx, 'dtype') and (indx.dtype==MaskType): indx = indx * umath.logical_not(_mask) - ndarray.__setitem__(_data,indx,dval) + ndarray.__setitem__(_data, indx, dval) else: if nbfields: err_msg = "Flexible 'hard' masks are not yet supported..." @@ -1716,7 +1716,7 @@ class MaskedArray(ndarray): those locations. """ - self.__setitem__(slice(i,j), value) + self.__setitem__(slice(i, j), value) #............................................ def __setmask__(self, mask, copy=False): """Set the mask. @@ -2220,12 +2220,14 @@ masked_%(name)s(data = %(data)s, return int(self.item()) #............................................ def get_imag(self): + "Returns the imaginary part." result = self._data.imag.view(type(self)) result.__setmask__(self._mask) return result imag = property(fget=get_imag, doc="Imaginary part.") def get_real(self): + "Returns the real part." result = self._data.real.view(type(self)) result.__setmask__(self._mask) return result @@ -2234,14 +2236,14 @@ masked_%(name)s(data = %(data)s, #............................................ def count(self, axis=None): - """Count the non-masked elements of the array along the given - axis. + """ + Count the non-masked elements of the array along the given axis. Parameters ---------- axis : int, optional - Axis along which to count the non-masked elements. If - not given, all the non masked elements are counted. + Axis along which to count the non-masked elements. If axis is None, + all the non masked elements are counted. Returns ------- @@ -3447,9 +3449,11 @@ masked_%(name)s(data = %(data)s, (self.__class__, self._baseclass, (0,), 'b', ), self.__getstate__()) # - def __deepcopy__(self, memo={}): + def __deepcopy__(self, memo=None): from copy import deepcopy copied = MaskedArray.__new__(type(self), self, copy=True) + if memo is None: + memo = {} memo[id(self)] = copied for (k,v) in self.__dict__.iteritems(): copied.__dict__[k] = deepcopy(v, memo) @@ -3687,16 +3691,16 @@ def power(a, b, third=None): fa = getdata(a) fb = getdata(b) # Get the type of the result (so that we preserve subclasses) - if isinstance(a,MaskedArray): + if isinstance(a, MaskedArray): basetype = type(a) else: basetype = MaskedArray # Get the result and view it as a (subclass of) MaskedArray - result = umath.power(fa,fb).view(basetype) + result = umath.power(fa, fb).view(basetype) # Find where we're in trouble w/ NaNs and Infs invalid = np.logical_not(np.isfinite(result.view(ndarray))) # Retrieve some extra attributes if needed - if isinstance(result,MaskedArray): + if isinstance(result, MaskedArray): result._update_from(a) # Add the initial mask if m is not nomask: @@ -3770,7 +3774,7 @@ def sort(a, axis=-1, kind='quicksort', order=None, endwith=True, fill_value=None filler = fill_value # return indx = np.indices(a.shape).tolist() - indx[axis] = filled(a,filler).argsort(axis=axis,kind=kind,order=order) + indx[axis] = filled(a, filler).argsort(axis=axis, kind=kind, order=order) return a[indx] sort.__doc__ = MaskedArray.sort.__doc__ @@ -3820,7 +3824,7 @@ def count(a, axis = None): count.__doc__ = MaskedArray.count.__doc__ -def expand_dims(x,axis): +def expand_dims(x, axis): """ Expand the shape of the array by including a new axis before the given one. @@ -4160,24 +4164,76 @@ def allequal (a, b, fill_value=True): else: return False -def allclose (a, b, fill_value=True, rtol=1.e-5, atol=1.e-8): - """ Return True if all elements of a and b are equal subject to - given tolerances. +def allclose (a, b, masked_equal=True, rtol=1.e-5, atol=1.e-8, fill_value=None): + """ + Returns True if two arrays are element-wise equal within a tolerance. + + The tolerance values are positive, typically very small numbers. The + relative difference (`rtol` * `b`) and the absolute difference (`atol`) + are added together to compare against the absolute difference between `a` + and `b`. + + Parameters + ---------- + a, b : array_like + Input arrays to compare. + fill_value : boolean, optional + Whether masked values in a or b are considered equal (True) or not + (False). + + rtol : Relative tolerance + The relative difference is equal to `rtol` * `b`. + atol : Absolute tolerance + The absolute difference is equal to `atol`. + + Returns + ------- + y : bool + Returns True if the two arrays are equal within the given + tolerance; False otherwise. If either array contains NaN, then + False is returned. - If fill_value is True, masked values are considered equal. - If fill_value is False, masked values considered unequal. - The relative error rtol should be positive and << 1.0 - The absolute error atol comes into play for those elements of b - that are very small or zero; it says how small `a` must be also. + See Also + -------- + all, any, alltrue, sometrue + + Notes + ----- + If the following equation is element-wise True, then allclose returns + True. + + absolute(`a` - `b`) <= (`atol` + `rtol` * absolute(`b`)) + + Return True if all elements of a and b are equal subject to + given tolerances. """ - m = mask_or(getmask(a), getmask(b)) - d1 = getdata(a) - d2 = getdata(b) - x = filled(array(d1, copy=0, mask=m), fill_value).astype(float) - y = filled(array(d2, copy=0, mask=m), 1).astype(float) - d = umath.less_equal(umath.absolute(x-y), atol + rtol * umath.absolute(y)) - return np.alltrue(np.ravel(d)) + if fill_value is not None: + warnings.warn("The use of fill_value is deprecated."\ + " Please use masked_equal instead.") + masked_equal = fill_value + # + x = masked_array(a, copy=False) + y = masked_array(b, copy=False) + m = mask_or(getmask(x), getmask(y)) + xinf = np.isinf(masked_array(x, copy=False, mask=m)).filled(False) + # If we have some infs, they should fall at the same place. + if not np.all(xinf == filled(np.isinf(y), False)): + return False + # No infs at all + if not np.any(xinf): + d = filled(umath.less_equal(umath.absolute(x-y), + atol + rtol * umath.absolute(y)), + masked_equal) + return np.all(d) + if not np.all(filled(x[xinf] == y[xinf], masked_equal)): + return False + x = x[~xinf] + y = y[~xinf] + d = filled(umath.less_equal(umath.absolute(x-y), + atol + rtol * umath.absolute(y)), + masked_equal) + return np.all(d) #.............................................................................. def asarray(a, dtype=None): @@ -4225,7 +4281,7 @@ def asanyarray(a, dtype=None): #####-------------------------------------------------------------------------- #---- --- Pickling --- #####-------------------------------------------------------------------------- -def dump(a,F): +def dump(a, F): """ Pickle the MaskedArray `a` to the file `F`. `F` can either be the handle of an exiting file, or a string representing a file diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 606b3c285..df601433c 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1298,6 +1298,27 @@ class TestMaskedArrayMethods(TestCase): assert_equal(m.transpose(), m._data.transpose()) + def test_allclose(self): + "Tests allclose on arrays" + a = np.random.rand(10) + b = a + np.random.rand(10) * 1e-8 + self.failUnless(allclose(a,b)) + # Test allclose w/ infs + a[0] = np.inf + self.failUnless(not allclose(a,b)) + b[0] = np.inf + self.failUnless(allclose(a,b)) + # Test all close w/ masked + a = masked_array(a) + a[-1] = masked + self.failUnless(allclose(a,b, masked_equal=True)) + self.failUnless(not allclose(a, b, masked_equal=False)) + # Test comparison w/ scalar + a *= 1e-8 + a[0] = 0 + self.failUnless(allclose(a, 0, masked_equal=True)) + + def test_allany(self): """Checks the any/all methods/functions.""" x = np.array([[ 0.13, 0.26, 0.90], @@ -1467,7 +1488,7 @@ class TestMaskedArrayMethods(TestCase): def test_empty(self): "Tests empty/like" - datatype = [('a',int_),('b',float),('c','|S8')] + datatype = [('a',int),('b',float),('c','|S8')] a = masked_array([(1,1.1,'1.1'),(2,2.2,'2.2'),(3,3.3,'3.3')], dtype=datatype) assert_equal(len(a.fill_value.item()), len(datatype)) |