diff options
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index f76bf2a78..3d7428e69 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1435,15 +1435,25 @@ def _nanop(op, fill, a, axis=None): Processed data. """ - y = array(a,subok=True) + y = array(a, subok=True) mask = isnan(a) - if mask.all(): - return np.nan - if not issubclass(y.dtype.type, np.integer): - y[mask] = fill + # We only need to take care of NaN's in floating point arrays + if not np.issubdtype(y.dtype, int): + y[mask] = fill - return op(y, axis=axis) + res = op(y, axis=axis) + mask_all_along_axis = mask.all(axis=axis) + + # Along some axes, only nan's were encountered. As such, any values + # calculated along that axis should be set to nan. + if mask_all_along_axis.any(): + if np.isscalar(res): + res = np.nan + else: + res[mask_all_along_axis] = np.nan + + return res def nansum(a, axis=None): """ |