diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/compiled_base.c | 8 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 14 |
2 files changed, 22 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/compiled_base.c b/numpy/core/src/multiarray/compiled_base.c index bcb44f6d1..8c140f5e2 100644 --- a/numpy/core/src/multiarray/compiled_base.c +++ b/numpy/core/src/multiarray/compiled_base.c @@ -654,6 +654,10 @@ arr_interp(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict) else if (j == lenxp - 1) { dres[i] = dy[j]; } + else if (dx[j] == x_val) { + /* Avoid potential non-finite interpolation */ + dres[i] = dy[j]; + } else { const npy_double slope = (slopes != NULL) ? slopes[j] : (dy[j+1] - dy[j]) / (dx[j+1] - dx[j]); @@ -822,6 +826,10 @@ arr_interp_complex(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict) else if (j == lenxp - 1) { dres[i] = dy[j]; } + else if (dx[j] == x_val) { + /* Avoid potential non-finite interpolation */ + dres[i] = dy[j]; + } else { if (slopes!=NULL) { dres[i].real = slopes[j].real*(x_val - dx[j]) + dy[j].real; diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 4103a9eb3..d2a9181db 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -2237,6 +2237,14 @@ class TestInterp(object): x0 = np.nan assert_almost_equal(np.interp(x0, x, y), x0) + def test_non_finite_behavior(self): + x = [1, 2, 2.5, 3, 4] + xp = [1, 2, 3, 4] + fp = [1, 2, np.inf, 4] + assert_almost_equal(np.interp(x, xp, fp), [1, 2, np.inf, np.inf, 4]) + fp = [1, 2, np.nan, 4] + assert_almost_equal(np.interp(x, xp, fp), [1, 2, np.nan, np.nan, 4]) + def test_complex_interp(self): # test complex interpolation x = np.linspace(0, 1, 5) @@ -2251,6 +2259,12 @@ class TestInterp(object): x0 = 2.0 right = 2 + 3.0j assert_almost_equal(np.interp(x0, x, y, right=right), right) + # test complex non finite + x = [1, 2, 2.5, 3, 4] + xp = [1, 2, 3, 4] + fp = [1, 2+1j, np.inf, 4] + y = [1, 2+1j, np.inf+0.5j, np.inf, 4] + assert_almost_equal(np.interp(x, xp, fp), y) # test complex periodic x = [-180, -170, -185, 185, -10, -5, 0, 365] xp = [190, -190, 350, -350] |