summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/npysort/binsearch.cpp6
-rw-r--r--numpy/core/tests/test_multiarray.py37
2 files changed, 38 insertions, 5 deletions
diff --git a/numpy/core/src/npysort/binsearch.cpp b/numpy/core/src/npysort/binsearch.cpp
index 98d305910..ea07bbb0c 100644
--- a/numpy/core/src/npysort/binsearch.cpp
+++ b/numpy/core/src/npysort/binsearch.cpp
@@ -330,9 +330,9 @@ struct binsearch_t : binsearch_base<arg> {
npy::bool_tag, npy::byte_tag, npy::ubyte_tag, npy::short_tag,
npy::ushort_tag, npy::int_tag, npy::uint_tag, npy::long_tag,
npy::ulong_tag, npy::longlong_tag, npy::ulonglong_tag,
- npy::half_tag, npy::float_tag, npy::double_tag,
- npy::longdouble_tag, npy::cfloat_tag, npy::cdouble_tag,
- npy::clongdouble_tag, npy::datetime_tag, npy::timedelta_tag>;
+ npy::float_tag, npy::double_tag, npy::longdouble_tag,
+ npy::cfloat_tag, npy::cdouble_tag, npy::clongdouble_tag,
+ npy::datetime_tag, npy::timedelta_tag, npy::half_tag>;
static constexpr std::array<value_type, taglist::size> map =
make_binsearch_map(taglist());
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index ef2fc7d2e..bc1621599 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2416,12 +2416,41 @@ class TestMethods:
assert_raises(ValueError, d.argsort, kind=k)
def test_searchsorted(self):
- # test for floats and complex containing nans. The logic is the
- # same for all float types so only test double types for now.
+ # test for floats and complex containing nans. Explicitly test
+ # half, single, and double precision floats to verify that
+ # the NaN-handling is correct.
# The search sorted routines use the compare functions for the
# array type, so this checks if that is consistent with the sort
# order.
+ # check half
+ a = np.array([0, 1, np.nan], dtype=np.float16)
+ msg = "Test real searchsorted with nans, side='l'"
+ b = a.searchsorted(a, side='left')
+ assert_equal(b, np.arange(3), msg)
+ msg = "Test real searchsorted with nans, side='r'"
+ b = a.searchsorted(a, side='right')
+ assert_equal(b, np.arange(1, 4), msg)
+ # check keyword arguments
+ a.searchsorted(v=1)
+ x = np.array([0, 1, np.nan], dtype='float16')
+ y = np.searchsorted(x, x[-1])
+ assert_equal(y, 2)
+
+ # check single
+ a = np.array([0, 1, np.nan], dtype=np.float32)
+ msg = "Test real searchsorted with nans, side='l'"
+ b = a.searchsorted(a, side='left')
+ assert_equal(b, np.arange(3), msg)
+ msg = "Test real searchsorted with nans, side='r'"
+ b = a.searchsorted(a, side='right')
+ assert_equal(b, np.arange(1, 4), msg)
+ # check keyword arguments
+ a.searchsorted(v=1)
+ x = np.array([0, 1, np.nan], dtype='float32')
+ y = np.searchsorted(x, x[-1])
+ assert_equal(y, 2)
+
# check double
a = np.array([0, 1, np.nan])
msg = "Test real searchsorted with nans, side='l'"
@@ -2432,6 +2461,10 @@ class TestMethods:
assert_equal(b, np.arange(1, 4), msg)
# check keyword arguments
a.searchsorted(v=1)
+ x = np.array([0, 1, np.nan])
+ y = np.searchsorted(x, x[-1])
+ assert_equal(y, 2)
+
# check double complex
a = np.zeros(9, dtype=np.complex128)
a.real += [0, 0, 1, 1, 0, 1, np.nan, np.nan, np.nan]