From 54b68dda8b29e3745b6b0da5f1d6a2b53f29e291 Mon Sep 17 00:00:00 2001 From: Joseph Fox-Rabinovitz Date: Mon, 11 Jul 2016 10:39:23 -0400 Subject: BUG: Fixed masked array behavior for scalar inputs to np.ma.atleast_*d --- numpy/ma/extras.py | 142 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 99 insertions(+), 43 deletions(-) (limited to 'numpy/ma/extras.py') diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index cbf7b6cdb..aefdd4e49 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -45,9 +45,8 @@ def issequence(seq): Is seq a sequence (ndarray, list or tuple)? """ - if isinstance(seq, (ndarray, tuple, list)): - return True - return False + return isinstance(seq, (ndarray, tuple, list)) + def count_masked(arr, axis=None): """ @@ -103,6 +102,7 @@ def count_masked(arr, axis=None): m = getmaskarray(arr) return m.sum(axis) + def masked_all(shape, dtype=float): """ Empty masked array with all elements masked. @@ -154,6 +154,7 @@ def masked_all(shape, dtype=float): mask=np.ones(shape, make_mask_descr(dtype))) return a + def masked_all_like(arr): """ Empty masked array with the properties of an existing array. @@ -221,6 +222,9 @@ class _fromnxfunction: as the wrapped NumPy function. The docstring of `newfunc` is adapted from the wrapped function as well, see `getdoc`. + This class should not be used directly. Instead, one of its extensions that + provides support for a specific type of input should be used. + Parameters ---------- funcname : str @@ -261,48 +265,100 @@ class _fromnxfunction: return def __call__(self, *args, **params): + pass + + +class _fromnxfunction_single(_fromnxfunction): + """ + A version of `_fromnxfunction` that is called with a single array + argument followed by auxiliary args that are passed verbatim for + both the data and mask calls. + """ + def __call__(self, x, *args, **params): func = getattr(np, self.__name__) - if len(args) == 1: - x = args[0] - if isinstance(x, ndarray): - _d = func(x.__array__(), **params) - _m = func(getmaskarray(x), **params) - return masked_array(_d, mask=_m) - elif isinstance(x, tuple) or isinstance(x, list): - _d = func(tuple([np.asarray(a) for a in x]), **params) - _m = func(tuple([getmaskarray(a) for a in x]), **params) - return masked_array(_d, mask=_m) - else: - _d = func(np.asarray(x), **params) - _m = func(getmaskarray(x), **params) - return masked_array(_d, mask=_m) + if isinstance(x, ndarray): + _d = func(x.__array__(), *args, **params) + _m = func(getmaskarray(x), *args, **params) + return masked_array(_d, mask=_m) else: - arrays = [] - args = list(args) - while len(args) > 0 and issequence(args[0]): - arrays.append(args.pop(0)) - res = [] - for x in arrays: - _d = func(np.asarray(x), *args, **params) - _m = func(getmaskarray(x), *args, **params) - res.append(masked_array(_d, mask=_m)) - return res - -atleast_1d = _fromnxfunction('atleast_1d') -atleast_2d = _fromnxfunction('atleast_2d') -atleast_3d = _fromnxfunction('atleast_3d') -#atleast_1d = np.atleast_1d -#atleast_2d = np.atleast_2d -#atleast_3d = np.atleast_3d - -vstack = row_stack = _fromnxfunction('vstack') -hstack = _fromnxfunction('hstack') -column_stack = _fromnxfunction('column_stack') -dstack = _fromnxfunction('dstack') - -hsplit = _fromnxfunction('hsplit') - -diagflat = _fromnxfunction('diagflat') + _d = func(np.asarray(x), *args, **params) + _m = func(getmaskarray(x), *args, **params) + return masked_array(_d, mask=_m) + + +class _fromnxfunction_seq(_fromnxfunction): + """ + A version of `_fromnxfunction` that is called with a single sequence + of arrays followed by auxiliary args that are passed verbatim for + both the data and mask calls. + """ + def __call__(self, x, *args, **params): + func = getattr(np, self.__name__) + _d = func(tuple([np.asarray(a) for a in x]), *args, **params) + _m = func(tuple([getmaskarray(a) for a in x]), *args, **params) + return masked_array(_d, mask=_m) + + +class _fromnxfunction_args(_fromnxfunction): + """ + A version of `_fromnxfunction` that is called with multiple array + arguments. The first non-array-like input marks the beginning of the + arguments that are passed verbatim for both the data and mask calls. + Array arguments are processed independently and the results are + returned in a list. If only one array is found, the return value is + just the processed array instead of a list. + """ + def __call__(self, *args, **params): + func = getattr(np, self.__name__) + arrays = [] + args = list(args) + while len(args) > 0 and issequence(args[0]): + arrays.append(args.pop(0)) + res = [] + for x in arrays: + _d = func(np.asarray(x), *args, **params) + _m = func(getmaskarray(x), *args, **params) + res.append(masked_array(_d, mask=_m)) + if len(arrays) == 1: + return res[0] + return res + + +class _fromnxfunction_allargs(_fromnxfunction): + """ + A version of `_fromnxfunction` that is called with multiple array + arguments. Similar to `_fromnxfunction_args` except that all args + are converted to arrays even if they are not so already. This makes + it possible to process scalars as 1-D arrays. Only keyword arguments + are passed through verbatim for the data and mask calls. Arrays + arguments are processed independently and the results are returned + in a list. If only one arg is present, the return value is just the + processed array instead of a list. + """ + def __call__(self, *args, **params): + func = getattr(np, self.__name__) + res = [] + for x in args: + _d = func(np.asarray(x), **params) + _m = func(getmaskarray(x), **params) + res.append(masked_array(_d, mask=_m)) + if len(args) == 1: + return res[0] + return res + + +atleast_1d = _fromnxfunction_allargs('atleast_1d') +atleast_2d = _fromnxfunction_allargs('atleast_2d') +atleast_3d = _fromnxfunction_allargs('atleast_3d') + +vstack = row_stack = _fromnxfunction_seq('vstack') +hstack = _fromnxfunction_seq('hstack') +column_stack = _fromnxfunction_seq('column_stack') +dstack = _fromnxfunction_seq('dstack') + +hsplit = _fromnxfunction_single('hsplit') + +diagflat = _fromnxfunction_single('diagflat') #####-------------------------------------------------------------------------- -- cgit v1.2.1