diff options
| -rw-r--r-- | doc/release/1.17.0-notes.rst | 6 | ||||
| m--------- | doc/sphinxext | 0 | ||||
| -rw-r--r-- | numpy/distutils/system_info.py | 60 | ||||
| -rw-r--r-- | numpy/lib/tests/test_type_check.py | 40 | ||||
| -rw-r--r-- | numpy/lib/type_check.py | 46 |
5 files changed, 123 insertions, 29 deletions
diff --git a/doc/release/1.17.0-notes.rst b/doc/release/1.17.0-notes.rst index 7857400e8..1155449a7 100644 --- a/doc/release/1.17.0-notes.rst +++ b/doc/release/1.17.0-notes.rst @@ -197,6 +197,12 @@ The boolean and integer types are incapable of storing ``np.nan`` and ``np.inf`` which allows us to provide specialized ufuncs that are up to 250x faster than the current approach. +New keywords added to ``np.nan_to_num`` +--------------------------------------- +``np.nan_to_num`` now accepts keywords ``nan``, ``posinf`` and ``neginf`` allowing the +user to define the value to replace the ``nan``, positive and negative ``np.inf`` values +respectively. + Changes ======= diff --git a/doc/sphinxext b/doc/sphinxext -Subproject e47b9404963ad2a75a11d167416038275c50d1c +Subproject a482f66913c1079d7439770f0119b55376bb1b8 diff --git a/numpy/distutils/system_info.py b/numpy/distutils/system_info.py index 4d923ad26..8a42434ff 100644 --- a/numpy/distutils/system_info.py +++ b/numpy/distutils/system_info.py @@ -1689,23 +1689,46 @@ class blas_info(system_info): else: info['include_dirs'] = self.get_include_dirs() if platform.system() == 'Windows': - # The check for windows is needed because has_cblas uses the + # The check for windows is needed because get_cblas_libs uses the # same compiler that was used to compile Python and msvc is # often not installed when mingw is being used. This rough # treatment is not desirable, but windows is tricky. info['language'] = 'f77' # XXX: is it generally true? else: - lib = self.has_cblas(info) + lib = self.get_cblas_libs(info) if lib is not None: info['language'] = 'c' - info['libraries'] = [lib] + info['libraries'] = lib info['define_macros'] = [('HAVE_CBLAS', None)] self.set_info(**info) - def has_cblas(self, info): + def get_cblas_libs(self, info): + """ Check whether we can link with CBLAS interface + + This method will search through several combinations of libraries + to check whether CBLAS is present: + + 1. Libraries in ``info['libraries']``, as is + 2. As 1. but also explicitly adding ``'cblas'`` as a library + 3. As 1. but also explicitly adding ``'blas'`` as a library + 4. Check only library ``'cblas'`` + 5. Check only library ``'blas'`` + + Parameters + ---------- + info : dict + system information dictionary for compilation and linking + + Returns + ------- + libraries : list of str or None + a list of libraries that enables the use of CBLAS interface. + Returns None if not found or a compilation error occurs. + + Since 1.17 returns a list. + """ # primitive cblas check by looking for the header and trying to link # cblas or blas - res = False c = customized_ccompiler() tmpdir = tempfile.mkdtemp() s = """#include <cblas.h> @@ -1724,29 +1747,26 @@ class blas_info(system_info): # check we can compile (find headers) obj = c.compile([src], output_dir=tmpdir, include_dirs=self.get_include_dirs()) + except (distutils.ccompiler.CompileError, distutils.ccompiler.LinkError): + return None - # check we can link (find library) - # some systems have separate cblas and blas libs. First - # check for cblas lib, and if not present check for blas lib. + # check we can link (find library) + # some systems have separate cblas and blas libs. + for libs in [info['libraries'], ['cblas'] + info['libraries'], + ['blas'] + info['libraries'], ['cblas'], ['blas']]: try: c.link_executable(obj, os.path.join(tmpdir, "a.out"), - libraries=["cblas"], + libraries=libs, library_dirs=info['library_dirs'], extra_postargs=info.get('extra_link_args', [])) - res = "cblas" + return libs + # This breaks the for loop + break except distutils.ccompiler.LinkError: - c.link_executable(obj, os.path.join(tmpdir, "a.out"), - libraries=["blas"], - library_dirs=info['library_dirs'], - extra_postargs=info.get('extra_link_args', [])) - res = "blas" - except distutils.ccompiler.CompileError: - res = None - except distutils.ccompiler.LinkError: - res = None + pass finally: shutil.rmtree(tmpdir) - return res + return None class openblas_info(blas_info): diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py index 2982ca31a..b3f114b92 100644 --- a/numpy/lib/tests/test_type_check.py +++ b/numpy/lib/tests/test_type_check.py @@ -360,6 +360,14 @@ class TestNanToNum(object): assert_(vals[1] == 0) assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2])) assert_equal(type(vals), np.ndarray) + + # perform the same tests but with nan, posinf and neginf keywords + with np.errstate(divide='ignore', invalid='ignore'): + vals = nan_to_num(np.array((-1., 0, 1))/0., + nan=10, posinf=20, neginf=30) + assert_equal(vals, [30, 10, 20]) + assert_all(np.isfinite(vals[[0, 2]])) + assert_equal(type(vals), np.ndarray) # perform the same test but in-place with np.errstate(divide='ignore', invalid='ignore'): @@ -371,26 +379,48 @@ class TestNanToNum(object): assert_(vals[1] == 0) assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2])) assert_equal(type(vals), np.ndarray) + + # perform the same test but in-place + with np.errstate(divide='ignore', invalid='ignore'): + vals = np.array((-1., 0, 1))/0. + result = nan_to_num(vals, copy=False, nan=10, posinf=20, neginf=30) + + assert_(result is vals) + assert_equal(vals, [30, 10, 20]) + assert_all(np.isfinite(vals[[0, 2]])) + assert_equal(type(vals), np.ndarray) def test_array(self): vals = nan_to_num([1]) assert_array_equal(vals, np.array([1], int)) assert_equal(type(vals), np.ndarray) + vals = nan_to_num([1], nan=10, posinf=20, neginf=30) + assert_array_equal(vals, np.array([1], int)) + assert_equal(type(vals), np.ndarray) def test_integer(self): vals = nan_to_num(1) assert_all(vals == 1) assert_equal(type(vals), np.int_) + vals = nan_to_num(1, nan=10, posinf=20, neginf=30) + assert_all(vals == 1) + assert_equal(type(vals), np.int_) def test_float(self): vals = nan_to_num(1.0) assert_all(vals == 1.0) assert_equal(type(vals), np.float_) + vals = nan_to_num(1.1, nan=10, posinf=20, neginf=30) + assert_all(vals == 1.1) + assert_equal(type(vals), np.float_) def test_complex_good(self): vals = nan_to_num(1+1j) assert_all(vals == 1+1j) assert_equal(type(vals), np.complex_) + vals = nan_to_num(1+1j, nan=10, posinf=20, neginf=30) + assert_all(vals == 1+1j) + assert_equal(type(vals), np.complex_) def test_complex_bad(self): with np.errstate(divide='ignore', invalid='ignore'): @@ -414,6 +444,16 @@ class TestNanToNum(object): # !! inf. Comment out for now, and see if it # !! changes #assert_all(vals.real < -1e10) and assert_all(np.isfinite(vals)) + + def test_do_not_rewrite_previous_keyword(self): + # This is done to test that when, for instance, nan=np.inf then these + # values are not rewritten by posinf keyword to the posinf value. + with np.errstate(divide='ignore', invalid='ignore'): + vals = nan_to_num(np.array((-1., 0, 1))/0., nan=np.inf, posinf=999) + assert_all(np.isfinite(vals[[0, 2]])) + assert_all(vals[0] < -1e10) + assert_equal(vals[[1, 2]], [np.inf, 999]) + assert_equal(type(vals), np.ndarray) class TestRealIfClose(object): diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py index f55517732..2b254b6c0 100644 --- a/numpy/lib/type_check.py +++ b/numpy/lib/type_check.py @@ -363,18 +363,23 @@ def _getmaxmin(t): return f.max, f.min -def _nan_to_num_dispatcher(x, copy=None): +def _nan_to_num_dispatcher(x, copy=None, nan=None, posinf=None, neginf=None): return (x,) @array_function_dispatch(_nan_to_num_dispatcher) -def nan_to_num(x, copy=True): +def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): """ - Replace NaN with zero and infinity with large finite numbers. + Replace NaN with zero and infinity with large finite numbers (default + behaviour) or with the numbers defined by the user using the `nan`, + `posinf` and/or `neginf` keywords. - If `x` is inexact, NaN is replaced by zero, and infinity and -infinity - replaced by the respectively largest and most negative finite floating - point values representable by ``x.dtype``. + If `x` is inexact, NaN is replaced by zero or by the user defined value in + `nan` keyword, infinity is replaced by the largest finite floating point + values representable by ``x.dtype`` or by the user defined value in + `posinf` keyword and -infinity is replaced by the most negative finite + floating point values representable by ``x.dtype`` or by the user defined + value in `neginf` keyword. For complex dtypes, the above is applied to each of the real and imaginary components of `x` separately. @@ -390,6 +395,17 @@ def nan_to_num(x, copy=True): in-place (False). The in-place operation only occurs if casting to an array does not require a copy. Default is True. + nan : int, float, optional + Value to be used to fill NaN values. If no value is passed + then NaN values will be replaced with 0.0. + posinf : int, float, optional + Value to be used to fill positive infinity values. If no value is + passed then positive infinity values will be replaced with a very + large number. + neginf : int, float, optional + Value to be used to fill negative infinity values. If no value is + passed then negative infinity values will be replaced with a very + small (or negative) number. .. versionadded:: 1.13 @@ -424,6 +440,9 @@ def nan_to_num(x, copy=True): >>> np.nan_to_num(x) array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary -1.28000000e+002, 1.28000000e+002]) + >>> np.nan_to_num(x, nan=-9999, posinf=33333333, neginf=33333333) + array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03, + -1.2800000e+02, 1.2800000e+02]) >>> y = np.array([complex(np.inf, np.nan), np.nan, complex(np.nan, np.inf)]) array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary -1.28000000e+002, 1.28000000e+002]) @@ -431,6 +450,8 @@ def nan_to_num(x, copy=True): array([ 1.79769313e+308 +0.00000000e+000j, # may vary 0.00000000e+000 +0.00000000e+000j, 0.00000000e+000 +1.79769313e+308j]) + >>> np.nan_to_num(y, nan=111111, posinf=222222) + array([222222.+111111.j, 111111. +0.j, 111111.+222222.j]) """ x = _nx.array(x, subok=True, copy=copy) xtype = x.dtype.type @@ -444,10 +465,17 @@ def nan_to_num(x, copy=True): dest = (x.real, x.imag) if iscomplex else (x,) maxf, minf = _getmaxmin(x.real.dtype) + if posinf is not None: + maxf = posinf + if neginf is not None: + minf = neginf for d in dest: - _nx.copyto(d, 0.0, where=isnan(d)) - _nx.copyto(d, maxf, where=isposinf(d)) - _nx.copyto(d, minf, where=isneginf(d)) + idx_nan = isnan(d) + idx_posinf = isposinf(d) + idx_neginf = isneginf(d) + _nx.copyto(d, nan, where=idx_nan) + _nx.copyto(d, maxf, where=idx_posinf) + _nx.copyto(d, minf, where=idx_neginf) return x[()] if isscalar else x #----------------------------------------------------------------------------- |
