summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/compiled_base.c8
-rw-r--r--numpy/lib/tests/test_function_base.py14
2 files changed, 22 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/compiled_base.c b/numpy/core/src/multiarray/compiled_base.c
index bcb44f6d1..8c140f5e2 100644
--- a/numpy/core/src/multiarray/compiled_base.c
+++ b/numpy/core/src/multiarray/compiled_base.c
@@ -654,6 +654,10 @@ arr_interp(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict)
else if (j == lenxp - 1) {
dres[i] = dy[j];
}
+ else if (dx[j] == x_val) {
+ /* Avoid potential non-finite interpolation */
+ dres[i] = dy[j];
+ }
else {
const npy_double slope = (slopes != NULL) ? slopes[j] :
(dy[j+1] - dy[j]) / (dx[j+1] - dx[j]);
@@ -822,6 +826,10 @@ arr_interp_complex(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict)
else if (j == lenxp - 1) {
dres[i] = dy[j];
}
+ else if (dx[j] == x_val) {
+ /* Avoid potential non-finite interpolation */
+ dres[i] = dy[j];
+ }
else {
if (slopes!=NULL) {
dres[i].real = slopes[j].real*(x_val - dx[j]) + dy[j].real;
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 4103a9eb3..d2a9181db 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2237,6 +2237,14 @@ class TestInterp(object):
x0 = np.nan
assert_almost_equal(np.interp(x0, x, y), x0)
+ def test_non_finite_behavior(self):
+ x = [1, 2, 2.5, 3, 4]
+ xp = [1, 2, 3, 4]
+ fp = [1, 2, np.inf, 4]
+ assert_almost_equal(np.interp(x, xp, fp), [1, 2, np.inf, np.inf, 4])
+ fp = [1, 2, np.nan, 4]
+ assert_almost_equal(np.interp(x, xp, fp), [1, 2, np.nan, np.nan, 4])
+
def test_complex_interp(self):
# test complex interpolation
x = np.linspace(0, 1, 5)
@@ -2251,6 +2259,12 @@ class TestInterp(object):
x0 = 2.0
right = 2 + 3.0j
assert_almost_equal(np.interp(x0, x, y, right=right), right)
+ # test complex non finite
+ x = [1, 2, 2.5, 3, 4]
+ xp = [1, 2, 3, 4]
+ fp = [1, 2+1j, np.inf, 4]
+ y = [1, 2+1j, np.inf+0.5j, np.inf, 4]
+ assert_almost_equal(np.interp(x, xp, fp), y)
# test complex periodic
x = [-180, -170, -185, 185, -10, -5, 0, 365]
xp = [190, -190, 350, -350]