summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
authorJoseph Fox-Rabinovitz <joseph.r.fox-rabinovitz@nasa.gov>2016-07-11 10:39:23 -0400
committerJoseph Fox-Rabinovitz <joseph.r.fox-rabinovitz@nasa.gov>2016-08-24 12:39:56 -0400
commit54b68dda8b29e3745b6b0da5f1d6a2b53f29e291 (patch)
tree4560f76441f5156ca418752cf26459a4b31ca7a3 /numpy/ma/extras.py
parentd92c75b46fa316a7711049c278bfb4d4f11b2349 (diff)
downloadnumpy-54b68dda8b29e3745b6b0da5f1d6a2b53f29e291.tar.gz
BUG: Fixed masked array behavior for scalar inputs to np.ma.atleast_*d
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py142
1 files changed, 99 insertions, 43 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index cbf7b6cdb..aefdd4e49 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -45,9 +45,8 @@ def issequence(seq):
Is seq a sequence (ndarray, list or tuple)?
"""
- if isinstance(seq, (ndarray, tuple, list)):
- return True
- return False
+ return isinstance(seq, (ndarray, tuple, list))
+
def count_masked(arr, axis=None):
"""
@@ -103,6 +102,7 @@ def count_masked(arr, axis=None):
m = getmaskarray(arr)
return m.sum(axis)
+
def masked_all(shape, dtype=float):
"""
Empty masked array with all elements masked.
@@ -154,6 +154,7 @@ def masked_all(shape, dtype=float):
mask=np.ones(shape, make_mask_descr(dtype)))
return a
+
def masked_all_like(arr):
"""
Empty masked array with the properties of an existing array.
@@ -221,6 +222,9 @@ class _fromnxfunction:
as the wrapped NumPy function. The docstring of `newfunc` is adapted from
the wrapped function as well, see `getdoc`.
+ This class should not be used directly. Instead, one of its extensions that
+ provides support for a specific type of input should be used.
+
Parameters
----------
funcname : str
@@ -261,48 +265,100 @@ class _fromnxfunction:
return
def __call__(self, *args, **params):
+ pass
+
+
+class _fromnxfunction_single(_fromnxfunction):
+ """
+ A version of `_fromnxfunction` that is called with a single array
+ argument followed by auxiliary args that are passed verbatim for
+ both the data and mask calls.
+ """
+ def __call__(self, x, *args, **params):
func = getattr(np, self.__name__)
- if len(args) == 1:
- x = args[0]
- if isinstance(x, ndarray):
- _d = func(x.__array__(), **params)
- _m = func(getmaskarray(x), **params)
- return masked_array(_d, mask=_m)
- elif isinstance(x, tuple) or isinstance(x, list):
- _d = func(tuple([np.asarray(a) for a in x]), **params)
- _m = func(tuple([getmaskarray(a) for a in x]), **params)
- return masked_array(_d, mask=_m)
- else:
- _d = func(np.asarray(x), **params)
- _m = func(getmaskarray(x), **params)
- return masked_array(_d, mask=_m)
+ if isinstance(x, ndarray):
+ _d = func(x.__array__(), *args, **params)
+ _m = func(getmaskarray(x), *args, **params)
+ return masked_array(_d, mask=_m)
else:
- arrays = []
- args = list(args)
- while len(args) > 0 and issequence(args[0]):
- arrays.append(args.pop(0))
- res = []
- for x in arrays:
- _d = func(np.asarray(x), *args, **params)
- _m = func(getmaskarray(x), *args, **params)
- res.append(masked_array(_d, mask=_m))
- return res
-
-atleast_1d = _fromnxfunction('atleast_1d')
-atleast_2d = _fromnxfunction('atleast_2d')
-atleast_3d = _fromnxfunction('atleast_3d')
-#atleast_1d = np.atleast_1d
-#atleast_2d = np.atleast_2d
-#atleast_3d = np.atleast_3d
-
-vstack = row_stack = _fromnxfunction('vstack')
-hstack = _fromnxfunction('hstack')
-column_stack = _fromnxfunction('column_stack')
-dstack = _fromnxfunction('dstack')
-
-hsplit = _fromnxfunction('hsplit')
-
-diagflat = _fromnxfunction('diagflat')
+ _d = func(np.asarray(x), *args, **params)
+ _m = func(getmaskarray(x), *args, **params)
+ return masked_array(_d, mask=_m)
+
+
+class _fromnxfunction_seq(_fromnxfunction):
+ """
+ A version of `_fromnxfunction` that is called with a single sequence
+ of arrays followed by auxiliary args that are passed verbatim for
+ both the data and mask calls.
+ """
+ def __call__(self, x, *args, **params):
+ func = getattr(np, self.__name__)
+ _d = func(tuple([np.asarray(a) for a in x]), *args, **params)
+ _m = func(tuple([getmaskarray(a) for a in x]), *args, **params)
+ return masked_array(_d, mask=_m)
+
+
+class _fromnxfunction_args(_fromnxfunction):
+ """
+ A version of `_fromnxfunction` that is called with multiple array
+ arguments. The first non-array-like input marks the beginning of the
+ arguments that are passed verbatim for both the data and mask calls.
+ Array arguments are processed independently and the results are
+ returned in a list. If only one array is found, the return value is
+ just the processed array instead of a list.
+ """
+ def __call__(self, *args, **params):
+ func = getattr(np, self.__name__)
+ arrays = []
+ args = list(args)
+ while len(args) > 0 and issequence(args[0]):
+ arrays.append(args.pop(0))
+ res = []
+ for x in arrays:
+ _d = func(np.asarray(x), *args, **params)
+ _m = func(getmaskarray(x), *args, **params)
+ res.append(masked_array(_d, mask=_m))
+ if len(arrays) == 1:
+ return res[0]
+ return res
+
+
+class _fromnxfunction_allargs(_fromnxfunction):
+ """
+ A version of `_fromnxfunction` that is called with multiple array
+ arguments. Similar to `_fromnxfunction_args` except that all args
+ are converted to arrays even if they are not so already. This makes
+ it possible to process scalars as 1-D arrays. Only keyword arguments
+ are passed through verbatim for the data and mask calls. Arrays
+ arguments are processed independently and the results are returned
+ in a list. If only one arg is present, the return value is just the
+ processed array instead of a list.
+ """
+ def __call__(self, *args, **params):
+ func = getattr(np, self.__name__)
+ res = []
+ for x in args:
+ _d = func(np.asarray(x), **params)
+ _m = func(getmaskarray(x), **params)
+ res.append(masked_array(_d, mask=_m))
+ if len(args) == 1:
+ return res[0]
+ return res
+
+
+atleast_1d = _fromnxfunction_allargs('atleast_1d')
+atleast_2d = _fromnxfunction_allargs('atleast_2d')
+atleast_3d = _fromnxfunction_allargs('atleast_3d')
+
+vstack = row_stack = _fromnxfunction_seq('vstack')
+hstack = _fromnxfunction_seq('hstack')
+column_stack = _fromnxfunction_seq('column_stack')
+dstack = _fromnxfunction_seq('dstack')
+
+hsplit = _fromnxfunction_single('hsplit')
+
+diagflat = _fromnxfunction_single('diagflat')
#####--------------------------------------------------------------------------