summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/src/_compiled_base.c24
-rw-r--r--numpy/lib/tests/test_function_base.py2
2 files changed, 21 insertions, 5 deletions
diff --git a/numpy/lib/src/_compiled_base.c b/numpy/lib/src/_compiled_base.c
index 0f238a12a..652268f24 100644
--- a/numpy/lib/src/_compiled_base.c
+++ b/numpy/lib/src/_compiled_base.c
@@ -681,8 +681,15 @@ arr_interp(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict)
slopes[i] = (dy[i + 1] - dy[i])/(dx[i + 1] - dx[i]);
}
for (i = 0; i < lenx; i++) {
- npy_intp j = binary_search(dz[i], dx, lenxp);
+ const double x = dz[i];
+ npy_intp j;
+ if (npy_isnan(x)) {
+ dres[i] = x;
+ continue;
+ }
+
+ j = binary_search(x, dx, lenxp);
if (j == -1) {
dres[i] = lval;
}
@@ -693,7 +700,7 @@ arr_interp(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict)
dres[i] = rval;
}
else {
- dres[i] = slopes[j]*(dz[i] - dx[j]) + dy[j];
+ dres[i] = slopes[j]*(x - dx[j]) + dy[j];
}
}
NPY_END_ALLOW_THREADS;
@@ -702,8 +709,15 @@ arr_interp(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict)
else {
NPY_BEGIN_ALLOW_THREADS;
for (i = 0; i < lenx; i++) {
- npy_intp j = binary_search(dz[i], dx, lenxp);
+ const double x = dz[i];
+ npy_intp j;
+
+ if (npy_isnan(x)) {
+ dres[i] = x;
+ continue;
+ }
+ j = binary_search(x, dx, lenxp);
if (j == -1) {
dres[i] = lval;
}
@@ -714,8 +728,8 @@ arr_interp(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict)
dres[i] = rval;
}
else {
- double slope = (dy[j + 1] - dy[j])/(dx[j + 1] - dx[j]);
- dres[i] = slope*(dz[i] - dx[j]) + dy[j];
+ const double slope = (dy[j + 1] - dy[j])/(dx[j + 1] - dx[j]);
+ dres[i] = slope*(x - dx[j]) + dy[j];
}
}
NPY_END_ALLOW_THREADS;
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 003d3e541..8035ac002 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1490,6 +1490,8 @@ class TestInterp(TestCase):
assert_almost_equal(np.interp(x0, x, y), x0)
x0 = np.float64(.3)
assert_almost_equal(np.interp(x0, x, y), x0)
+ x0 = np.nan
+ assert_almost_equal(np.interp(x0, x, y), x0)
def test_zero_dimensional_interpolation_point(self):
x = np.linspace(0, 1, 5)