diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2009-03-26 23:28:03 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2009-03-26 23:28:03 +0000 |
commit | da8312aed0218cb44a856dfeb73a06d224a29ff4 (patch) | |
tree | 231b41f0948ad590e2a920545158c70004d6a066 /numpy/lib/function_base.py | |
parent | 7b751f66c7feb71646f0c2540aca2e5e67cd5db5 (diff) | |
download | numpy-da8312aed0218cb44a856dfeb73a06d224a29ff4.tar.gz |
Fix nanmin, -max etc. to handle axis argument correctly.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index f76bf2a78..3d7428e69 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1435,15 +1435,25 @@ def _nanop(op, fill, a, axis=None): Processed data. """ - y = array(a,subok=True) + y = array(a, subok=True) mask = isnan(a) - if mask.all(): - return np.nan - if not issubclass(y.dtype.type, np.integer): - y[mask] = fill + # We only need to take care of NaN's in floating point arrays + if not np.issubdtype(y.dtype, int): + y[mask] = fill - return op(y, axis=axis) + res = op(y, axis=axis) + mask_all_along_axis = mask.all(axis=axis) + + # Along some axes, only nan's were encountered. As such, any values + # calculated along that axis should be set to nan. + if mask_all_along_axis.any(): + if np.isscalar(res): + res = np.nan + else: + res[mask_all_along_axis] = np.nan + + return res def nansum(a, axis=None): """ |