diff options
author | pierregm <pierregm@localhost> | 2008-05-12 15:48:27 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2008-05-12 15:48:27 +0000 |
commit | c2b77e022e8393f057b5b54dc595f35b01a26809 (patch) | |
tree | 23516b0dd2905dfa604ae23a49a3c1106d3227d8 /numpy | |
parent | d6cbaddf167d3c99d4a0aca91fab241e2ae67f90 (diff) | |
download | numpy-c2b77e022e8393f057b5b54dc595f35b01a26809.tar.gz |
core : power : use the quick-and-dirty approach: compute everything and mask afterwards
: MaskedArray._update_from(obj) : ensure that _baseclass is a ndarray if obj wasn't one already
: introduced clip in the namespace, just for convenience
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 51 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 6 |
2 files changed, 42 insertions, 15 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 6a6f78f03..b1a06fc8d 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -26,8 +26,8 @@ __all__ = ['MAError', 'MaskType', 'MaskedArray', 'arctanh', 'argmax', 'argmin', 'argsort', 'around', 'array', 'asarray','asanyarray', 'bitwise_and', 'bitwise_or', 'bitwise_xor', - 'ceil', 'choose', 'common_fill_value', 'compress', 'compressed', - 'concatenate', 'conjugate', 'cos', 'cosh', 'count', + 'ceil', 'choose', 'clip', 'common_fill_value', 'compress', + 'compressed', 'concatenate', 'conjugate', 'cos', 'cosh', 'count', 'default_fill_value', 'diagonal', 'divide', 'dump', 'dumps', 'empty', 'empty_like', 'equal', 'exp', 'fabs', 'fmod', 'filled', 'floor', 'floor_divide','fix_invalid', @@ -1226,7 +1226,7 @@ class MaskedArray(numeric.ndarray): def _update_from(self, obj): """Copies some attributes of obj to self. """ - if obj is not None: + if obj is not None and isinstance(obj,ndarray): _baseclass = type(obj) else: _baseclass = ndarray @@ -2845,23 +2845,45 @@ def power(a, b, third=None): """ if third is not None: raise MAError, "3-argument power not supported." + # Get the masks ma = getmask(a) mb = getmask(b) m = mask_or(ma, mb) + # Get the rawdata fa = getdata(a) fb = getdata(b) - if fb.dtype.char in typecodes["Integer"]: - return masked_array(umath.power(fa, fb), m) - m = mask_or(m, (fa < 0) & (fb != fb.astype(int))) - if m is nomask: - return masked_array(umath.power(fa, fb)) + # Get the type of the result (so that we preserve subclasses) + if isinstance(a,MaskedArray): + basetype = type(a) else: - fa = fa.copy() - if m.all(): - fa.flat = 1 - else: - numpy.putmask(fa,m,1) - return masked_array(umath.power(fa, fb), m) + basetype = MaskedArray + # Get the result and view it as a (subclass of) MaskedArray + result = umath.power(fa,fb).view(basetype) + # Retrieve some extra attributes if needed + result._update_from(a) + # Find where we're in trouble w/ NaNs and Infs + invalid = numpy.logical_not(numpy.isfinite(result.view(ndarray))) + # Add the initial mask + if m is not nomask: + result._mask = m + # Fix the invalid parts + if invalid.any(): + result[invalid] = masked + result._data[invalid] = result.fill_value + return result + +# if fb.dtype.char in typecodes["Integer"]: +# return masked_array(umath.power(fa, fb), m) +# m = mask_or(m, (fa < 0) & (fb != fb.astype(int))) +# if m is nomask: +# return masked_array(umath.power(fa, fb)) +# else: +# fa = fa.copy() +# if m.all(): +# fa.flat = 1 +# else: +# numpy.putmask(fa,m,1) +# return masked_array(umath.power(fa, fb), m) #.............................................................................. def argsort(a, axis=None, kind='quicksort', order=None, fill_value=None): @@ -3373,6 +3395,7 @@ frombuffer = _convert2ma('frombuffer') fromfunction = _convert2ma('fromfunction') identity = _convert2ma('identity') indices = numpy.indices +clip = numpy.clip ############################################################################### diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index c1282a5ee..b04c1c2ab 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1571,7 +1571,11 @@ class TestMiscFunctions(NumpyTestCase): b = array([0.5,2.,0.5,2.,1.], mask=[0,0,0,0,1]) y = power(x,b) assert_almost_equal(y, [0, 1.21, 1.04880884817, 1.21, 0.] ) - assert_equal(y._mask, [1,0,0,0,1]) + assert_equal(y._mask, [1,0,0,0,1]) + b.mask = nomask + y = power(x,b) + assert_equal(y._mask, [1,0,0,0,1]) + ############################################################################### |