summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py17
1 files changed, 6 insertions, 11 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index a3044f27d..85e75305d 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3876,17 +3876,12 @@ def _quantile_is_valid(q):
def _lerp(a, b, t, out=None):
""" Linearly interpolate from a to b by a factor of t """
diff_b_a = subtract(b, a)
-
- _scalar_or_0d = lambda x: np.isscalar(x) or np.ndim(x) == 0
- if _scalar_or_0d(a) and _scalar_or_0d(b) and _scalar_or_0d(t):
- if t <= 0.5:
- return add(a, diff_b_a * t, out=out)
- else:
- return subtract(b, diff_b_a * (1 - t), out=out)
- else:
- lerp_interpolation = add(a, diff_b_a*t, out=out)
- subtract(b, diff_b_a * (1 - t), out=lerp_interpolation, where=t>=0.5)
- return lerp_interpolation
+ # asanyarray is a stop-gap until gh-13105
+ lerp_interpolation = asanyarray(add(a, diff_b_a*t, out=out))
+ subtract(b, diff_b_a * (1 - t), out=lerp_interpolation, where=t>=0.5)
+ if lerp_interpolation.ndim == 0 and out is None:
+ lerp_interpolation = lerp_interpolation[()] # unpack 0d arrays
+ return lerp_interpolation
def _quantile_ureduce_func(a, q, axis=None, out=None, overwrite_input=False,