summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2008-05-12 15:48:27 +0000
committerpierregm <pierregm@localhost>2008-05-12 15:48:27 +0000
commitc2b77e022e8393f057b5b54dc595f35b01a26809 (patch)
tree23516b0dd2905dfa604ae23a49a3c1106d3227d8 /numpy
parentd6cbaddf167d3c99d4a0aca91fab241e2ae67f90 (diff)
downloadnumpy-c2b77e022e8393f057b5b54dc595f35b01a26809.tar.gz
core : power : use the quick-and-dirty approach: compute everything and mask afterwards
: MaskedArray._update_from(obj) : ensure that _baseclass is a ndarray if obj wasn't one already : introduced clip in the namespace, just for convenience
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py51
-rw-r--r--numpy/ma/tests/test_core.py6
2 files changed, 42 insertions, 15 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 6a6f78f03..b1a06fc8d 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -26,8 +26,8 @@ __all__ = ['MAError', 'MaskType', 'MaskedArray',
'arctanh', 'argmax', 'argmin', 'argsort', 'around',
'array', 'asarray','asanyarray',
'bitwise_and', 'bitwise_or', 'bitwise_xor',
- 'ceil', 'choose', 'common_fill_value', 'compress', 'compressed',
- 'concatenate', 'conjugate', 'cos', 'cosh', 'count',
+ 'ceil', 'choose', 'clip', 'common_fill_value', 'compress',
+ 'compressed', 'concatenate', 'conjugate', 'cos', 'cosh', 'count',
'default_fill_value', 'diagonal', 'divide', 'dump', 'dumps',
'empty', 'empty_like', 'equal', 'exp',
'fabs', 'fmod', 'filled', 'floor', 'floor_divide','fix_invalid',
@@ -1226,7 +1226,7 @@ class MaskedArray(numeric.ndarray):
def _update_from(self, obj):
"""Copies some attributes of obj to self.
"""
- if obj is not None:
+ if obj is not None and isinstance(obj,ndarray):
_baseclass = type(obj)
else:
_baseclass = ndarray
@@ -2845,23 +2845,45 @@ def power(a, b, third=None):
"""
if third is not None:
raise MAError, "3-argument power not supported."
+ # Get the masks
ma = getmask(a)
mb = getmask(b)
m = mask_or(ma, mb)
+ # Get the rawdata
fa = getdata(a)
fb = getdata(b)
- if fb.dtype.char in typecodes["Integer"]:
- return masked_array(umath.power(fa, fb), m)
- m = mask_or(m, (fa < 0) & (fb != fb.astype(int)))
- if m is nomask:
- return masked_array(umath.power(fa, fb))
+ # Get the type of the result (so that we preserve subclasses)
+ if isinstance(a,MaskedArray):
+ basetype = type(a)
else:
- fa = fa.copy()
- if m.all():
- fa.flat = 1
- else:
- numpy.putmask(fa,m,1)
- return masked_array(umath.power(fa, fb), m)
+ basetype = MaskedArray
+ # Get the result and view it as a (subclass of) MaskedArray
+ result = umath.power(fa,fb).view(basetype)
+ # Retrieve some extra attributes if needed
+ result._update_from(a)
+ # Find where we're in trouble w/ NaNs and Infs
+ invalid = numpy.logical_not(numpy.isfinite(result.view(ndarray)))
+ # Add the initial mask
+ if m is not nomask:
+ result._mask = m
+ # Fix the invalid parts
+ if invalid.any():
+ result[invalid] = masked
+ result._data[invalid] = result.fill_value
+ return result
+
+# if fb.dtype.char in typecodes["Integer"]:
+# return masked_array(umath.power(fa, fb), m)
+# m = mask_or(m, (fa < 0) & (fb != fb.astype(int)))
+# if m is nomask:
+# return masked_array(umath.power(fa, fb))
+# else:
+# fa = fa.copy()
+# if m.all():
+# fa.flat = 1
+# else:
+# numpy.putmask(fa,m,1)
+# return masked_array(umath.power(fa, fb), m)
#..............................................................................
def argsort(a, axis=None, kind='quicksort', order=None, fill_value=None):
@@ -3373,6 +3395,7 @@ frombuffer = _convert2ma('frombuffer')
fromfunction = _convert2ma('fromfunction')
identity = _convert2ma('identity')
indices = numpy.indices
+clip = numpy.clip
###############################################################################
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index c1282a5ee..b04c1c2ab 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1571,7 +1571,11 @@ class TestMiscFunctions(NumpyTestCase):
b = array([0.5,2.,0.5,2.,1.], mask=[0,0,0,0,1])
y = power(x,b)
assert_almost_equal(y, [0, 1.21, 1.04880884817, 1.21, 0.] )
- assert_equal(y._mask, [1,0,0,0,1])
+ assert_equal(y._mask, [1,0,0,0,1])
+ b.mask = nomask
+ y = power(x,b)
+ assert_equal(y._mask, [1,0,0,0,1])
+
###############################################################################