diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2015-04-01 21:13:43 -0400 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-04-01 21:13:43 -0400 |
commit | 8b6effadd7836f7e80f0f3e7dd9dd43d20ad1590 (patch) | |
tree | 0ddea7c067f4bf0630bb5736ab562942beea7e3b | |
parent | 799a4c7e66b5e6ab6e9d48b29e386c39c991955e (diff) | |
parent | bf8d3329d43bf534e45cb8182a6d712138566cdc (diff) | |
download | numpy-8b6effadd7836f7e80f0f3e7dd9dd43d20ad1590.tar.gz |
Merge pull request #5709 from abalkin/issue-5185
BUG: Implemented MaskedArray.dot
-rw-r--r-- | numpy/ma/core.py | 20 | ||||
-rw-r--r-- | numpy/ma/extras.py | 8 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 24 |
3 files changed, 45 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 79924351c..51e9f0f28 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -4522,6 +4522,26 @@ class MaskedArray(ndarray): return D.astype(dtype).filled(0).sum(axis=None, out=out) trace.__doc__ = ndarray.trace.__doc__ + def dot(self, other, out=None): + am = ~getmaskarray(self) + bm = ~getmaskarray(other) + if out is None: + d = np.dot(filled(self, 0), filled(other, 0)) + m = ~np.dot(am, bm) + if d.ndim == 0: + d = np.asarray(d) + r = d.view(get_masked_subclass(self, other)) + r.__setmask__(m) + return r + d = self.filled(0).dot(other.filled(0), out._data) + if out.mask.shape != d.shape: + out._mask = numpy.empty(d.shape, MaskType) + np.dot(am, bm, out._mask) + np.logical_not(out._mask, out._mask) + return out + dot.__doc__ = ndarray.dot.__doc__ + + def sum(self, axis=None, dtype=None, out=None): """ Return the sum of the array elements over the given axis. diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 6d812964d..d389099ae 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -1047,13 +1047,7 @@ def dot(a, b, strict=False): if strict and (a.ndim == 2) and (b.ndim == 2): a = mask_rows(a) b = mask_cols(b) - # - d = np.dot(filled(a, 0), filled(b, 0)) - # - am = (~getmaskarray(a)) - bm = (~getmaskarray(b)) - m = ~np.dot(am, bm) - return masked_array(d, mask=m) + return a.dot(b) #####-------------------------------------------------------------------------- #---- --- arraysetops --- diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 1d4462306..807fc0ba6 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -2982,6 +2982,30 @@ class TestMaskedArrayMathMethods(TestCase): X.trace() - sum(mXdiag.mask * X.diagonal(), axis=0)) + def test_dot(self): + # Tests dot on MaskedArrays. + (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d + fx = mx.filled(0) + r = mx.dot(mx) + assert_almost_equal(r.filled(0), fx.dot(fx)) + assert_(r.mask is nomask) + + fX = mX.filled(0) + r = mX.dot(mX) + assert_almost_equal(r.filled(0), fX.dot(fX)) + assert_(r.mask[1,3]) + r1 = empty_like(r) + mX.dot(mX, r1) + assert_almost_equal(r, r1) + + mYY = mXX.swapaxes(-1, -2) + fXX, fYY = mXX.filled(0), mYY.filled(0) + r = mXX.dot(mYY) + assert_almost_equal(r.filled(0), fXX.dot(fYY)) + r1 = empty_like(r) + mXX.dot(mYY, r1) + assert_almost_equal(r, r1) + def test_varstd(self): # Tests var & std on MaskedArrays. (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d |