summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-07-18 00:42:15 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-09-05 00:49:10 -0700
commit9aea5a42bd8b74ffd2472d1e6c7761daca868a4a (patch)
tree2d017f32dfce98d38014c1fa9f163c2578d1d9d4 /numpy/ma
parente1ccca947a297deb90301acc14b1e13e8dd0b816 (diff)
downloadnumpy-9aea5a42bd8b74ffd2472d1e6c7761daca868a4a.tar.gz
BUG: Make transpose and diagonal masks be views
Fixes gh-8506 and fixes gh-7483
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py3
-rw-r--r--numpy/ma/tests/test_core.py25
2 files changed, 24 insertions, 4 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index deddacfdc..0035cd9f7 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -2558,7 +2558,8 @@ def _arraymethod(funcname, onmask=True):
if not onmask:
result.__setmask__(mask)
elif mask is not nomask:
- result.__setmask__(getattr(mask, funcname)(*args, **params))
+ # __setmask__ makes a copy, which we don't want
+ result._mask = getattr(mask, funcname)(*args, **params)
return result
methdoc = getattr(ndarray, funcname, None) or getattr(np, funcname, None)
if methdoc is not None:
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index e3c35cf81..b19a138ab 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -3205,9 +3205,7 @@ class TestMaskedArrayMethods(object):
assert_(m_arr_sq is not np.ma.masked)
assert_equal(m_arr_sq.mask, True)
m_arr_sq[...] = 2
- # TODO: mask isn't copied to/from views yet in maskedarray, so we can
- # only check the data
- assert_equal(m_arr.data[0,0], 2)
+ assert_equal(m_arr[0,0], 2)
def test_swapaxes(self):
# Tests swapaxes on MaskedArrays.
@@ -3396,6 +3394,27 @@ class TestMaskedArrayMethods(object):
assert_equal(x.T.mask, x.mask)
assert_equal(x.T.data, x.data)
+ def test_transpose_view(self):
+ x = np.ma.array([[1, 2, 3], [4, 5, 6]])
+ x[0,1] = np.ma.masked
+ xt = x.T
+
+ xt[1,0] = 10
+ xt[0,1] = np.ma.masked
+
+ assert_equal(x.data, xt.T.data)
+ assert_equal(x.mask, xt.T.mask)
+
+ def test_diagonal_view(self):
+ x = np.ma.zeros((3,3))
+ x[0,0] = 10
+ x[1,1] = np.ma.masked
+ x[2,2] = 20
+ xd = x.diagonal()
+ x[1,1] = 15
+ assert_equal(xd.mask, x.diagonal().mask)
+ assert_equal(xd.data, x.diagonal().data)
+
class TestMaskedArrayMathMethods(object):