summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/npysort/binsearch.cpp6
-rw-r--r--numpy/core/tests/test_multiarray.py37
2 files changed, 28 insertions, 15 deletions
diff --git a/numpy/core/src/npysort/binsearch.cpp b/numpy/core/src/npysort/binsearch.cpp
index 98d305910..ea07bbb0c 100644
--- a/numpy/core/src/npysort/binsearch.cpp
+++ b/numpy/core/src/npysort/binsearch.cpp
@@ -330,9 +330,9 @@ struct binsearch_t : binsearch_base<arg> {
npy::bool_tag, npy::byte_tag, npy::ubyte_tag, npy::short_tag,
npy::ushort_tag, npy::int_tag, npy::uint_tag, npy::long_tag,
npy::ulong_tag, npy::longlong_tag, npy::ulonglong_tag,
- npy::half_tag, npy::float_tag, npy::double_tag,
- npy::longdouble_tag, npy::cfloat_tag, npy::cdouble_tag,
- npy::clongdouble_tag, npy::datetime_tag, npy::timedelta_tag>;
+ npy::float_tag, npy::double_tag, npy::longdouble_tag,
+ npy::cfloat_tag, npy::cdouble_tag, npy::clongdouble_tag,
+ npy::datetime_tag, npy::timedelta_tag, npy::half_tag>;
static constexpr std::array<value_type, taglist::size> map =
make_binsearch_map(taglist());
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index ef2fc7d2e..3aa2de10e 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2415,23 +2415,32 @@ class TestMethods:
assert_raises(ValueError, d.sort, kind=k)
assert_raises(ValueError, d.argsort, kind=k)
- def test_searchsorted(self):
- # test for floats and complex containing nans. The logic is the
- # same for all float types so only test double types for now.
- # The search sorted routines use the compare functions for the
- # array type, so this checks if that is consistent with the sort
- # order.
-
- # check double
- a = np.array([0, 1, np.nan])
- msg = "Test real searchsorted with nans, side='l'"
+ @pytest.mark.parametrize('a', [
+ np.array([0, 1, np.nan], dtype=np.float16),
+ np.array([0, 1, np.nan], dtype=np.float32),
+ np.array([0, 1, np.nan]),
+ ])
+ def test_searchsorted_floats(self, a):
+ # test for floats arrays containing nans. Explicitly test
+ # half, single, and double precision floats to verify that
+ # the NaN-handling is correct.
+ msg = "Test real (%s) searchsorted with nans, side='l'" % a.dtype
b = a.searchsorted(a, side='left')
assert_equal(b, np.arange(3), msg)
- msg = "Test real searchsorted with nans, side='r'"
+ msg = "Test real (%s) searchsorted with nans, side='r'" % a.dtype
b = a.searchsorted(a, side='right')
assert_equal(b, np.arange(1, 4), msg)
# check keyword arguments
a.searchsorted(v=1)
+ x = np.array([0, 1, np.nan], dtype='float32')
+ y = np.searchsorted(x, x[-1])
+ assert_equal(y, 2)
+
+ def test_searchsorted_complex(self):
+ # test for complex arrays containing nans.
+ # The search sorted routines use the compare functions for the
+ # array type, so this checks if that is consistent with the sort
+ # order.
# check double complex
a = np.zeros(9, dtype=np.complex128)
a.real += [0, 0, 1, 1, 0, 1, np.nan, np.nan, np.nan]
@@ -2450,7 +2459,8 @@ class TestMethods:
a = np.array([0, 128], dtype='>i4')
b = a.searchsorted(np.array(128, dtype='>i4'))
assert_equal(b, 1, msg)
-
+
+ def test_searchsorted_n_elements(self):
# Check 0 elements
a = np.ones(0)
b = a.searchsorted([0, 1, 2], 'left')
@@ -2470,6 +2480,7 @@ class TestMethods:
b = a.searchsorted([0, 1, 2], 'right')
assert_equal(b, [0, 2, 2])
+ def test_searchsorted_unaligned_array(self):
# Test searching unaligned array
a = np.arange(10)
aligned = np.empty(a.itemsize * a.size + 1, 'uint8')
@@ -2486,6 +2497,7 @@ class TestMethods:
b = a.searchsorted(unaligned, 'right')
assert_equal(b, a + 1)
+ def test_searchsorted_resetting(self):
# Test smart resetting of binsearch indices
a = np.arange(5)
b = a.searchsorted([6, 5, 4], 'left')
@@ -2493,6 +2505,7 @@ class TestMethods:
b = a.searchsorted([6, 5, 4], 'right')
assert_equal(b, [5, 5, 5])
+ def test_searchsorted_type_specific(self):
# Test all type specific binary search functions
types = ''.join((np.typecodes['AllInteger'], np.typecodes['AllFloat'],
np.typecodes['Datetime'], '?O'))