summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/compiled_base.c37
-rw-r--r--numpy/lib/tests/test_function_base.py40
2 files changed, 55 insertions, 22 deletions
diff --git a/numpy/core/src/multiarray/compiled_base.c b/numpy/core/src/multiarray/compiled_base.c
index 8ffeedac2..b9db3bb8f 100644
--- a/numpy/core/src/multiarray/compiled_base.c
+++ b/numpy/core/src/multiarray/compiled_base.c
@@ -529,14 +529,15 @@ binary_search_with_guess(const npy_double key, const npy_double *arr,
}
/*
- * It would seem that for the following code to work, 'len' should
- * at least be 4. But because of the way 'guess' is normalized, it
- * will always be set to 1 if len <= 4. Given that, and that keys
- * outside of the 'arr' bounds have already been handled, and the
- * order in which comparisons happen below, it should become obvious
- * that it will work with any array of at least 2 items.
+ * If len <= 4 use linear search.
+ * From above we know key >= arr[0] when we start.
*/
- assert (len >= 2);
+ if (len <= 4) {
+ npy_intp i;
+
+ for (i = 1; i < len && key >= arr[i]; ++i);
+ return i - 1;
+ }
if (guess > len - 3) {
guess = len - 3;
@@ -546,36 +547,36 @@ binary_search_with_guess(const npy_double key, const npy_double *arr,
}
/* check most likely values: guess - 1, guess, guess + 1 */
- if (key <= arr[guess]) {
- if (key <= arr[guess - 1]) {
+ if (key < arr[guess]) {
+ if (key < arr[guess - 1]) {
imax = guess - 1;
/* last attempt to restrict search to items in cache */
if (guess > LIKELY_IN_CACHE_SIZE &&
- key > arr[guess - LIKELY_IN_CACHE_SIZE]) {
+ key >= arr[guess - LIKELY_IN_CACHE_SIZE]) {
imin = guess - LIKELY_IN_CACHE_SIZE;
}
}
else {
- /* key > arr[guess - 1] */
+ /* key >= arr[guess - 1] */
return guess - 1;
}
}
else {
- /* key > arr[guess] */
- if (key <= arr[guess + 1]) {
+ /* key >= arr[guess] */
+ if (key < arr[guess + 1]) {
return guess;
}
else {
- /* key > arr[guess + 1] */
- if (key <= arr[guess + 2]) {
+ /* key >= arr[guess + 1] */
+ if (key < arr[guess + 2]) {
return guess + 1;
}
else {
- /* key > arr[guess + 2] */
+ /* key >= arr[guess + 2] */
imin = guess + 2;
/* last attempt to restrict search to items in cache */
if (guess < len - LIKELY_IN_CACHE_SIZE - 1 &&
- key <= arr[guess + LIKELY_IN_CACHE_SIZE]) {
+ key < arr[guess + LIKELY_IN_CACHE_SIZE]) {
imax = guess + LIKELY_IN_CACHE_SIZE;
}
}
@@ -673,7 +674,7 @@ arr_interp(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict)
}
}
- /* binary_search_with_guess needs at least a 2 item long array */
+ /* binary_search_with_guess needs at least a 3 item long array */
if (lenxp == 1) {
const npy_double xp_val = dx[0];
const npy_double fp_val = dy[0];
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 88c932692..a5ac78e33 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1974,10 +1974,42 @@ class TestInterp(TestCase):
assert_almost_equal(np.interp(x0, x, y), x0)
def test_right_left_behavior(self):
- assert_equal(interp([-1, 0, 1], [0], [1]), [1, 1, 1])
- assert_equal(interp([-1, 0, 1], [0], [1], left=0), [0, 1, 1])
- assert_equal(interp([-1, 0, 1], [0], [1], right=0), [1, 1, 0])
- assert_equal(interp([-1, 0, 1], [0], [1], left=0, right=0), [0, 1, 0])
+ # Needs range of sizes to test different code paths.
+ # size ==1 is special cased, 1 < size < 5 is linear search, and
+ # size >= 5 goes through local search and possibly binary search.
+ for size in range(1, 10):
+ xp = np.arange(size, dtype=np.double)
+ yp = np.ones(size, dtype=np.double)
+ incpts = np.array([-1, 0, size - 1, size], dtype=np.double)
+ decpts = incpts[::-1]
+
+ incres = interp(incpts, xp, yp)
+ decres = interp(decpts, xp, yp)
+ inctgt = np.array([1, 1, 1, 1], dtype=np.float)
+ dectgt = inctgt[::-1]
+ assert_equal(incres, inctgt)
+ assert_equal(decres, dectgt)
+
+ incres = interp(incpts, xp, yp, left=0)
+ decres = interp(decpts, xp, yp, left=0)
+ inctgt = np.array([0, 1, 1, 1], dtype=np.float)
+ dectgt = inctgt[::-1]
+ assert_equal(incres, inctgt)
+ assert_equal(decres, dectgt)
+
+ incres = interp(incpts, xp, yp, right=2)
+ decres = interp(decpts, xp, yp, right=2)
+ inctgt = np.array([1, 1, 1, 2], dtype=np.float)
+ dectgt = inctgt[::-1]
+ assert_equal(incres, inctgt)
+ assert_equal(decres, dectgt)
+
+ incres = interp(incpts, xp, yp, left=0, right=2)
+ decres = interp(decpts, xp, yp, left=0, right=2)
+ inctgt = np.array([0, 1, 1, 2], dtype=np.float)
+ dectgt = inctgt[::-1]
+ assert_equal(incres, inctgt)
+ assert_equal(decres, dectgt)
def test_scalar_interpolation_point(self):
x = np.linspace(0, 1, 5)