diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-06-04 13:52:27 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-06-21 12:16:10 +0100 |
commit | bbe994f9a860ce41addfd785cf91918ce1562ba6 (patch) | |
tree | cf67d2609791cf5e70e6021a83cacec14c107ef0 /numpy/lib/nanfunctions.py | |
parent | 4d9f4d5d5e00076be64c6d43e691df8273f749ef (diff) | |
download | numpy-bbe994f9a860ce41addfd785cf91918ce1562ba6.tar.gz |
MAINT: Factor out code duplicated by nanmedian and nanpercentile
Diffstat (limited to 'numpy/lib/nanfunctions.py')
-rw-r--r-- | numpy/lib/nanfunctions.py | 96 |
1 files changed, 53 insertions, 43 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py index 97bd33492..791218886 100644 --- a/numpy/lib/nanfunctions.py +++ b/numpy/lib/nanfunctions.py @@ -106,6 +106,46 @@ def _copyto(a, val, mask): return a +def _remove_nan_1d(arr1d, overwrite_input=False): + """ + Equivalent to arr1d[~arr1d.isnan()], but in a different order + + Presumably faster as it incurs fewer copies + + Parameters + ---------- + arr1d : ndarray + Array to remove nans from + overwrite_input : bool + True if `arr1d` can be modified in place + + Returns + ------- + res : ndarray + Array with nan elements removed + overwrite_input : bool + True if `res` can be modified in place, given the constraint on the + input + """ + + c = np.isnan(arr1d) + s = np.nonzero(c)[0] + if s.size == arr1d.size: + warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=4) + return arr1d[:0], True + elif s.size == 0: + return arr1d, overwrite_input + else: + if not overwrite_input: + arr1d = arr1d.copy() + # select non-nans at end of array + enonan = arr1d[-s.size:][~c[-s.size:]] + # fill nans in beginning of array with non-nans of end + arr1d[s[:enonan.size]] = enonan + + return arr1d[:-s.size], True + + def _divide_by_count(a, b, out=None): """ Compute a/b ignoring invalid results. If `a` is an array the division @@ -836,24 +876,12 @@ def _nanmedian1d(arr1d, overwrite_input=False): Private function for rank 1 arrays. Compute the median ignoring NaNs. See nanmedian for parameter usage """ - c = np.isnan(arr1d) - s = np.nonzero(c)[0] - if s.size == arr1d.size: - warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=3) + arr1d, overwrite_input = _remove_nan_1d(arr1d, + overwrite_input=overwrite_input) + if arr1d.size == 0: return np.nan - elif s.size == 0: - return np.median(arr1d, overwrite_input=overwrite_input) - else: - if overwrite_input: - x = arr1d - else: - x = arr1d.copy() - # select non-nans at end of array - enonan = arr1d[-s.size:][~c[-s.size:]] - # fill nans in beginning of array with non-nans of end - x[s[:enonan.size]] = enonan - # slice nans away - return np.median(x[:-s.size], overwrite_input=True) + + return np.median(arr1d, overwrite_input=overwrite_input) def _nanmedian(a, axis=None, out=None, overwrite_input=False): @@ -1155,34 +1183,16 @@ def _nanpercentile(a, q, axis=None, out=None, overwrite_input=False, def _nanpercentile1d(arr1d, q, overwrite_input=False, interpolation='linear'): """ - Private function for rank 1 arrays. Compute percentile ignoring - NaNs. - + Private function for rank 1 arrays. Compute percentile ignoring NaNs. See nanpercentile for parameter usage """ - c = np.isnan(arr1d) - s = np.nonzero(c)[0] - if s.size == arr1d.size: - warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=3) - if q.ndim == 0: - return np.nan - else: - return np.nan * np.ones((len(q),)) - elif s.size == 0: - return np.percentile(arr1d, q, overwrite_input=overwrite_input, - interpolation=interpolation) - else: - if overwrite_input: - x = arr1d - else: - x = arr1d.copy() - # select non-nans at end of array - enonan = arr1d[-s.size:][~c[-s.size:]] - # fill nans in beginning of array with non-nans of end - x[s[:enonan.size]] = enonan - # slice nans away - return np.percentile(x[:-s.size], q, overwrite_input=True, - interpolation=interpolation) + arr1d, overwrite_input = _remove_nan_1d(arr1d, + overwrite_input=overwrite_input) + if arr1d.size == 0: + return np.full(q.shape, np.nan)[()] # convert to scalar + + return np.percentile(arr1d, q, overwrite_input=overwrite_input, + interpolation=interpolation) def nanvar(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue): |