summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/1.17.0-notes.rst6
m---------doc/sphinxext0
-rw-r--r--numpy/distutils/system_info.py60
-rw-r--r--numpy/lib/tests/test_type_check.py40
-rw-r--r--numpy/lib/type_check.py46
5 files changed, 123 insertions, 29 deletions
diff --git a/doc/release/1.17.0-notes.rst b/doc/release/1.17.0-notes.rst
index 7857400e8..1155449a7 100644
--- a/doc/release/1.17.0-notes.rst
+++ b/doc/release/1.17.0-notes.rst
@@ -197,6 +197,12 @@ The boolean and integer types are incapable of storing ``np.nan`` and ``np.inf``
which allows us to provide specialized ufuncs that are up to 250x faster than the current
approach.
+New keywords added to ``np.nan_to_num``
+---------------------------------------
+``np.nan_to_num`` now accepts keywords ``nan``, ``posinf`` and ``neginf`` allowing the
+user to define the value to replace the ``nan``, positive and negative ``np.inf`` values
+respectively.
+
Changes
=======
diff --git a/doc/sphinxext b/doc/sphinxext
-Subproject e47b9404963ad2a75a11d167416038275c50d1c
+Subproject a482f66913c1079d7439770f0119b55376bb1b8
diff --git a/numpy/distutils/system_info.py b/numpy/distutils/system_info.py
index 4d923ad26..8a42434ff 100644
--- a/numpy/distutils/system_info.py
+++ b/numpy/distutils/system_info.py
@@ -1689,23 +1689,46 @@ class blas_info(system_info):
else:
info['include_dirs'] = self.get_include_dirs()
if platform.system() == 'Windows':
- # The check for windows is needed because has_cblas uses the
+ # The check for windows is needed because get_cblas_libs uses the
# same compiler that was used to compile Python and msvc is
# often not installed when mingw is being used. This rough
# treatment is not desirable, but windows is tricky.
info['language'] = 'f77' # XXX: is it generally true?
else:
- lib = self.has_cblas(info)
+ lib = self.get_cblas_libs(info)
if lib is not None:
info['language'] = 'c'
- info['libraries'] = [lib]
+ info['libraries'] = lib
info['define_macros'] = [('HAVE_CBLAS', None)]
self.set_info(**info)
- def has_cblas(self, info):
+ def get_cblas_libs(self, info):
+ """ Check whether we can link with CBLAS interface
+
+ This method will search through several combinations of libraries
+ to check whether CBLAS is present:
+
+ 1. Libraries in ``info['libraries']``, as is
+ 2. As 1. but also explicitly adding ``'cblas'`` as a library
+ 3. As 1. but also explicitly adding ``'blas'`` as a library
+ 4. Check only library ``'cblas'``
+ 5. Check only library ``'blas'``
+
+ Parameters
+ ----------
+ info : dict
+ system information dictionary for compilation and linking
+
+ Returns
+ -------
+ libraries : list of str or None
+ a list of libraries that enables the use of CBLAS interface.
+ Returns None if not found or a compilation error occurs.
+
+ Since 1.17 returns a list.
+ """
# primitive cblas check by looking for the header and trying to link
# cblas or blas
- res = False
c = customized_ccompiler()
tmpdir = tempfile.mkdtemp()
s = """#include <cblas.h>
@@ -1724,29 +1747,26 @@ class blas_info(system_info):
# check we can compile (find headers)
obj = c.compile([src], output_dir=tmpdir,
include_dirs=self.get_include_dirs())
+ except (distutils.ccompiler.CompileError, distutils.ccompiler.LinkError):
+ return None
- # check we can link (find library)
- # some systems have separate cblas and blas libs. First
- # check for cblas lib, and if not present check for blas lib.
+ # check we can link (find library)
+ # some systems have separate cblas and blas libs.
+ for libs in [info['libraries'], ['cblas'] + info['libraries'],
+ ['blas'] + info['libraries'], ['cblas'], ['blas']]:
try:
c.link_executable(obj, os.path.join(tmpdir, "a.out"),
- libraries=["cblas"],
+ libraries=libs,
library_dirs=info['library_dirs'],
extra_postargs=info.get('extra_link_args', []))
- res = "cblas"
+ return libs
+ # This breaks the for loop
+ break
except distutils.ccompiler.LinkError:
- c.link_executable(obj, os.path.join(tmpdir, "a.out"),
- libraries=["blas"],
- library_dirs=info['library_dirs'],
- extra_postargs=info.get('extra_link_args', []))
- res = "blas"
- except distutils.ccompiler.CompileError:
- res = None
- except distutils.ccompiler.LinkError:
- res = None
+ pass
finally:
shutil.rmtree(tmpdir)
- return res
+ return None
class openblas_info(blas_info):
diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py
index 2982ca31a..b3f114b92 100644
--- a/numpy/lib/tests/test_type_check.py
+++ b/numpy/lib/tests/test_type_check.py
@@ -360,6 +360,14 @@ class TestNanToNum(object):
assert_(vals[1] == 0)
assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2]))
assert_equal(type(vals), np.ndarray)
+
+ # perform the same tests but with nan, posinf and neginf keywords
+ with np.errstate(divide='ignore', invalid='ignore'):
+ vals = nan_to_num(np.array((-1., 0, 1))/0.,
+ nan=10, posinf=20, neginf=30)
+ assert_equal(vals, [30, 10, 20])
+ assert_all(np.isfinite(vals[[0, 2]]))
+ assert_equal(type(vals), np.ndarray)
# perform the same test but in-place
with np.errstate(divide='ignore', invalid='ignore'):
@@ -371,26 +379,48 @@ class TestNanToNum(object):
assert_(vals[1] == 0)
assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2]))
assert_equal(type(vals), np.ndarray)
+
+ # perform the same test but in-place
+ with np.errstate(divide='ignore', invalid='ignore'):
+ vals = np.array((-1., 0, 1))/0.
+ result = nan_to_num(vals, copy=False, nan=10, posinf=20, neginf=30)
+
+ assert_(result is vals)
+ assert_equal(vals, [30, 10, 20])
+ assert_all(np.isfinite(vals[[0, 2]]))
+ assert_equal(type(vals), np.ndarray)
def test_array(self):
vals = nan_to_num([1])
assert_array_equal(vals, np.array([1], int))
assert_equal(type(vals), np.ndarray)
+ vals = nan_to_num([1], nan=10, posinf=20, neginf=30)
+ assert_array_equal(vals, np.array([1], int))
+ assert_equal(type(vals), np.ndarray)
def test_integer(self):
vals = nan_to_num(1)
assert_all(vals == 1)
assert_equal(type(vals), np.int_)
+ vals = nan_to_num(1, nan=10, posinf=20, neginf=30)
+ assert_all(vals == 1)
+ assert_equal(type(vals), np.int_)
def test_float(self):
vals = nan_to_num(1.0)
assert_all(vals == 1.0)
assert_equal(type(vals), np.float_)
+ vals = nan_to_num(1.1, nan=10, posinf=20, neginf=30)
+ assert_all(vals == 1.1)
+ assert_equal(type(vals), np.float_)
def test_complex_good(self):
vals = nan_to_num(1+1j)
assert_all(vals == 1+1j)
assert_equal(type(vals), np.complex_)
+ vals = nan_to_num(1+1j, nan=10, posinf=20, neginf=30)
+ assert_all(vals == 1+1j)
+ assert_equal(type(vals), np.complex_)
def test_complex_bad(self):
with np.errstate(divide='ignore', invalid='ignore'):
@@ -414,6 +444,16 @@ class TestNanToNum(object):
# !! inf. Comment out for now, and see if it
# !! changes
#assert_all(vals.real < -1e10) and assert_all(np.isfinite(vals))
+
+ def test_do_not_rewrite_previous_keyword(self):
+ # This is done to test that when, for instance, nan=np.inf then these
+ # values are not rewritten by posinf keyword to the posinf value.
+ with np.errstate(divide='ignore', invalid='ignore'):
+ vals = nan_to_num(np.array((-1., 0, 1))/0., nan=np.inf, posinf=999)
+ assert_all(np.isfinite(vals[[0, 2]]))
+ assert_all(vals[0] < -1e10)
+ assert_equal(vals[[1, 2]], [np.inf, 999])
+ assert_equal(type(vals), np.ndarray)
class TestRealIfClose(object):
diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py
index f55517732..2b254b6c0 100644
--- a/numpy/lib/type_check.py
+++ b/numpy/lib/type_check.py
@@ -363,18 +363,23 @@ def _getmaxmin(t):
return f.max, f.min
-def _nan_to_num_dispatcher(x, copy=None):
+def _nan_to_num_dispatcher(x, copy=None, nan=None, posinf=None, neginf=None):
return (x,)
@array_function_dispatch(_nan_to_num_dispatcher)
-def nan_to_num(x, copy=True):
+def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
"""
- Replace NaN with zero and infinity with large finite numbers.
+ Replace NaN with zero and infinity with large finite numbers (default
+ behaviour) or with the numbers defined by the user using the `nan`,
+ `posinf` and/or `neginf` keywords.
- If `x` is inexact, NaN is replaced by zero, and infinity and -infinity
- replaced by the respectively largest and most negative finite floating
- point values representable by ``x.dtype``.
+ If `x` is inexact, NaN is replaced by zero or by the user defined value in
+ `nan` keyword, infinity is replaced by the largest finite floating point
+ values representable by ``x.dtype`` or by the user defined value in
+ `posinf` keyword and -infinity is replaced by the most negative finite
+ floating point values representable by ``x.dtype`` or by the user defined
+ value in `neginf` keyword.
For complex dtypes, the above is applied to each of the real and
imaginary components of `x` separately.
@@ -390,6 +395,17 @@ def nan_to_num(x, copy=True):
in-place (False). The in-place operation only occurs if
casting to an array does not require a copy.
Default is True.
+ nan : int, float, optional
+ Value to be used to fill NaN values. If no value is passed
+ then NaN values will be replaced with 0.0.
+ posinf : int, float, optional
+ Value to be used to fill positive infinity values. If no value is
+ passed then positive infinity values will be replaced with a very
+ large number.
+ neginf : int, float, optional
+ Value to be used to fill negative infinity values. If no value is
+ passed then negative infinity values will be replaced with a very
+ small (or negative) number.
.. versionadded:: 1.13
@@ -424,6 +440,9 @@ def nan_to_num(x, copy=True):
>>> np.nan_to_num(x)
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
-1.28000000e+002, 1.28000000e+002])
+ >>> np.nan_to_num(x, nan=-9999, posinf=33333333, neginf=33333333)
+ array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03,
+ -1.2800000e+02, 1.2800000e+02])
>>> y = np.array([complex(np.inf, np.nan), np.nan, complex(np.nan, np.inf)])
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
-1.28000000e+002, 1.28000000e+002])
@@ -431,6 +450,8 @@ def nan_to_num(x, copy=True):
array([ 1.79769313e+308 +0.00000000e+000j, # may vary
0.00000000e+000 +0.00000000e+000j,
0.00000000e+000 +1.79769313e+308j])
+ >>> np.nan_to_num(y, nan=111111, posinf=222222)
+ array([222222.+111111.j, 111111. +0.j, 111111.+222222.j])
"""
x = _nx.array(x, subok=True, copy=copy)
xtype = x.dtype.type
@@ -444,10 +465,17 @@ def nan_to_num(x, copy=True):
dest = (x.real, x.imag) if iscomplex else (x,)
maxf, minf = _getmaxmin(x.real.dtype)
+ if posinf is not None:
+ maxf = posinf
+ if neginf is not None:
+ minf = neginf
for d in dest:
- _nx.copyto(d, 0.0, where=isnan(d))
- _nx.copyto(d, maxf, where=isposinf(d))
- _nx.copyto(d, minf, where=isneginf(d))
+ idx_nan = isnan(d)
+ idx_posinf = isposinf(d)
+ idx_neginf = isneginf(d)
+ _nx.copyto(d, nan, where=idx_nan)
+ _nx.copyto(d, maxf, where=idx_posinf)
+ _nx.copyto(d, minf, where=idx_neginf)
return x[()] if isscalar else x
#-----------------------------------------------------------------------------