summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTobias Pitters <tobias.pitters@gmail.com>2020-05-27 19:50:49 +0200
committerTobias Pitters <tobias.pitters@gmail.com>2020-05-27 19:50:49 +0200
commit214e8302ff92c002853ab03427cae2448a812f7c (patch)
tree1c9f898ff2e2999f0fc9a15fac5c33629e2a3256
parent13975971df33dd511852cc4b9dcd7dceb43d9221 (diff)
downloadnumpy-214e8302ff92c002853ab03427cae2448a812f7c.tar.gz
fix lerp function and corresponding tests
-rw-r--r--numpy/lib/function_base.py14
-rw-r--r--numpy/lib/tests/test_function_base.py12
2 files changed, 19 insertions, 7 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index f56a8844e..8fb6ba6eb 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3875,10 +3875,18 @@ def _quantile_is_valid(q):
def _lerp(a, b, t, out=None):
""" Linearly interpolate from a to b by a factor of t """
- if t < 0.5:
- return add(a, subtract(b, a)*t, out=out)
+ #return add(a, subtract(b, a, out=out)*t, out=out)
+ diff_b_a = subtract(b, a)
+
+ if np.isscalar(a) and np.isscalar(b) and (np.isscalar(t) or np.ndim(t) == 0):
+ if t <= 0.5:
+ return add(a, diff_b_a * t, out=out)
+ else:
+ return subtract(b, diff_b_a * (1 - t), out=out)
else:
- return subtract(b, subtract(b, a)*(1-t), out=out)
+ lerp_interpolation = add(a, diff_b_a*t, out=out)
+ subtract(b, diff_b_a * (1 - t), out=lerp_interpolation, where=t>=0.5)
+ return lerp_interpolation
def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False,
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 96e3135ef..5468901b4 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -3140,8 +3140,10 @@ class TestLerp:
@hypothesis.given(t=st.floats(allow_nan=False, allow_infinity=False,
min_value=0, max_value=1),
- a=st.floats(allow_nan=False, allow_infinity=False),
- b=st.floats(allow_nan=False, allow_infinity=False))
+ a=st.floats(allow_nan=False, allow_infinity=False,
+ width=32),
+ b=st.floats(allow_nan=False, allow_infinity=False,
+ width=32))
def test_lerp_bounded(self, t, a, b):
if a <= b:
assert a <= np.lib.function_base._lerp(a, b, t) <= b
@@ -3150,8 +3152,10 @@ class TestLerp:
@hypothesis.given(t=st.floats(allow_nan=False, allow_infinity=False,
min_value=0, max_value=1),
- a=st.floats(allow_nan=False, allow_infinity=False),
- b=st.floats(allow_nan=False, allow_infinity=False))
+ a=st.floats(allow_nan=False, allow_infinity=False,
+ width=32),
+ b=st.floats(allow_nan=False, allow_infinity=False,
+ width=32))
def test_lerp_symmetric(self, t, a, b):
# double subtraction is needed to remove the extra precision that t < 0.5 has
assert np.lib.function_base._lerp(a, b, 1 - (1 - t)) == np.lib.function_base._lerp(b, a, 1 - t)