summaryrefslogtreecommitdiff
path: root/numpy/ma/tests/test_extras.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-11-07 12:58:06 -0700
committerCharles Harris <charlesr.harris@gmail.com>2015-11-10 17:45:50 -0700
commit3e82108f701b0ce6cbb9e16f5d7fd4c3cb27a97c (patch)
tree5fb9ba02477cad389e8d6ec74e4be48b9aa687c5 /numpy/ma/tests/test_extras.py
parentcf9f1907b99d06291ab16ad4d2105a871f56f7d9 (diff)
downloadnumpy-3e82108f701b0ce6cbb9e16f5d7fd4c3cb27a97c.tar.gz
TST: Add tests for ma.dot.
Test that ma.dot always returns a masked array. Test basic that the new out parameter in ma.dot works.
Diffstat (limited to 'numpy/ma/tests/test_extras.py')
-rw-r--r--numpy/ma/tests/test_extras.py54
1 files changed, 35 insertions, 19 deletions
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index c41c629fc..6138d0573 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -538,26 +538,26 @@ class TestCompressFunctions(TestCase):
m = [1, 0, 0, 0, 0, 0]
a = masked_array(n, mask=m).reshape(2, 3)
b = masked_array(n, mask=m).reshape(3, 2)
- c = dot(a, b, True)
+ c = dot(a, b, strict=True)
assert_equal(c.mask, [[1, 1], [1, 0]])
- c = dot(b, a, True)
+ c = dot(b, a, strict=True)
assert_equal(c.mask, [[1, 1, 1], [1, 0, 0], [1, 0, 0]])
- c = dot(a, b, False)
+ c = dot(a, b, strict=False)
assert_equal(c, np.dot(a.filled(0), b.filled(0)))
- c = dot(b, a, False)
+ c = dot(b, a, strict=False)
assert_equal(c, np.dot(b.filled(0), a.filled(0)))
#
m = [0, 0, 0, 0, 0, 1]
a = masked_array(n, mask=m).reshape(2, 3)
b = masked_array(n, mask=m).reshape(3, 2)
- c = dot(a, b, True)
+ c = dot(a, b, strict=True)
assert_equal(c.mask, [[0, 1], [1, 1]])
- c = dot(b, a, True)
+ c = dot(b, a, strict=True)
assert_equal(c.mask, [[0, 0, 1], [0, 0, 1], [1, 1, 1]])
- c = dot(a, b, False)
+ c = dot(a, b, strict=False)
assert_equal(c, np.dot(a.filled(0), b.filled(0)))
assert_equal(c, dot(a, b))
- c = dot(b, a, False)
+ c = dot(b, a, strict=False)
assert_equal(c, np.dot(b.filled(0), a.filled(0)))
#
m = [0, 0, 0, 0, 0, 0]
@@ -570,37 +570,53 @@ class TestCompressFunctions(TestCase):
#
a = masked_array(n, mask=[1, 0, 0, 0, 0, 0]).reshape(2, 3)
b = masked_array(n, mask=[0, 0, 0, 0, 0, 0]).reshape(3, 2)
- c = dot(a, b, True)
+ c = dot(a, b, strict=True)
assert_equal(c.mask, [[1, 1], [0, 0]])
- c = dot(a, b, False)
+ c = dot(a, b, strict=False)
assert_equal(c, np.dot(a.filled(0), b.filled(0)))
- c = dot(b, a, True)
+ c = dot(b, a, strict=True)
assert_equal(c.mask, [[1, 0, 0], [1, 0, 0], [1, 0, 0]])
- c = dot(b, a, False)
+ c = dot(b, a, strict=False)
assert_equal(c, np.dot(b.filled(0), a.filled(0)))
#
a = masked_array(n, mask=[0, 0, 0, 0, 0, 1]).reshape(2, 3)
b = masked_array(n, mask=[0, 0, 0, 0, 0, 0]).reshape(3, 2)
- c = dot(a, b, True)
+ c = dot(a, b, strict=True)
assert_equal(c.mask, [[0, 0], [1, 1]])
c = dot(a, b)
assert_equal(c, np.dot(a.filled(0), b.filled(0)))
- c = dot(b, a, True)
+ c = dot(b, a, strict=True)
assert_equal(c.mask, [[0, 0, 1], [0, 0, 1], [0, 0, 1]])
- c = dot(b, a, False)
+ c = dot(b, a, strict=False)
assert_equal(c, np.dot(b.filled(0), a.filled(0)))
#
a = masked_array(n, mask=[0, 0, 0, 0, 0, 1]).reshape(2, 3)
b = masked_array(n, mask=[0, 0, 1, 0, 0, 0]).reshape(3, 2)
- c = dot(a, b, True)
+ c = dot(a, b, strict=True)
assert_equal(c.mask, [[1, 0], [1, 1]])
- c = dot(a, b, False)
+ c = dot(a, b, strict=False)
assert_equal(c, np.dot(a.filled(0), b.filled(0)))
- c = dot(b, a, True)
+ c = dot(b, a, strict=True)
assert_equal(c.mask, [[0, 0, 1], [1, 1, 1], [0, 0, 1]])
- c = dot(b, a, False)
+ c = dot(b, a, strict=False)
assert_equal(c, np.dot(b.filled(0), a.filled(0)))
+ def test_dot_returns_maskedarray(self):
+ # See gh-6611
+ a = np.eye(3)
+ b = array(a)
+ assert_(type(dot(a, a)) is MaskedArray)
+ assert_(type(dot(a, b)) is MaskedArray)
+ assert_(type(dot(b, a)) is MaskedArray)
+ assert_(type(dot(b, b)) is MaskedArray)
+
+ def test_dot_out(self):
+ a = array(np.eye(3))
+ out = array(np.zeros((3, 3)))
+ res = dot(a, a, out=out)
+ assert_(res is out)
+ assert_equal(a, res)
+
class TestApplyAlongAxis(TestCase):
# Tests 2D functions