summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2010-09-13 15:43:27 +0000
committerpierregm <pierregm@localhost>2010-09-13 15:43:27 +0000
commit96afea0b32c77fa60112a1a78d66af77912e2523 (patch)
treee8dadd935c387ee5d07b200ae9adc7be4de50336 /numpy/ma
parent7213c5d804412b1ab6f23c6419ba836865af517a (diff)
downloadnumpy-96afea0b32c77fa60112a1a78d66af77912e2523.tar.gz
* ma.core._print_templates: switched the keys 'short' and 'long' to 'short_std' and 'long_std' respectively (bug #1586)
* Fixed incorrect broadcasting in ma.power (bug #1606)
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py13
-rw-r--r--numpy/ma/tests/test_core.py33
2 files changed, 39 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index de7485638..a945789df 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -2296,14 +2296,14 @@ def _recursive_printoption(result, mask, printopt):
np.putmask(curdata, curmask, printopt)
return
-_print_templates = dict(long="""\
+_print_templates = dict(long_std="""\
masked_%(name)s(data =
%(data)s,
%(nlen)s mask =
%(mask)s,
%(nlen)s fill_value = %(fill)s)
""",
- short="""\
+ short_std="""\
masked_%(name)s(data = %(data)s,
%(nlen)s mask = %(mask)s,
%(nlen)s fill_value = %(fill)s)
@@ -3574,8 +3574,8 @@ class MaskedArray(ndarray):
return _print_templates['short_flx'] % parameters
return _print_templates['long_flx'] % parameters
elif n <= 1:
- return _print_templates['short'] % parameters
- return _print_templates['long'] % parameters
+ return _print_templates['short_std'] % parameters
+ return _print_templates['long_std'] % parameters
def __eq__(self, other):
@@ -5972,7 +5972,7 @@ harden_mask = _frommethod('harden_mask')
ids = _frommethod('ids')
maximum = _maximum_operation()
mean = _frommethod('mean')
-minimum = _minimum_operation ()
+minimum = _minimum_operation()
nonzero = _frommethod('nonzero')
prod = _frommethod('prod')
product = _frommethod('prod')
@@ -6040,8 +6040,7 @@ def power(a, b, third=None):
if m is not nomask:
if not (result.ndim):
return masked
- m |= invalid
- result._mask = m
+ result._mask = np.logical_or(m, invalid)
# Fix the invalid parts
if invalid.any():
if not result.ndim:
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 908d7adc6..f907a3447 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -2975,6 +2975,39 @@ class TestMaskedArrayFunctions(TestCase):
assert_almost_equal(x, y)
assert_almost_equal(x._data, y._data)
+ def test_power_w_broadcasting(self):
+ "Test power w/ broadcasting"
+ a2 = np.array([[1., 2., 3.], [4., 5., 6.]])
+ a2m = array(a2, mask=[[1, 0, 0], [0, 0, 1]])
+ b1 = np.array([2, 4, 3])
+ b1m = array(b1, mask=[0, 1, 0])
+ b2 = np.array([b1, b1])
+ b2m = array(b2, mask=[[0, 1, 0], [0, 1, 0]])
+ #
+ ctrl = array([[1 ** 2, 2 ** 4, 3 ** 3], [4 ** 2, 5 ** 4, 6 ** 3]],
+ mask=[[1, 1, 0], [0, 1, 1]])
+ # No broadcasting, base & exp w/ mask
+ test = a2m ** b2m
+ assert_equal(test, ctrl)
+ assert_equal(test.mask, ctrl.mask)
+ # No broadcasting, base w/ mask, exp w/o mask
+ test = a2m ** b2
+ assert_equal(test, ctrl)
+ assert_equal(test.mask, a2m.mask)
+ # No broadcasting, base w/o mask, exp w/ mask
+ test = a2 ** b2m
+ assert_equal(test, ctrl)
+ assert_equal(test.mask, b2m.mask)
+ #
+ ctrl = array([[2 ** 2, 4 ** 4, 3 ** 3], [2 ** 2, 4 ** 4, 3 ** 3]],
+ mask=[[0, 1, 0], [0, 1, 0]])
+ test = b1 ** b2m
+ assert_equal(test, ctrl)
+ assert_equal(test.mask, ctrl.mask)
+ test = b2m ** b1
+ assert_equal(test, ctrl)
+ assert_equal(test.mask, ctrl.mask)
+
def test_where(self):
"Test the where function"