summaryrefslogtreecommitdiff
path: root/numpy/ma/tests/test_extras.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/tests/test_extras.py')
-rw-r--r--numpy/ma/tests/test_extras.py35
1 files changed, 30 insertions, 5 deletions
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index bb59aad96..33c4b1922 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -1202,7 +1202,7 @@ class TestArraySetOps(TestCase):
class TestShapeBase(TestCase):
- def test_atleast2d(self):
+ def test_atleast_2d(self):
# Test atleast_2d
a = masked_array([0, 1, 2], mask=[0, 1, 0])
b = atleast_2d(a)
@@ -1210,21 +1210,46 @@ class TestShapeBase(TestCase):
assert_equal(b.mask.shape, b.data.shape)
assert_equal(a.shape, (3,))
assert_equal(a.mask.shape, a.data.shape)
+ assert_equal(b.mask.shape, b.data.shape)
def test_shape_scalar(self):
# the atleast and diagflat function should work with scalars
# GitHub issue #3367
+ # Additionally, the atleast functions should accept multiple scalars
+ # correctly
b = atleast_1d(1.0)
- assert_equal(b.shape, (1, ))
- assert_equal(b.mask.shape, b.data.shape)
+ assert_equal(b.shape, (1,))
+ assert_equal(b.mask.shape, b.shape)
+ assert_equal(b.data.shape, b.shape)
+
+ b = atleast_1d(1.0, 2.0)
+ for a in b:
+ assert_equal(a.shape, (1,))
+ assert_equal(a.mask.shape, a.shape)
+ assert_equal(a.data.shape, a.shape)
b = atleast_2d(1.0)
assert_equal(b.shape, (1, 1))
- assert_equal(b.mask.shape, b.data.shape)
+ assert_equal(b.mask.shape, b.shape)
+ assert_equal(b.data.shape, b.shape)
+
+ b = atleast_2d(1.0, 2.0)
+ for a in b:
+ assert_equal(a.shape, (1, 1))
+ assert_equal(a.mask.shape, a.shape)
+ assert_equal(a.data.shape, a.shape)
b = atleast_3d(1.0)
assert_equal(b.shape, (1, 1, 1))
- assert_equal(b.mask.shape, b.data.shape)
+ assert_equal(b.mask.shape, b.shape)
+ assert_equal(b.data.shape, b.shape)
+
+ b = atleast_3d(1.0, 2.0)
+ for a in b:
+ assert_equal(a.shape, (1, 1, 1))
+ assert_equal(a.mask.shape, a.shape)
+ assert_equal(a.data.shape, a.shape)
+
b = diagflat(1.0)
assert_equal(b.shape, (1, 1))