diff options
-rw-r--r-- | numpy/core/src/multiarray/compiled_base.c | 37 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 40 |
2 files changed, 55 insertions, 22 deletions
diff --git a/numpy/core/src/multiarray/compiled_base.c b/numpy/core/src/multiarray/compiled_base.c index 8ffeedac2..b9db3bb8f 100644 --- a/numpy/core/src/multiarray/compiled_base.c +++ b/numpy/core/src/multiarray/compiled_base.c @@ -529,14 +529,15 @@ binary_search_with_guess(const npy_double key, const npy_double *arr, } /* - * It would seem that for the following code to work, 'len' should - * at least be 4. But because of the way 'guess' is normalized, it - * will always be set to 1 if len <= 4. Given that, and that keys - * outside of the 'arr' bounds have already been handled, and the - * order in which comparisons happen below, it should become obvious - * that it will work with any array of at least 2 items. + * If len <= 4 use linear search. + * From above we know key >= arr[0] when we start. */ - assert (len >= 2); + if (len <= 4) { + npy_intp i; + + for (i = 1; i < len && key >= arr[i]; ++i); + return i - 1; + } if (guess > len - 3) { guess = len - 3; @@ -546,36 +547,36 @@ binary_search_with_guess(const npy_double key, const npy_double *arr, } /* check most likely values: guess - 1, guess, guess + 1 */ - if (key <= arr[guess]) { - if (key <= arr[guess - 1]) { + if (key < arr[guess]) { + if (key < arr[guess - 1]) { imax = guess - 1; /* last attempt to restrict search to items in cache */ if (guess > LIKELY_IN_CACHE_SIZE && - key > arr[guess - LIKELY_IN_CACHE_SIZE]) { + key >= arr[guess - LIKELY_IN_CACHE_SIZE]) { imin = guess - LIKELY_IN_CACHE_SIZE; } } else { - /* key > arr[guess - 1] */ + /* key >= arr[guess - 1] */ return guess - 1; } } else { - /* key > arr[guess] */ - if (key <= arr[guess + 1]) { + /* key >= arr[guess] */ + if (key < arr[guess + 1]) { return guess; } else { - /* key > arr[guess + 1] */ - if (key <= arr[guess + 2]) { + /* key >= arr[guess + 1] */ + if (key < arr[guess + 2]) { return guess + 1; } else { - /* key > arr[guess + 2] */ + /* key >= arr[guess + 2] */ imin = guess + 2; /* last attempt to restrict search to items in cache */ if (guess < len - LIKELY_IN_CACHE_SIZE - 1 && - key <= arr[guess + LIKELY_IN_CACHE_SIZE]) { + key < arr[guess + LIKELY_IN_CACHE_SIZE]) { imax = guess + LIKELY_IN_CACHE_SIZE; } } @@ -673,7 +674,7 @@ arr_interp(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict) } } - /* binary_search_with_guess needs at least a 2 item long array */ + /* binary_search_with_guess needs at least a 3 item long array */ if (lenxp == 1) { const npy_double xp_val = dx[0]; const npy_double fp_val = dy[0]; diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 88c932692..a5ac78e33 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1974,10 +1974,42 @@ class TestInterp(TestCase): assert_almost_equal(np.interp(x0, x, y), x0) def test_right_left_behavior(self): - assert_equal(interp([-1, 0, 1], [0], [1]), [1, 1, 1]) - assert_equal(interp([-1, 0, 1], [0], [1], left=0), [0, 1, 1]) - assert_equal(interp([-1, 0, 1], [0], [1], right=0), [1, 1, 0]) - assert_equal(interp([-1, 0, 1], [0], [1], left=0, right=0), [0, 1, 0]) + # Needs range of sizes to test different code paths. + # size ==1 is special cased, 1 < size < 5 is linear search, and + # size >= 5 goes through local search and possibly binary search. + for size in range(1, 10): + xp = np.arange(size, dtype=np.double) + yp = np.ones(size, dtype=np.double) + incpts = np.array([-1, 0, size - 1, size], dtype=np.double) + decpts = incpts[::-1] + + incres = interp(incpts, xp, yp) + decres = interp(decpts, xp, yp) + inctgt = np.array([1, 1, 1, 1], dtype=np.float) + dectgt = inctgt[::-1] + assert_equal(incres, inctgt) + assert_equal(decres, dectgt) + + incres = interp(incpts, xp, yp, left=0) + decres = interp(decpts, xp, yp, left=0) + inctgt = np.array([0, 1, 1, 1], dtype=np.float) + dectgt = inctgt[::-1] + assert_equal(incres, inctgt) + assert_equal(decres, dectgt) + + incres = interp(incpts, xp, yp, right=2) + decres = interp(decpts, xp, yp, right=2) + inctgt = np.array([1, 1, 1, 2], dtype=np.float) + dectgt = inctgt[::-1] + assert_equal(incres, inctgt) + assert_equal(decres, dectgt) + + incres = interp(incpts, xp, yp, left=0, right=2) + decres = interp(decpts, xp, yp, left=0, right=2) + inctgt = np.array([0, 1, 1, 2], dtype=np.float) + dectgt = inctgt[::-1] + assert_equal(incres, inctgt) + assert_equal(decres, dectgt) def test_scalar_interpolation_point(self): x = np.linspace(0, 1, 5) |