diff options
| -rw-r--r-- | numpy/core/src/npysort/binsearch.cpp | 6 | ||||
| -rw-r--r-- | numpy/core/tests/test_multiarray.py | 37 |
2 files changed, 38 insertions, 5 deletions
diff --git a/numpy/core/src/npysort/binsearch.cpp b/numpy/core/src/npysort/binsearch.cpp index 98d305910..ea07bbb0c 100644 --- a/numpy/core/src/npysort/binsearch.cpp +++ b/numpy/core/src/npysort/binsearch.cpp @@ -330,9 +330,9 @@ struct binsearch_t : binsearch_base<arg> { npy::bool_tag, npy::byte_tag, npy::ubyte_tag, npy::short_tag, npy::ushort_tag, npy::int_tag, npy::uint_tag, npy::long_tag, npy::ulong_tag, npy::longlong_tag, npy::ulonglong_tag, - npy::half_tag, npy::float_tag, npy::double_tag, - npy::longdouble_tag, npy::cfloat_tag, npy::cdouble_tag, - npy::clongdouble_tag, npy::datetime_tag, npy::timedelta_tag>; + npy::float_tag, npy::double_tag, npy::longdouble_tag, + npy::cfloat_tag, npy::cdouble_tag, npy::clongdouble_tag, + npy::datetime_tag, npy::timedelta_tag, npy::half_tag>; static constexpr std::array<value_type, taglist::size> map = make_binsearch_map(taglist()); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index ef2fc7d2e..bc1621599 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2416,12 +2416,41 @@ class TestMethods: assert_raises(ValueError, d.argsort, kind=k) def test_searchsorted(self): - # test for floats and complex containing nans. The logic is the - # same for all float types so only test double types for now. + # test for floats and complex containing nans. Explicitly test + # half, single, and double precision floats to verify that + # the NaN-handling is correct. # The search sorted routines use the compare functions for the # array type, so this checks if that is consistent with the sort # order. + # check half + a = np.array([0, 1, np.nan], dtype=np.float16) + msg = "Test real searchsorted with nans, side='l'" + b = a.searchsorted(a, side='left') + assert_equal(b, np.arange(3), msg) + msg = "Test real searchsorted with nans, side='r'" + b = a.searchsorted(a, side='right') + assert_equal(b, np.arange(1, 4), msg) + # check keyword arguments + a.searchsorted(v=1) + x = np.array([0, 1, np.nan], dtype='float16') + y = np.searchsorted(x, x[-1]) + assert_equal(y, 2) + + # check single + a = np.array([0, 1, np.nan], dtype=np.float32) + msg = "Test real searchsorted with nans, side='l'" + b = a.searchsorted(a, side='left') + assert_equal(b, np.arange(3), msg) + msg = "Test real searchsorted with nans, side='r'" + b = a.searchsorted(a, side='right') + assert_equal(b, np.arange(1, 4), msg) + # check keyword arguments + a.searchsorted(v=1) + x = np.array([0, 1, np.nan], dtype='float32') + y = np.searchsorted(x, x[-1]) + assert_equal(y, 2) + # check double a = np.array([0, 1, np.nan]) msg = "Test real searchsorted with nans, side='l'" @@ -2432,6 +2461,10 @@ class TestMethods: assert_equal(b, np.arange(1, 4), msg) # check keyword arguments a.searchsorted(v=1) + x = np.array([0, 1, np.nan]) + y = np.searchsorted(x, x[-1]) + assert_equal(y, 2) + # check double complex a = np.zeros(9, dtype=np.complex128) a.real += [0, 0, 1, 1, 0, 1, np.nan, np.nan, np.nan] |
