diff options
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 42 |
1 files changed, 26 insertions, 16 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index fb37801b9..d9a539f18 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -4284,8 +4284,17 @@ def percentile(a, q, axis=None, out=None, >>> assert not np.all(a == b) """ - q = array(q, dtype=np.float64, copy=True) - r, k = _ureduce(a, func=_percentile, q=q, axis=axis, out=out, + q = np.true_divide(q, 100.0) # handles the asarray for us too + if not _quantile_is_valid(q): + raise ValueError("Percentiles must be in the range [0, 100]") + return _quantile_unchecked( + a, q, axis, out, overwrite_input, interpolation, keepdims) + + +def _quantile_unchecked(a, q, axis=None, out=None, overwrite_input=False, + interpolation='linear', keepdims=False): + """Assumes that q is in [0, 1], and is an ndarray""" + r, k = _ureduce(a, func=_quantile_ureduce_func, q=q, axis=axis, out=out, overwrite_input=overwrite_input, interpolation=interpolation) if keepdims: @@ -4294,8 +4303,21 @@ def percentile(a, q, axis=None, out=None, return r -def _percentile(a, q, axis=None, out=None, - overwrite_input=False, interpolation='linear', keepdims=False): +def _quantile_is_valid(q): + # avoid expensive reductions, relevant for arrays with < O(1000) elements + if q.ndim == 1 and q.size < 10: + for i in range(q.size): + if q[i] < 0.0 or q[i] > 1.0: + return False + else: + # faster than any() + if np.count_nonzero(q < 0.0) or np.count_nonzero(q > 1.0): + return False + return True + + +def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False, + interpolation='linear', keepdims=False): a = asarray(a) if q.ndim == 0: # Do not allow 0-d arrays because following code fails for scalar @@ -4304,18 +4326,6 @@ def _percentile(a, q, axis=None, out=None, else: zerod = False - # avoid expensive reductions, relevant for arrays with < O(1000) elements - if q.size < 10: - for i in range(q.size): - if q[i] < 0. or q[i] > 100.: - raise ValueError("Percentiles must be in the range [0,100]") - q[i] /= 100. - else: - # faster than any() - if np.count_nonzero(q < 0.) or np.count_nonzero(q > 100.): - raise ValueError("Percentiles must be in the range [0,100]") - q /= 100. - # prepare a for partioning if overwrite_input: if axis is None: |