diff options
| -rw-r--r-- | numpy/core/setup.py | 4 | ||||
| -rw-r--r-- | numpy/core/src/common/npy_binsearch.h | 31 | ||||
| -rw-r--r-- | numpy/core/src/common/npy_binsearch.h.src | 144 | ||||
| -rw-r--r-- | numpy/core/src/common/npy_sort.h.src | 9 | ||||
| -rw-r--r-- | numpy/core/src/common/numpy_tag.h | 147 | ||||
| -rw-r--r-- | numpy/core/src/npysort/binsearch.c.src | 250 | ||||
| -rw-r--r-- | numpy/core/src/npysort/binsearch.cpp | 387 | ||||
| -rw-r--r-- | numpy/core/src/npysort/npysort_common.h | 18 |
8 files changed, 589 insertions, 401 deletions
diff --git a/numpy/core/setup.py b/numpy/core/setup.py index 17fbd99af..9b6e74240 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -945,8 +945,8 @@ def configuration(parent_package='',top_path=None): join('src', 'npysort', 'radixsort.cpp'), join('src', 'common', 'npy_partition.h.src'), join('src', 'npysort', 'selection.c.src'), - join('src', 'common', 'npy_binsearch.h.src'), - join('src', 'npysort', 'binsearch.c.src'), + join('src', 'common', 'npy_binsearch.h'), + join('src', 'npysort', 'binsearch.cpp'), ] ####################################################################### diff --git a/numpy/core/src/common/npy_binsearch.h b/numpy/core/src/common/npy_binsearch.h new file mode 100644 index 000000000..8d2f0714d --- /dev/null +++ b/numpy/core/src/common/npy_binsearch.h @@ -0,0 +1,31 @@ +#ifndef __NPY_BINSEARCH_H__ +#define __NPY_BINSEARCH_H__ + +#include "npy_sort.h" +#include <numpy/npy_common.h> +#include <numpy/ndarraytypes.h> + + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void (PyArray_BinSearchFunc)(const char*, const char*, char*, + npy_intp, npy_intp, + npy_intp, npy_intp, npy_intp, + PyArrayObject*); + +typedef int (PyArray_ArgBinSearchFunc)(const char*, const char*, + const char*, char*, + npy_intp, npy_intp, npy_intp, + npy_intp, npy_intp, npy_intp, + PyArrayObject*); + +NPY_NO_EXPORT PyArray_BinSearchFunc* get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side); +NPY_NO_EXPORT PyArray_ArgBinSearchFunc* get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/numpy/core/src/common/npy_binsearch.h.src b/numpy/core/src/common/npy_binsearch.h.src deleted file mode 100644 index 052c44482..000000000 --- a/numpy/core/src/common/npy_binsearch.h.src +++ /dev/null @@ -1,144 +0,0 @@ -#ifndef __NPY_BINSEARCH_H__ -#define __NPY_BINSEARCH_H__ - -#include "npy_sort.h" -#include <numpy/npy_common.h> -#include <numpy/ndarraytypes.h> - -#define ARRAY_SIZE(a) (sizeof(a)/sizeof(a[0])) - -typedef void (PyArray_BinSearchFunc)(const char*, const char*, char*, - npy_intp, npy_intp, - npy_intp, npy_intp, npy_intp, - PyArrayObject*); - -typedef int (PyArray_ArgBinSearchFunc)(const char*, const char*, - const char*, char*, - npy_intp, npy_intp, npy_intp, - npy_intp, npy_intp, npy_intp, - PyArrayObject*); - -typedef struct { - int typenum; - PyArray_BinSearchFunc *binsearch[NPY_NSEARCHSIDES]; -} binsearch_map; - -typedef struct { - int typenum; - PyArray_ArgBinSearchFunc *argbinsearch[NPY_NSEARCHSIDES]; -} argbinsearch_map; - -/**begin repeat - * - * #side = left, right# - */ - -/**begin repeat1 - * - * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong, - * longlong, ulonglong, half, float, double, longdouble, - * cfloat, cdouble, clongdouble, datetime, timedelta# - */ - -NPY_NO_EXPORT void -binsearch_@side@_@suff@(const char *arr, const char *key, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, npy_intp ret_str, - PyArrayObject *unused); -NPY_NO_EXPORT int -argbinsearch_@side@_@suff@(const char *arr, const char *key, - const char *sort, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, - npy_intp sort_str, npy_intp ret_str, - PyArrayObject *unused); -/**end repeat1**/ - -NPY_NO_EXPORT void -npy_binsearch_@side@(const char *arr, const char *key, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, - npy_intp ret_str, PyArrayObject *cmp); -NPY_NO_EXPORT int -npy_argbinsearch_@side@(const char *arr, const char *key, - const char *sort, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, - npy_intp sort_str, npy_intp ret_str, - PyArrayObject *cmp); -/**end repeat**/ - -/**begin repeat - * - * #arg = , arg# - * #Arg = , Arg# - */ - -static @arg@binsearch_map _@arg@binsearch_map[] = { - /* If adding new types, make sure to keep them ordered by type num */ - /**begin repeat1 - * - * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, - * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE, - * CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA, HALF# - * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong, - * longlong, ulonglong, float, double, longdouble, - * cfloat, cdouble, clongdouble, datetime, timedelta, half# - */ - {NPY_@TYPE@, - { - &@arg@binsearch_left_@suff@, - &@arg@binsearch_right_@suff@, - }, - }, - /**end repeat1**/ -}; - -static PyArray_@Arg@BinSearchFunc *gen@arg@binsearch_map[] = { - &npy_@arg@binsearch_left, - &npy_@arg@binsearch_right, -}; - -static NPY_INLINE PyArray_@Arg@BinSearchFunc* -get_@arg@binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) -{ - npy_intp nfuncs = ARRAY_SIZE(_@arg@binsearch_map); - npy_intp min_idx = 0; - npy_intp max_idx = nfuncs; - int type = dtype->type_num; - - if (side >= NPY_NSEARCHSIDES) { - return NULL; - } - - /* - * It seems only fair that a binary search function be searched for - * using a binary search... - */ - while (min_idx < max_idx) { - npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); - - if (_@arg@binsearch_map[mid_idx].typenum < type) { - min_idx = mid_idx + 1; - } - else { - max_idx = mid_idx; - } - } - - if (min_idx < nfuncs && - _@arg@binsearch_map[min_idx].typenum == type) { - return _@arg@binsearch_map[min_idx].@arg@binsearch[side]; - } - - if (dtype->f->compare) { - return gen@arg@binsearch_map[side]; - } - - return NULL; -} -/**end repeat**/ - -#undef ARRAY_SIZE - -#endif diff --git a/numpy/core/src/common/npy_sort.h.src b/numpy/core/src/common/npy_sort.h.src index b4a1e9b0c..a3f556f56 100644 --- a/numpy/core/src/common/npy_sort.h.src +++ b/numpy/core/src/common/npy_sort.h.src @@ -18,6 +18,11 @@ static NPY_INLINE int npy_get_msb(npy_uintp unum) return depth_limit; } +#ifdef __cplusplus +extern "C" { +#endif + + /* ***************************************************************************** @@ -102,4 +107,8 @@ NPY_NO_EXPORT int npy_aheapsort(void *vec, npy_intp *ind, npy_intp cnt, void *ar NPY_NO_EXPORT int npy_amergesort(void *vec, npy_intp *ind, npy_intp cnt, void *arr); NPY_NO_EXPORT int npy_atimsort(void *vec, npy_intp *ind, npy_intp cnt, void *arr); +#ifdef __cplusplus +} +#endif + #endif diff --git a/numpy/core/src/common/numpy_tag.h b/numpy/core/src/common/numpy_tag.h index dc8d5286b..60e9b02cd 100644 --- a/numpy/core/src/common/numpy_tag.h +++ b/numpy/core/src/common/numpy_tag.h @@ -1,8 +1,15 @@ #ifndef _NPY_COMMON_TAG_H_ #define _NPY_COMMON_TAG_H_ +#include "../npysort/npysort_common.h" + namespace npy { +template<typename... tags> +struct taglist { + static constexpr unsigned size = sizeof...(tags); +}; + struct integral_tag { }; struct floating_point_tag { @@ -14,63 +21,203 @@ struct date_tag { struct bool_tag : integral_tag { using type = npy_bool; + static constexpr NPY_TYPES type_value = NPY_BOOL; + static int less(type const& a, type const& b) { + return BOOL_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct byte_tag : integral_tag { using type = npy_byte; + static constexpr NPY_TYPES type_value = NPY_BYTE; + static int less(type const& a, type const& b) { + return BYTE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct ubyte_tag : integral_tag { using type = npy_ubyte; + static constexpr NPY_TYPES type_value = NPY_UBYTE; + static int less(type const& a, type const& b) { + return UBYTE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct short_tag : integral_tag { using type = npy_short; + static constexpr NPY_TYPES type_value = NPY_SHORT; + static int less(type const& a, type const& b) { + return SHORT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct ushort_tag : integral_tag { using type = npy_ushort; + static constexpr NPY_TYPES type_value = NPY_USHORT; + static int less(type const& a, type const& b) { + return USHORT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct int_tag : integral_tag { using type = npy_int; + static constexpr NPY_TYPES type_value = NPY_INT; + static int less(type const& a, type const& b) { + return INT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct uint_tag : integral_tag { using type = npy_uint; + static constexpr NPY_TYPES type_value = NPY_UINT; + static int less(type const& a, type const& b) { + return UINT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct long_tag : integral_tag { using type = npy_long; + static constexpr NPY_TYPES type_value = NPY_LONG; + static int less(type const& a, type const& b) { + return LONG_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct ulong_tag : integral_tag { using type = npy_ulong; + static constexpr NPY_TYPES type_value = NPY_ULONG; + static int less(type const& a, type const& b) { + return ULONG_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct longlong_tag : integral_tag { using type = npy_longlong; + static constexpr NPY_TYPES type_value = NPY_LONGLONG; + static int less(type const& a, type const& b) { + return LONGLONG_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct ulonglong_tag : integral_tag { using type = npy_ulonglong; + static constexpr NPY_TYPES type_value = NPY_ULONGLONG; + static int less(type const& a, type const& b) { + return ULONGLONG_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct half_tag { using type = npy_half; + static constexpr NPY_TYPES type_value = NPY_HALF; + static int less(type const& a, type const& b) { + return HALF_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct float_tag : floating_point_tag { using type = npy_float; + static constexpr NPY_TYPES type_value = NPY_FLOAT; + static int less(type const& a, type const& b) { + return FLOAT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct double_tag : floating_point_tag { using type = npy_double; + static constexpr NPY_TYPES type_value = NPY_DOUBLE; + static int less(type const& a, type const& b) { + return DOUBLE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct longdouble_tag : floating_point_tag { using type = npy_longdouble; + static constexpr NPY_TYPES type_value = NPY_LONGDOUBLE; + static int less(type const& a, type const& b) { + return LONGDOUBLE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct cfloat_tag : complex_tag { using type = npy_cfloat; + static constexpr NPY_TYPES type_value = NPY_CFLOAT; + static int less(type const& a, type const& b) { + return CFLOAT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct cdouble_tag : complex_tag { using type = npy_cdouble; + static constexpr NPY_TYPES type_value = NPY_CDOUBLE; + static int less(type const& a, type const& b) { + return CDOUBLE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct clongdouble_tag : complex_tag { using type = npy_clongdouble; + static constexpr NPY_TYPES type_value = NPY_CLONGDOUBLE; + static int less(type const& a, type const& b) { + return CLONGDOUBLE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct datetime_tag : date_tag { using type = npy_datetime; + static constexpr NPY_TYPES type_value = NPY_DATETIME; + static int less(type const& a, type const& b) { + return DATETIME_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; struct timedelta_tag : date_tag { using type = npy_timedelta; + static constexpr NPY_TYPES type_value = NPY_TIMEDELTA; + static int less(type const& a, type const& b) { + return TIMEDELTA_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } }; } // namespace npy diff --git a/numpy/core/src/npysort/binsearch.c.src b/numpy/core/src/npysort/binsearch.c.src deleted file mode 100644 index 41165897b..000000000 --- a/numpy/core/src/npysort/binsearch.c.src +++ /dev/null @@ -1,250 +0,0 @@ -/* -*- c -*- */ -#define NPY_NO_DEPRECATED_API NPY_API_VERSION - -#include "npy_sort.h" -#include "npysort_common.h" -#include "npy_binsearch.h" - -#define NOT_USED NPY_UNUSED(unused) - -/* - ***************************************************************************** - ** NUMERIC SEARCHES ** - ***************************************************************************** - */ - -/**begin repeat - * - * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, - * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, - * CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA# - * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong, - * longlong, ulonglong, half, float, double, longdouble, - * cfloat, cdouble, clongdouble, datetime, timedelta# - * #type = npy_bool, npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, - * npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong, - * npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat, - * npy_cdouble, npy_clongdouble, npy_datetime, npy_timedelta# - */ - -#define @TYPE@_LTE(a, b) (!@TYPE@_LT((b), (a))) - -/**begin repeat1 - * - * #side = left, right# - * #CMP = LT, LTE# - */ - -NPY_NO_EXPORT void -binsearch_@side@_@suff@(const char *arr, const char *key, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, npy_intp ret_str, - PyArrayObject *NOT_USED) -{ - npy_intp min_idx = 0; - npy_intp max_idx = arr_len; - @type@ last_key_val; - - if (key_len == 0) { - return; - } - last_key_val = *(const @type@ *)key; - - for (; key_len > 0; key_len--, key += key_str, ret += ret_str) { - const @type@ key_val = *(const @type@ *)key; - /* - * Updating only one of the indices based on the previous key - * gives the search a big boost when keys are sorted, but slightly - * slows down things for purely random ones. - */ - if (@TYPE@_LT(last_key_val, key_val)) { - max_idx = arr_len; - } - else { - min_idx = 0; - max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len; - } - - last_key_val = key_val; - - while (min_idx < max_idx) { - const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); - const @type@ mid_val = *(const @type@ *)(arr + mid_idx*arr_str); - if (@TYPE@_@CMP@(mid_val, key_val)) { - min_idx = mid_idx + 1; - } - else { - max_idx = mid_idx; - } - } - *(npy_intp *)ret = min_idx; - } -} - -NPY_NO_EXPORT int -argbinsearch_@side@_@suff@(const char *arr, const char *key, - const char *sort, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, - npy_intp sort_str, npy_intp ret_str, - PyArrayObject *NOT_USED) -{ - npy_intp min_idx = 0; - npy_intp max_idx = arr_len; - @type@ last_key_val; - - if (key_len == 0) { - return 0; - } - last_key_val = *(const @type@ *)key; - - for (; key_len > 0; key_len--, key += key_str, ret += ret_str) { - const @type@ key_val = *(const @type@ *)key; - /* - * Updating only one of the indices based on the previous key - * gives the search a big boost when keys are sorted, but slightly - * slows down things for purely random ones. - */ - if (@TYPE@_LT(last_key_val, key_val)) { - max_idx = arr_len; - } - else { - min_idx = 0; - max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len; - } - - last_key_val = key_val; - - while (min_idx < max_idx) { - const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); - const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str); - @type@ mid_val; - - if (sort_idx < 0 || sort_idx >= arr_len) { - return -1; - } - - mid_val = *(const @type@ *)(arr + sort_idx*arr_str); - - if (@TYPE@_@CMP@(mid_val, key_val)) { - min_idx = mid_idx + 1; - } - else { - max_idx = mid_idx; - } - } - *(npy_intp *)ret = min_idx; - } - return 0; -} - -/**end repeat1**/ -/**end repeat**/ - -/* - ***************************************************************************** - ** GENERIC SEARCH ** - ***************************************************************************** - */ - - /**begin repeat - * - * #side = left, right# - * #CMP = <, <=# - */ - -NPY_NO_EXPORT void -npy_binsearch_@side@(const char *arr, const char *key, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, npy_intp ret_str, - PyArrayObject *cmp) -{ - PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare; - npy_intp min_idx = 0; - npy_intp max_idx = arr_len; - const char *last_key = key; - - for (; key_len > 0; key_len--, key += key_str, ret += ret_str) { - /* - * Updating only one of the indices based on the previous key - * gives the search a big boost when keys are sorted, but slightly - * slows down things for purely random ones. - */ - if (compare(last_key, key, cmp) @CMP@ 0) { - max_idx = arr_len; - } - else { - min_idx = 0; - max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len; - } - - last_key = key; - - while (min_idx < max_idx) { - const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); - const char *arr_ptr = arr + mid_idx*arr_str; - - if (compare(arr_ptr, key, cmp) @CMP@ 0) { - min_idx = mid_idx + 1; - } - else { - max_idx = mid_idx; - } - } - *(npy_intp *)ret = min_idx; - } -} - -NPY_NO_EXPORT int -npy_argbinsearch_@side@(const char *arr, const char *key, - const char *sort, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, - npy_intp sort_str, npy_intp ret_str, - PyArrayObject *cmp) -{ - PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare; - npy_intp min_idx = 0; - npy_intp max_idx = arr_len; - const char *last_key = key; - - for (; key_len > 0; key_len--, key += key_str, ret += ret_str) { - /* - * Updating only one of the indices based on the previous key - * gives the search a big boost when keys are sorted, but slightly - * slows down things for purely random ones. - */ - if (compare(last_key, key, cmp) @CMP@ 0) { - max_idx = arr_len; - } - else { - min_idx = 0; - max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len; - } - - last_key = key; - - while (min_idx < max_idx) { - const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); - const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str); - const char *arr_ptr; - - if (sort_idx < 0 || sort_idx >= arr_len) { - return -1; - } - - arr_ptr = arr + sort_idx*arr_str; - - if (compare(arr_ptr, key, cmp) @CMP@ 0) { - min_idx = mid_idx + 1; - } - else { - max_idx = mid_idx; - } - } - *(npy_intp *)ret = min_idx; - } - return 0; -} - -/**end repeat**/ diff --git a/numpy/core/src/npysort/binsearch.cpp b/numpy/core/src/npysort/binsearch.cpp new file mode 100644 index 000000000..cd5f03470 --- /dev/null +++ b/numpy/core/src/npysort/binsearch.cpp @@ -0,0 +1,387 @@ +/* -*- c -*- */ + +#define NPY_NO_DEPRECATED_API NPY_API_VERSION + +#include "npy_sort.h" +#include "numpy_tag.h" +#include <numpy/npy_common.h> +#include <numpy/ndarraytypes.h> + +#include "npy_binsearch.h" + +#include <array> +#include <functional> // for std::less and std::less_equal + +// Enumerators for the variant of binsearch +enum arg_t { noarg, arg}; +enum side_t { left, right}; + +// Mapping from enumerators to comparators +template<class Tag, side_t side> +struct side_to_cmp; +template<class Tag> +struct side_to_cmp<Tag, left> { static constexpr auto value = Tag::less; }; +template<class Tag> +struct side_to_cmp<Tag, right> { static constexpr auto value = Tag::less_equal; }; + +template<side_t side> +struct side_to_generic_cmp; +template<> +struct side_to_generic_cmp<left> { using type = std::less<int>; }; +template<> +struct side_to_generic_cmp<right> { using type = std::less_equal<int>; }; + +/* + ***************************************************************************** + ** NUMERIC SEARCHES ** + ***************************************************************************** + */ +template<class Tag, side_t side> +static void +binsearch(const char *arr, const char *key, char *ret, + npy_intp arr_len, npy_intp key_len, + npy_intp arr_str, npy_intp key_str, npy_intp ret_str, + PyArrayObject*) +{ + using T = typename Tag::type; + auto cmp = side_to_cmp<Tag, side>::value; + npy_intp min_idx = 0; + npy_intp max_idx = arr_len; + T last_key_val; + + if (key_len == 0) { + return; + } + last_key_val = *(const T *)key; + + for (; key_len > 0; key_len--, key += key_str, ret += ret_str) { + const T key_val = *(const T *)key; + /* + * Updating only one of the indices based on the previous key + * gives the search a big boost when keys are sorted, but slightly + * slows down things for purely random ones. + */ + if (cmp(last_key_val, key_val)) { + max_idx = arr_len; + } + else { + min_idx = 0; + max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len; + } + + last_key_val = key_val; + + while (min_idx < max_idx) { + const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); + const T mid_val = *(const T *)(arr + mid_idx*arr_str); + if (cmp(mid_val, key_val)) { + min_idx = mid_idx + 1; + } + else { + max_idx = mid_idx; + } + } + *(npy_intp *)ret = min_idx; + } +} + +template<class Tag, side_t side> +static int +argbinsearch(const char *arr, const char *key, + const char *sort, char *ret, + npy_intp arr_len, npy_intp key_len, + npy_intp arr_str, npy_intp key_str, + npy_intp sort_str, npy_intp ret_str, PyArrayObject*) +{ + using T = typename Tag::type; + auto cmp = side_to_cmp<Tag, side>::value; + npy_intp min_idx = 0; + npy_intp max_idx = arr_len; + T last_key_val; + + if (key_len == 0) { + return 0; + } + last_key_val = *(const T *)key; + + for (; key_len > 0; key_len--, key += key_str, ret += ret_str) { + const T key_val = *(const T *)key; + /* + * Updating only one of the indices based on the previous key + * gives the search a big boost when keys are sorted, but slightly + * slows down things for purely random ones. + */ + if (cmp(last_key_val, key_val)) { + max_idx = arr_len; + } + else { + min_idx = 0; + max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len; + } + + last_key_val = key_val; + + while (min_idx < max_idx) { + const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); + const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str); + T mid_val; + + if (sort_idx < 0 || sort_idx >= arr_len) { + return -1; + } + + mid_val = *(const T *)(arr + sort_idx*arr_str); + + if (cmp(mid_val, key_val)) { + min_idx = mid_idx + 1; + } + else { + max_idx = mid_idx; + } + } + *(npy_intp *)ret = min_idx; + } + return 0; +} + +/* + ***************************************************************************** + ** GENERIC SEARCH ** + ***************************************************************************** + */ + +template<side_t side> +static void +npy_binsearch(const char *arr, const char *key, char *ret, + npy_intp arr_len, npy_intp key_len, + npy_intp arr_str, npy_intp key_str, npy_intp ret_str, + PyArrayObject *cmp) +{ + using Cmp = typename side_to_generic_cmp<side>::type; + PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare; + npy_intp min_idx = 0; + npy_intp max_idx = arr_len; + const char *last_key = key; + + for (; key_len > 0; key_len--, key += key_str, ret += ret_str) { + /* + * Updating only one of the indices based on the previous key + * gives the search a big boost when keys are sorted, but slightly + * slows down things for purely random ones. + */ + if (Cmp{}(compare(last_key, key, cmp), 0)) { + max_idx = arr_len; + } + else { + min_idx = 0; + max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len; + } + + last_key = key; + + while (min_idx < max_idx) { + const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); + const char *arr_ptr = arr + mid_idx*arr_str; + + if (Cmp{}(compare(arr_ptr, key, cmp), 0)) { + min_idx = mid_idx + 1; + } + else { + max_idx = mid_idx; + } + } + *(npy_intp *)ret = min_idx; + } +} + +template<side_t side> +static int +npy_argbinsearch(const char *arr, const char *key, + const char *sort, char *ret, + npy_intp arr_len, npy_intp key_len, + npy_intp arr_str, npy_intp key_str, + npy_intp sort_str, npy_intp ret_str, + PyArrayObject *cmp) +{ + using Cmp = typename side_to_generic_cmp<side>::type; + PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare; + npy_intp min_idx = 0; + npy_intp max_idx = arr_len; + const char *last_key = key; + + for (; key_len > 0; key_len--, key += key_str, ret += ret_str) { + /* + * Updating only one of the indices based on the previous key + * gives the search a big boost when keys are sorted, but slightly + * slows down things for purely random ones. + */ + if (Cmp{}(compare(last_key, key, cmp), 0)) { + max_idx = arr_len; + } + else { + min_idx = 0; + max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len; + } + + last_key = key; + + while (min_idx < max_idx) { + const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); + const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str); + const char *arr_ptr; + + if (sort_idx < 0 || sort_idx >= arr_len) { + return -1; + } + + arr_ptr = arr + sort_idx*arr_str; + + if (Cmp{}(compare(arr_ptr, key, cmp), 0)) { + min_idx = mid_idx + 1; + } + else { + max_idx = mid_idx; + } + } + *(npy_intp *)ret = min_idx; + } + return 0; +} + +/* + ***************************************************************************** + ** GENERATOR ** + ***************************************************************************** + */ + +template<arg_t arg> +struct binsearch_base; + +template<> +struct binsearch_base<arg> { + using function_type = PyArray_ArgBinSearchFunc*; + struct value_type { + int typenum; + function_type binsearch[NPY_NSEARCHSIDES]; + }; + template<class... Tags> + static constexpr std::array<value_type, sizeof...(Tags)> make_binsearch_map(npy::taglist<Tags...>) { + return std::array<value_type, sizeof...(Tags)>{ + value_type{ + Tags::type_value, + { + (function_type)&argbinsearch<Tags, left>, + (function_type)argbinsearch<Tags, right> + } + }... + }; + } + static constexpr std::array<function_type, 2> npy_map = { + (function_type)&npy_argbinsearch<left>, + (function_type)&npy_argbinsearch<right> + }; +}; +constexpr std::array<binsearch_base<arg>::function_type, 2> binsearch_base<arg>::npy_map; + +template<> +struct binsearch_base<noarg> { + using function_type = PyArray_BinSearchFunc*; + struct value_type { + int typenum; + function_type binsearch[NPY_NSEARCHSIDES]; + }; + template<class... Tags> + static constexpr std::array<value_type, sizeof...(Tags)> make_binsearch_map(npy::taglist<Tags...>) { + return std::array<value_type, sizeof...(Tags)>{ + value_type{ + Tags::type_value, + { + (function_type)&binsearch<Tags, left>, + (function_type)binsearch<Tags, right> + } + }... + }; + } + static constexpr std::array<function_type, 2> npy_map = { + (function_type)&npy_binsearch<left>, + (function_type)&npy_binsearch<right> + }; +}; +constexpr std::array<binsearch_base<noarg>::function_type, 2> binsearch_base<noarg>::npy_map; + +// Handle generation of all binsearch variants +template<arg_t arg> +struct binsearch_t : binsearch_base<arg> { + using binsearch_base<arg>::make_binsearch_map; + using value_type = typename binsearch_base<arg>::value_type; + + using taglist = npy::taglist< + /* If adding new types, make sure to keep them ordered by type num */ + npy::bool_tag, npy::byte_tag, npy::ubyte_tag, npy::short_tag, + npy::ushort_tag, npy::int_tag, npy::uint_tag, npy::long_tag, + npy::ulong_tag, npy::longlong_tag, npy::ulonglong_tag, npy::half_tag, + npy::float_tag, npy::double_tag, npy::longdouble_tag, npy::cfloat_tag, + npy::cdouble_tag, npy::clongdouble_tag, npy::datetime_tag, + npy::timedelta_tag>; + + static constexpr std::array<value_type, taglist::size> map = make_binsearch_map(taglist()); +}; +template<arg_t arg> +constexpr std::array<typename binsearch_t<arg>::value_type, binsearch_t<arg>::taglist::size> binsearch_t<arg>::map; + + +template<arg_t arg> +static NPY_INLINE typename binsearch_t<arg>::function_type +_get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) +{ + using binsearch = binsearch_t<arg>; + npy_intp nfuncs = binsearch::map.size();; + npy_intp min_idx = 0; + npy_intp max_idx = nfuncs; + int type = dtype->type_num; + + if ((int)side >= (int)NPY_NSEARCHSIDES) { + return NULL; + } + + /* + * It seems only fair that a binary search function be searched for + * using a binary search... + */ + while (min_idx < max_idx) { + npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); + + if (binsearch::map[mid_idx].typenum < type) { + min_idx = mid_idx + 1; + } + else { + max_idx = mid_idx; + } + } + + if (min_idx < nfuncs && + binsearch::map[min_idx].typenum == type) { + return binsearch::map[min_idx].binsearch[side]; + } + + if (dtype->f->compare) { + return binsearch::npy_map[side]; + } + + return NULL; +} + + +/* + ***************************************************************************** + ** C INTERFACE ** + ***************************************************************************** + */ +extern "C" { + NPY_NO_EXPORT PyArray_BinSearchFunc* get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) { + return _get_binsearch_func<noarg>(dtype, side); + } + NPY_NO_EXPORT PyArray_ArgBinSearchFunc* get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) { + return _get_binsearch_func<arg>(dtype, side); + } +} diff --git a/numpy/core/src/npysort/npysort_common.h b/numpy/core/src/npysort/npysort_common.h index 2a6e4d421..a537fcd08 100644 --- a/numpy/core/src/npysort/npysort_common.h +++ b/numpy/core/src/npysort/npysort_common.h @@ -4,6 +4,10 @@ #include <stdlib.h> #include <numpy/ndarraytypes.h> +#ifdef __cplusplus +extern "C" { +#endif + /* ***************************************************************************** ** SWAP MACROS ** @@ -139,14 +143,14 @@ LONGDOUBLE_LT(npy_longdouble a, npy_longdouble b) NPY_INLINE static int -npy_half_isnan(npy_half h) +_npy_half_isnan(npy_half h) { return ((h&0x7c00u) == 0x7c00u) && ((h&0x03ffu) != 0x0000u); } NPY_INLINE static int -npy_half_lt_nonan(npy_half h1, npy_half h2) +_npy_half_lt_nonan(npy_half h1, npy_half h2) { if (h1&0x8000u) { if (h2&0x8000u) { @@ -173,11 +177,11 @@ HALF_LT(npy_half a, npy_half b) { int ret; - if (npy_half_isnan(b)) { - ret = !npy_half_isnan(a); + if (_npy_half_isnan(b)) { + ret = !_npy_half_isnan(a); } else { - ret = !npy_half_isnan(a) && npy_half_lt_nonan(a, b); + ret = !_npy_half_isnan(a) && _npy_half_lt_nonan(a, b); } return ret; @@ -373,4 +377,8 @@ GENERIC_SWAP(char *a, char *b, size_t len) } } +#ifdef __cplusplus +} +#endif + #endif |
