summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/tests/test_shape_base.py9
-rw-r--r--numpy/matrixlib/tests/test_defmatrix.py8
2 files changed, 17 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index c95894f94..6d24dd624 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -293,6 +293,15 @@ class TestExpandDims(object):
assert_warns(DeprecationWarning, expand_dims, a, -6)
assert_warns(DeprecationWarning, expand_dims, a, 5)
+ def test_subclasses(self):
+ a = np.arange(10).reshape((2, 5))
+ a = np.ma.array(a, mask=a%3 == 0)
+
+ expanded = np.expand_dims(a, axis=1)
+ assert_(isinstance(expanded, np.ma.MaskedArray))
+ assert_equal(expanded.shape, (2, 1, 5))
+ assert_equal(expanded.mask.shape, (2, 1, 5))
+
class TestArraySplit(object):
def test_integer_0_split(self):
diff --git a/numpy/matrixlib/tests/test_defmatrix.py b/numpy/matrixlib/tests/test_defmatrix.py
index e74e83cdb..272cd8d52 100644
--- a/numpy/matrixlib/tests/test_defmatrix.py
+++ b/numpy/matrixlib/tests/test_defmatrix.py
@@ -466,3 +466,11 @@ class TestShape(object):
def test_matrix_memory_sharing(self):
assert_(np.may_share_memory(self.m, self.m.ravel()))
assert_(not np.may_share_memory(self.m, self.m.flatten()))
+
+ def test_expand_dims_matrix(self):
+ # matrices are always 2d - so expand_dims only makes sense when the
+ # type is changed away from matrix.
+ a = np.arange(10).reshape((2, 5)).view(np.matrix)
+ expanded = np.expand_dims(a, axis=1)
+ assert_equal(expanded.ndim, 3)
+ assert_(not isinstance(expanded, np.matrix))