summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2018-02-25 11:28:31 -0800
committerEric Wieser <wieser.eric@gmail.com>2018-02-25 11:28:31 -0800
commit8eccbf7df2f564e23c1bf7c9f5ee08e4b0dc6a36 (patch)
tree343b263cb1318c45013915e804a19102fa58d0c0 /numpy/lib
parent6d1687cbcdf32a2bde765d39394a7b2bb9838ae4 (diff)
downloadnumpy-8eccbf7df2f564e23c1bf7c9f5ee08e4b0dc6a36.tar.gz
BUG/MAINT: Remove special handling of 0d arrays and scalars in interp
These are now handled generically by the underlying C function This fixes the period argument for 0d arrays. Now never returns a pure-python scalar, which matches the behaviour of most of numpy. Rework of b66a200a4a1e98f1955c8a774e4ebfb4588dab5b
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py19
-rw-r--r--numpy/lib/tests/test_function_base.py13
2 files changed, 14 insertions, 18 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 391c47a06..504280cef 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1255,23 +1255,13 @@ def interp(x, xp, fp, left=None, right=None, period=None):
interp_func = compiled_interp
input_dtype = np.float64
- if period is None:
- if isinstance(x, (float, int, number)):
- return interp_func([x], xp, fp, left, right).item()
- elif isinstance(x, np.ndarray) and x.ndim == 0:
- return interp_func([x], xp, fp, left, right).item()
- else:
- return interp_func(x, xp, fp, left, right)
- else:
+ if period is not None:
if period == 0:
raise ValueError("period must be a non-zero value")
period = abs(period)
left = None
right = None
- return_array = True
- if isinstance(x, (float, int, number)):
- return_array = False
- x = [x]
+
x = np.asarray(x, dtype=np.float64)
xp = np.asarray(xp, dtype=np.float64)
fp = np.asarray(fp, dtype=input_dtype)
@@ -1289,10 +1279,7 @@ def interp(x, xp, fp, left=None, right=None, period=None):
xp = np.concatenate((xp[-1:]-period, xp, xp[0:1]+period))
fp = np.concatenate((fp[-1:], fp, fp[0:1]))
- if return_array:
- return interp_func(x, xp, fp, left, right)
- else:
- return interp_func(x, xp, fp, left, right).item()
+ return interp_func(x, xp, fp, left, right)
def angle(z, deg=0):
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index dc5fe3397..49b450175 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2252,8 +2252,17 @@ class TestInterp(object):
y = np.linspace(0, 1, 5)
x0 = np.array(.3)
assert_almost_equal(np.interp(x0, x, y), x0)
- x0 = np.array(.3, dtype=object)
- assert_almost_equal(np.interp(x0, x, y), .3)
+
+ xp = np.array([0, 2, 4])
+ fp = np.array([1, -1, 1])
+
+ actual = np.interp(np.array(1), xp, fp)
+ assert_equal(actual, 0)
+ assert_(isinstance(actual, np.float64))
+
+ actual = np.interp(np.array(4.5), xp, fp, period=4)
+ assert_equal(actual, 0.5)
+ assert_(isinstance(actual, np.float64))
def test_if_len_x_is_small(self):
xp = np.arange(0, 10, 0.0001)