summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/shape_base.py6
-rw-r--r--numpy/ma/core.py52
2 files changed, 6 insertions, 52 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 65104115a..d31d8a939 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -536,7 +536,11 @@ def expand_dims(a, axis):
True
"""
- a = asarray(a)
+ if isinstance(a, matrix):
+ a = asarray(a)
+ else:
+ a = asanyarray(a)
+
shape = a.shape
if axis > a.ndim or axis < -a.ndim - 1:
# 2017-05-17, 1.13.0
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 5bfa51b12..74edeb274 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -43,7 +43,7 @@ from numpy.lib.function_base import angle
from numpy.compat import (
getargspec, formatargspec, long, basestring, unicode, bytes
)
-from numpy import expand_dims as n_expand_dims
+from numpy import expand_dims
from numpy.core.multiarray import normalize_axis_index
from numpy.core.numeric import normalize_axis_tuple
@@ -6795,56 +6795,6 @@ def diag(v, k=0):
return output
-def expand_dims(x, axis):
- """
- Expand the shape of an array.
-
- Expands the shape of the array by including a new axis before the one
- specified by the `axis` parameter. This function behaves the same as
- `numpy.expand_dims` but preserves masked elements.
-
- See Also
- --------
- numpy.expand_dims : Equivalent function in top-level NumPy module.
-
- Examples
- --------
- >>> import numpy.ma as ma
- >>> x = ma.array([1, 2, 4])
- >>> x[1] = ma.masked
- >>> x
- masked_array(data = [1 -- 4],
- mask = [False True False],
- fill_value = 999999)
- >>> np.expand_dims(x, axis=0)
- array([[1, 2, 4]])
- >>> ma.expand_dims(x, axis=0)
- masked_array(data =
- [[1 -- 4]],
- mask =
- [[False True False]],
- fill_value = 999999)
-
- The same result can be achieved using slicing syntax with `np.newaxis`.
-
- >>> x[np.newaxis, :]
- masked_array(data =
- [[1 -- 4]],
- mask =
- [[False True False]],
- fill_value = 999999)
-
- """
- result = n_expand_dims(x, axis)
- if isinstance(x, MaskedArray):
- new_shape = result.shape
- result = x.view()
- result.shape = new_shape
- if result._mask is not nomask:
- result._mask.shape = new_shape
- return result
-
-
def left_shift(a, n):
"""
Shift the bits of an integer to the left.