summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2008-08-12 06:56:11 +0000
committerStefan van der Walt <stefan@sun.ac.za>2008-08-12 06:56:11 +0000
commit04a0ee16539077c6693f4fb36e388a3aa77e9ac2 (patch)
treed00f5eea308d4684704d11d04121134bd636fdc2 /numpy/lib/function_base.py
parent73498929287e76f82fa398b29b55e14e5ea015b0 (diff)
downloadnumpy-04a0ee16539077c6693f4fb36e388a3aa77e9ac2.tar.gz
More consistent nan-operations.
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py57
1 files changed, 37 insertions, 20 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index f010536da..5ca797864 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1356,6 +1356,38 @@ def place(arr, mask, vals):
"""
return _insert(arr, mask, vals)
+def _nanop(op, fill, a, axis=None):
+ """
+ General operation on arrays with not-a-number values.
+
+ Parameters
+ ----------
+ op : callable
+ Operation to perform.
+ fill : float
+ NaN values are set to fill before doing the operation.
+ a : array-like
+ Input array.
+ axis : {int, None}, optional
+ Axis along which the operation is computed.
+ By default the input is flattened.
+
+ Returns
+ -------
+ y : {ndarray, scalar}
+ Processed data.
+
+ """
+ y = array(a,subok=True)
+ mask = isnan(a)
+ if mask.all():
+ return np.nan
+
+ if not issubclass(y.dtype.type, np.integer):
+ y[mask] = fill
+
+ return op(y, axis=axis)
+
def nansum(a, axis=None):
"""
Sum the array along the given axis, treating NaNs as zero.
@@ -1381,10 +1413,7 @@ def nansum(a, axis=None):
array([ 2., 1.])
"""
- y = array(a,subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = 0
- return y.sum(axis)
+ return _nanop(np.sum, 0, a, axis)
def nanmin(a, axis=None):
"""
@@ -1413,10 +1442,7 @@ def nanmin(a, axis=None):
array([ 1., 3.])
"""
- y = array(a,subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = _nx.inf
- return y.min(axis)
+ return _nanop(np.min, np.inf, a, axis)
def nanargmin(a, axis=None):
"""
@@ -1426,10 +1452,7 @@ def nanargmin(a, axis=None):
Refer to `numpy.nanargmax` for detailed documentation.
"""
- y = array(a, subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = _nx.inf
- return y.argmin(axis)
+ return _nanop(np.argmin, np.inf, a, axis)
def nanmax(a, axis=None):
"""
@@ -1458,10 +1481,7 @@ def nanmax(a, axis=None):
array([ 2., 3.])
"""
- y = array(a, subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = -_nx.inf
- return y.max(axis)
+ return _nanop(np.max, -np.inf, a, axis)
def nanargmax(a, axis=None):
"""
@@ -1497,10 +1517,7 @@ def nanargmax(a, axis=None):
array([1, 0])
"""
- y = array(a,subok=True)
- if not issubclass(y.dtype.type, _nx.integer):
- y[isnan(a)] = -_nx.inf
- return y.argmax(axis)
+ return _nanop(np.argmax, -np.inf, a, axis)
def disp(mesg, device=None, linefeed=True):
"""Display a message to the given device (default is sys.stdout)