summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-04-01 21:13:43 -0400
committerCharles Harris <charlesr.harris@gmail.com>2015-04-01 21:13:43 -0400
commit8b6effadd7836f7e80f0f3e7dd9dd43d20ad1590 (patch)
tree0ddea7c067f4bf0630bb5736ab562942beea7e3b
parent799a4c7e66b5e6ab6e9d48b29e386c39c991955e (diff)
parentbf8d3329d43bf534e45cb8182a6d712138566cdc (diff)
downloadnumpy-8b6effadd7836f7e80f0f3e7dd9dd43d20ad1590.tar.gz
Merge pull request #5709 from abalkin/issue-5185
BUG: Implemented MaskedArray.dot
-rw-r--r--numpy/ma/core.py20
-rw-r--r--numpy/ma/extras.py8
-rw-r--r--numpy/ma/tests/test_core.py24
3 files changed, 45 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 79924351c..51e9f0f28 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -4522,6 +4522,26 @@ class MaskedArray(ndarray):
return D.astype(dtype).filled(0).sum(axis=None, out=out)
trace.__doc__ = ndarray.trace.__doc__
+ def dot(self, other, out=None):
+ am = ~getmaskarray(self)
+ bm = ~getmaskarray(other)
+ if out is None:
+ d = np.dot(filled(self, 0), filled(other, 0))
+ m = ~np.dot(am, bm)
+ if d.ndim == 0:
+ d = np.asarray(d)
+ r = d.view(get_masked_subclass(self, other))
+ r.__setmask__(m)
+ return r
+ d = self.filled(0).dot(other.filled(0), out._data)
+ if out.mask.shape != d.shape:
+ out._mask = numpy.empty(d.shape, MaskType)
+ np.dot(am, bm, out._mask)
+ np.logical_not(out._mask, out._mask)
+ return out
+ dot.__doc__ = ndarray.dot.__doc__
+
+
def sum(self, axis=None, dtype=None, out=None):
"""
Return the sum of the array elements over the given axis.
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 6d812964d..d389099ae 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -1047,13 +1047,7 @@ def dot(a, b, strict=False):
if strict and (a.ndim == 2) and (b.ndim == 2):
a = mask_rows(a)
b = mask_cols(b)
- #
- d = np.dot(filled(a, 0), filled(b, 0))
- #
- am = (~getmaskarray(a))
- bm = (~getmaskarray(b))
- m = ~np.dot(am, bm)
- return masked_array(d, mask=m)
+ return a.dot(b)
#####--------------------------------------------------------------------------
#---- --- arraysetops ---
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 1d4462306..807fc0ba6 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -2982,6 +2982,30 @@ class TestMaskedArrayMathMethods(TestCase):
X.trace() - sum(mXdiag.mask * X.diagonal(),
axis=0))
+ def test_dot(self):
+ # Tests dot on MaskedArrays.
+ (x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d
+ fx = mx.filled(0)
+ r = mx.dot(mx)
+ assert_almost_equal(r.filled(0), fx.dot(fx))
+ assert_(r.mask is nomask)
+
+ fX = mX.filled(0)
+ r = mX.dot(mX)
+ assert_almost_equal(r.filled(0), fX.dot(fX))
+ assert_(r.mask[1,3])
+ r1 = empty_like(r)
+ mX.dot(mX, r1)
+ assert_almost_equal(r, r1)
+
+ mYY = mXX.swapaxes(-1, -2)
+ fXX, fYY = mXX.filled(0), mYY.filled(0)
+ r = mXX.dot(mYY)
+ assert_almost_equal(r.filled(0), fXX.dot(fYY))
+ r1 = empty_like(r)
+ mXX.dot(mYY, r1)
+ assert_almost_equal(r, r1)
+
def test_varstd(self):
# Tests var & std on MaskedArrays.
(x, X, XX, m, mx, mX, mXX, m2x, m2X, m2XX) = self.d