summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexander Belopolsky <a@enlnt.com>2015-03-22 21:15:59 -0400
committerAlexander Belopolsky <a@enlnt.com>2015-03-22 21:15:59 -0400
commit3fdf1883369c5cd40ad7022ad46a629f2284a7a3 (patch)
tree131ebd2558bed692ff5e0f3eb759a8d12e6e1219
parente3101647ef7c262fa6b4ddc8fdf79453f8a1e05c (diff)
downloadnumpy-3fdf1883369c5cd40ad7022ad46a629f2284a7a3.tar.gz
BUG: Implemented MaskedArray.dot
MaskedArray used to inherit ndarray.dot which ignored masks in the operands. Fixes issue #5185.
-rw-r--r--numpy/ma/core.py16
-rw-r--r--numpy/ma/extras.py8
2 files changed, 17 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 79924351c..964636595 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -4522,6 +4522,22 @@ class MaskedArray(ndarray):
return D.astype(dtype).filled(0).sum(axis=None, out=out)
trace.__doc__ = ndarray.trace.__doc__
+ def dot(self, other, out=None):
+ am = ~getmaskarray(self)
+ bm = ~getmaskarray(other)
+ if out is None:
+ d = np.dot(filled(self, 0), filled(other, 0))
+ m = ~np.dot(am, bm)
+ return masked_array(d, mask=m)
+ d = self.filled(0).dot(other.filled(0), out)
+ if out.mask.shape != d.shape:
+ out._mask = numpy.empty(d.shape, MaskType)
+ np.dot(am, bm, out._mask)
+ np.logical_not(out._mask, out._mask)
+ return out
+ dot.__doc__ = ndarray.dot.__doc__
+
+
def sum(self, axis=None, dtype=None, out=None):
"""
Return the sum of the array elements over the given axis.
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 6d812964d..d389099ae 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -1047,13 +1047,7 @@ def dot(a, b, strict=False):
if strict and (a.ndim == 2) and (b.ndim == 2):
a = mask_rows(a)
b = mask_cols(b)
- #
- d = np.dot(filled(a, 0), filled(b, 0))
- #
- am = (~getmaskarray(a))
- bm = (~getmaskarray(b))
- m = ~np.dot(am, bm)
- return masked_array(d, mask=m)
+ return a.dot(b)
#####--------------------------------------------------------------------------
#---- --- arraysetops ---