diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/shape_base.py | 6 | ||||
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 9 | ||||
-rw-r--r-- | numpy/ma/core.py | 52 | ||||
-rw-r--r-- | numpy/matrixlib/tests/test_defmatrix.py | 8 |
4 files changed, 23 insertions, 52 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 65104115a..d31d8a939 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -536,7 +536,11 @@ def expand_dims(a, axis): True """ - a = asarray(a) + if isinstance(a, matrix): + a = asarray(a) + else: + a = asanyarray(a) + shape = a.shape if axis > a.ndim or axis < -a.ndim - 1: # 2017-05-17, 1.13.0 diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index c95894f94..6d24dd624 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -293,6 +293,15 @@ class TestExpandDims(object): assert_warns(DeprecationWarning, expand_dims, a, -6) assert_warns(DeprecationWarning, expand_dims, a, 5) + def test_subclasses(self): + a = np.arange(10).reshape((2, 5)) + a = np.ma.array(a, mask=a%3 == 0) + + expanded = np.expand_dims(a, axis=1) + assert_(isinstance(expanded, np.ma.MaskedArray)) + assert_equal(expanded.shape, (2, 1, 5)) + assert_equal(expanded.mask.shape, (2, 1, 5)) + class TestArraySplit(object): def test_integer_0_split(self): diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 5bfa51b12..74edeb274 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -43,7 +43,7 @@ from numpy.lib.function_base import angle from numpy.compat import ( getargspec, formatargspec, long, basestring, unicode, bytes ) -from numpy import expand_dims as n_expand_dims +from numpy import expand_dims from numpy.core.multiarray import normalize_axis_index from numpy.core.numeric import normalize_axis_tuple @@ -6795,56 +6795,6 @@ def diag(v, k=0): return output -def expand_dims(x, axis): - """ - Expand the shape of an array. - - Expands the shape of the array by including a new axis before the one - specified by the `axis` parameter. This function behaves the same as - `numpy.expand_dims` but preserves masked elements. - - See Also - -------- - numpy.expand_dims : Equivalent function in top-level NumPy module. - - Examples - -------- - >>> import numpy.ma as ma - >>> x = ma.array([1, 2, 4]) - >>> x[1] = ma.masked - >>> x - masked_array(data = [1 -- 4], - mask = [False True False], - fill_value = 999999) - >>> np.expand_dims(x, axis=0) - array([[1, 2, 4]]) - >>> ma.expand_dims(x, axis=0) - masked_array(data = - [[1 -- 4]], - mask = - [[False True False]], - fill_value = 999999) - - The same result can be achieved using slicing syntax with `np.newaxis`. - - >>> x[np.newaxis, :] - masked_array(data = - [[1 -- 4]], - mask = - [[False True False]], - fill_value = 999999) - - """ - result = n_expand_dims(x, axis) - if isinstance(x, MaskedArray): - new_shape = result.shape - result = x.view() - result.shape = new_shape - if result._mask is not nomask: - result._mask.shape = new_shape - return result - - def left_shift(a, n): """ Shift the bits of an integer to the left. diff --git a/numpy/matrixlib/tests/test_defmatrix.py b/numpy/matrixlib/tests/test_defmatrix.py index e74e83cdb..272cd8d52 100644 --- a/numpy/matrixlib/tests/test_defmatrix.py +++ b/numpy/matrixlib/tests/test_defmatrix.py @@ -466,3 +466,11 @@ class TestShape(object): def test_matrix_memory_sharing(self): assert_(np.may_share_memory(self.m, self.m.ravel())) assert_(not np.may_share_memory(self.m, self.m.flatten())) + + def test_expand_dims_matrix(self): + # matrices are always 2d - so expand_dims only makes sense when the + # type is changed away from matrix. + a = np.arange(10).reshape((2, 5)).view(np.matrix) + expanded = np.expand_dims(a, axis=1) + assert_equal(expanded.ndim, 3) + assert_(not isinstance(expanded, np.matrix)) |