diff options
-rw-r--r-- | numpy/lib/function_base.py | 14 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 12 |
2 files changed, 19 insertions, 7 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index f56a8844e..8fb6ba6eb 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -3875,10 +3875,18 @@ def _quantile_is_valid(q): def _lerp(a, b, t, out=None): """ Linearly interpolate from a to b by a factor of t """ - if t < 0.5: - return add(a, subtract(b, a)*t, out=out) + #return add(a, subtract(b, a, out=out)*t, out=out) + diff_b_a = subtract(b, a) + + if np.isscalar(a) and np.isscalar(b) and (np.isscalar(t) or np.ndim(t) == 0): + if t <= 0.5: + return add(a, diff_b_a * t, out=out) + else: + return subtract(b, diff_b_a * (1 - t), out=out) else: - return subtract(b, subtract(b, a)*(1-t), out=out) + lerp_interpolation = add(a, diff_b_a*t, out=out) + subtract(b, diff_b_a * (1 - t), out=lerp_interpolation, where=t>=0.5) + return lerp_interpolation def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False, diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 96e3135ef..5468901b4 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -3140,8 +3140,10 @@ class TestLerp: @hypothesis.given(t=st.floats(allow_nan=False, allow_infinity=False, min_value=0, max_value=1), - a=st.floats(allow_nan=False, allow_infinity=False), - b=st.floats(allow_nan=False, allow_infinity=False)) + a=st.floats(allow_nan=False, allow_infinity=False, + width=32), + b=st.floats(allow_nan=False, allow_infinity=False, + width=32)) def test_lerp_bounded(self, t, a, b): if a <= b: assert a <= np.lib.function_base._lerp(a, b, t) <= b @@ -3150,8 +3152,10 @@ class TestLerp: @hypothesis.given(t=st.floats(allow_nan=False, allow_infinity=False, min_value=0, max_value=1), - a=st.floats(allow_nan=False, allow_infinity=False), - b=st.floats(allow_nan=False, allow_infinity=False)) + a=st.floats(allow_nan=False, allow_infinity=False, + width=32), + b=st.floats(allow_nan=False, allow_infinity=False, + width=32)) def test_lerp_symmetric(self, t, a, b): # double subtraction is needed to remove the extra precision that t < 0.5 has assert np.lib.function_base._lerp(a, b, 1 - (1 - t)) == np.lib.function_base._lerp(b, a, 1 - t) |