diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2008-08-12 06:56:11 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2008-08-12 06:56:11 +0000 |
commit | 04a0ee16539077c6693f4fb36e388a3aa77e9ac2 (patch) | |
tree | d00f5eea308d4684704d11d04121134bd636fdc2 /numpy/lib/function_base.py | |
parent | 73498929287e76f82fa398b29b55e14e5ea015b0 (diff) | |
download | numpy-04a0ee16539077c6693f4fb36e388a3aa77e9ac2.tar.gz |
More consistent nan-operations.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 57 |
1 files changed, 37 insertions, 20 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index f010536da..5ca797864 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1356,6 +1356,38 @@ def place(arr, mask, vals): """ return _insert(arr, mask, vals) +def _nanop(op, fill, a, axis=None): + """ + General operation on arrays with not-a-number values. + + Parameters + ---------- + op : callable + Operation to perform. + fill : float + NaN values are set to fill before doing the operation. + a : array-like + Input array. + axis : {int, None}, optional + Axis along which the operation is computed. + By default the input is flattened. + + Returns + ------- + y : {ndarray, scalar} + Processed data. + + """ + y = array(a,subok=True) + mask = isnan(a) + if mask.all(): + return np.nan + + if not issubclass(y.dtype.type, np.integer): + y[mask] = fill + + return op(y, axis=axis) + def nansum(a, axis=None): """ Sum the array along the given axis, treating NaNs as zero. @@ -1381,10 +1413,7 @@ def nansum(a, axis=None): array([ 2., 1.]) """ - y = array(a,subok=True) - if not issubclass(y.dtype.type, _nx.integer): - y[isnan(a)] = 0 - return y.sum(axis) + return _nanop(np.sum, 0, a, axis) def nanmin(a, axis=None): """ @@ -1413,10 +1442,7 @@ def nanmin(a, axis=None): array([ 1., 3.]) """ - y = array(a,subok=True) - if not issubclass(y.dtype.type, _nx.integer): - y[isnan(a)] = _nx.inf - return y.min(axis) + return _nanop(np.min, np.inf, a, axis) def nanargmin(a, axis=None): """ @@ -1426,10 +1452,7 @@ def nanargmin(a, axis=None): Refer to `numpy.nanargmax` for detailed documentation. """ - y = array(a, subok=True) - if not issubclass(y.dtype.type, _nx.integer): - y[isnan(a)] = _nx.inf - return y.argmin(axis) + return _nanop(np.argmin, np.inf, a, axis) def nanmax(a, axis=None): """ @@ -1458,10 +1481,7 @@ def nanmax(a, axis=None): array([ 2., 3.]) """ - y = array(a, subok=True) - if not issubclass(y.dtype.type, _nx.integer): - y[isnan(a)] = -_nx.inf - return y.max(axis) + return _nanop(np.max, -np.inf, a, axis) def nanargmax(a, axis=None): """ @@ -1497,10 +1517,7 @@ def nanargmax(a, axis=None): array([1, 0]) """ - y = array(a,subok=True) - if not issubclass(y.dtype.type, _nx.integer): - y[isnan(a)] = -_nx.inf - return y.argmax(axis) + return _nanop(np.argmax, -np.inf, a, axis) def disp(mesg, device=None, linefeed=True): """Display a message to the given device (default is sys.stdout) |