diff options
Diffstat (limited to 'numpy/core/_methods.py')
-rw-r--r-- | numpy/core/_methods.py | 80 |
1 files changed, 51 insertions, 29 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 86ddf4d17..75fd32ec8 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -50,20 +50,32 @@ def _prod(a, axis=None, dtype=None, out=None, keepdims=False, initial=_NoValue, where=True): return umr_prod(a, axis, dtype, out, keepdims, initial, where) -def _any(a, axis=None, dtype=None, out=None, keepdims=False): - return umr_any(a, axis, dtype, out, keepdims) - -def _all(a, axis=None, dtype=None, out=None, keepdims=False): - return umr_all(a, axis, dtype, out, keepdims) - -def _count_reduce_items(arr, axis): - if axis is None: - axis = tuple(range(arr.ndim)) - if not isinstance(axis, tuple): - axis = (axis,) - items = 1 - for ax in axis: - items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)] +def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + return umr_any(a, axis, dtype, out, keepdims, where=where) + +def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + return umr_all(a, axis, dtype, out, keepdims, where=where) + +def _count_reduce_items(arr, axis, keepdims=False, where=True): + # fast-path for the default case + if where is True: + # no boolean mask given, calculate items according to axis + if axis is None: + axis = tuple(range(arr.ndim)) + elif not isinstance(axis, tuple): + axis = (axis,) + items = nt.intp(1) + for ax in axis: + items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)] + else: + # TODO: Optimize case when `where` is broadcast along a non-reduction + # axis and full sum is more excessive than needed. + + # guarded to protect circular imports + from numpy.lib.stride_tricks import broadcast_to + # count True values in (potentially broadcasted) boolean mask + items = umr_sum(broadcast_to(where, arr.shape), axis, nt.intp, None, + keepdims) return items # Numpy 1.17.0, 2019-02-24 @@ -140,13 +152,13 @@ def _clip(a, min=None, max=None, out=None, *, casting=None, **kwargs): return _clip_dep_invoke_with_casting( um.clip, a, min, max, out=out, casting=casting, **kwargs) -def _mean(a, axis=None, dtype=None, out=None, keepdims=False): +def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): arr = asanyarray(a) is_float16_result = False - rcount = _count_reduce_items(arr, axis) - # Make this warning show up first - if rcount == 0: + + rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) + if umr_any(rcount == 0, axis=None): warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2) # Cast bool, unsigned int, and int to float64 by default @@ -157,7 +169,7 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False): dtype = mu.dtype('f4') is_float16_result = True - ret = umr_sum(arr, axis, dtype, out, keepdims) + ret = umr_sum(arr, axis, dtype, out, keepdims, where=where) if isinstance(ret, mu.ndarray): ret = um.true_divide( ret, rcount, out=ret, casting='unsafe', subok=False) @@ -173,12 +185,13 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False): return ret -def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): +def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, + where=True): arr = asanyarray(a) - rcount = _count_reduce_items(arr, axis) + rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) # Make this warning show up on top. - if ddof >= rcount: + if umr_any(ddof >= rcount, axis=None): warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) @@ -189,10 +202,18 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # Compute the mean. # Note that if dtype is not of inexact type then arraymean will # not be either. - arrmean = umr_sum(arr, axis, dtype, keepdims=True) + arrmean = umr_sum(arr, axis, dtype, keepdims=True, where=where) + # The shape of rcount has to match arrmean to not change the shape of out + # in broadcasting. Otherwise, it cannot be stored back to arrmean. + if rcount.ndim == 0: + # fast-path for default case when where is True + div = rcount + else: + # matching rcount to arrmean when where is specified as array + div = rcount.reshape(arrmean.shape) if isinstance(arrmean, mu.ndarray): - arrmean = um.true_divide( - arrmean, rcount, out=arrmean, casting='unsafe', subok=False) + arrmean = um.true_divide(arrmean, div, out=arrmean, casting='unsafe', + subok=False) else: arrmean = arrmean.dtype.type(arrmean / rcount) @@ -213,10 +234,10 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): else: x = um.multiply(x, um.conjugate(x), out=x).real - ret = umr_sum(x, axis, dtype, out, keepdims) + ret = umr_sum(x, axis, dtype, out, keepdims=keepdims, where=where) # Compute degrees of freedom and make sure it is not negative. - rcount = max([rcount - ddof, 0]) + rcount = um.maximum(rcount - ddof, 0) # divide by degrees of freedom if isinstance(ret, mu.ndarray): @@ -229,9 +250,10 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): return ret -def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): +def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, + where=True): ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof, - keepdims=keepdims) + keepdims=keepdims, where=where) if isinstance(ret, mu.ndarray): ret = um.sqrt(ret, out=ret) |