diff options
| -rw-r--r-- | numpy/core/src/npysort/binsearch.cpp | 219 | ||||
| -rw-r--r-- | numpy/core/src/npysort/selection.cpp | 36 |
2 files changed, 141 insertions, 114 deletions
diff --git a/numpy/core/src/npysort/binsearch.cpp b/numpy/core/src/npysort/binsearch.cpp index cd5f03470..8733a8d4e 100644 --- a/numpy/core/src/npysort/binsearch.cpp +++ b/numpy/core/src/npysort/binsearch.cpp @@ -2,46 +2,65 @@ #define NPY_NO_DEPRECATED_API NPY_API_VERSION -#include "npy_sort.h" -#include "numpy_tag.h" -#include <numpy/npy_common.h> -#include <numpy/ndarraytypes.h> +#include "numpy/ndarraytypes.h" +#include "numpy/npy_common.h" #include "npy_binsearch.h" +#include "npy_sort.h" +#include "numpy_tag.h" #include <array> #include <functional> // for std::less and std::less_equal // Enumerators for the variant of binsearch -enum arg_t { noarg, arg}; -enum side_t { left, right}; +enum arg_t +{ + noarg, + arg +}; +enum side_t +{ + left, + right +}; // Mapping from enumerators to comparators -template<class Tag, side_t side> +template <class Tag, side_t side> struct side_to_cmp; -template<class Tag> -struct side_to_cmp<Tag, left> { static constexpr auto value = Tag::less; }; -template<class Tag> -struct side_to_cmp<Tag, right> { static constexpr auto value = Tag::less_equal; }; -template<side_t side> +template <class Tag> +struct side_to_cmp<Tag, left> { + static constexpr auto value = Tag::less; +}; + +template <class Tag> +struct side_to_cmp<Tag, right> { + static constexpr auto value = Tag::less_equal; +}; + +template <side_t side> struct side_to_generic_cmp; -template<> -struct side_to_generic_cmp<left> { using type = std::less<int>; }; -template<> -struct side_to_generic_cmp<right> { using type = std::less_equal<int>; }; + +template <> +struct side_to_generic_cmp<left> { + using type = std::less<int>; +}; + +template <> +struct side_to_generic_cmp<right> { + using type = std::less_equal<int>; +}; /* ***************************************************************************** ** NUMERIC SEARCHES ** ***************************************************************************** */ -template<class Tag, side_t side> +template <class Tag, side_t side> static void -binsearch(const char *arr, const char *key, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, npy_intp ret_str, - PyArrayObject*) +binsearch(const char *arr, const char *key, char *ret, npy_intp arr_len, + npy_intp key_len, npy_intp arr_str, npy_intp key_str, + npy_intp ret_str, PyArrayObject *) { using T = typename Tag::type; auto cmp = side_to_cmp<Tag, side>::value; @@ -73,7 +92,7 @@ binsearch(const char *arr, const char *key, char *ret, while (min_idx < max_idx) { const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); - const T mid_val = *(const T *)(arr + mid_idx*arr_str); + const T mid_val = *(const T *)(arr + mid_idx * arr_str); if (cmp(mid_val, key_val)) { min_idx = mid_idx + 1; } @@ -85,13 +104,12 @@ binsearch(const char *arr, const char *key, char *ret, } } -template<class Tag, side_t side> +template <class Tag, side_t side> static int -argbinsearch(const char *arr, const char *key, - const char *sort, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, - npy_intp sort_str, npy_intp ret_str, PyArrayObject*) +argbinsearch(const char *arr, const char *key, const char *sort, char *ret, + npy_intp arr_len, npy_intp key_len, npy_intp arr_str, + npy_intp key_str, npy_intp sort_str, npy_intp ret_str, + PyArrayObject *) { using T = typename Tag::type; auto cmp = side_to_cmp<Tag, side>::value; @@ -123,14 +141,14 @@ argbinsearch(const char *arr, const char *key, while (min_idx < max_idx) { const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); - const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str); + const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx * sort_str); T mid_val; if (sort_idx < 0 || sort_idx >= arr_len) { return -1; } - mid_val = *(const T *)(arr + sort_idx*arr_str); + mid_val = *(const T *)(arr + sort_idx * arr_str); if (cmp(mid_val, key_val)) { min_idx = mid_idx + 1; @@ -150,12 +168,11 @@ argbinsearch(const char *arr, const char *key, ***************************************************************************** */ -template<side_t side> +template <side_t side> static void -npy_binsearch(const char *arr, const char *key, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, npy_intp ret_str, - PyArrayObject *cmp) +npy_binsearch(const char *arr, const char *key, char *ret, npy_intp arr_len, + npy_intp key_len, npy_intp arr_str, npy_intp key_str, + npy_intp ret_str, PyArrayObject *cmp) { using Cmp = typename side_to_generic_cmp<side>::type; PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare; @@ -181,7 +198,7 @@ npy_binsearch(const char *arr, const char *key, char *ret, while (min_idx < max_idx) { const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); - const char *arr_ptr = arr + mid_idx*arr_str; + const char *arr_ptr = arr + mid_idx * arr_str; if (Cmp{}(compare(arr_ptr, key, cmp), 0)) { min_idx = mid_idx + 1; @@ -194,14 +211,12 @@ npy_binsearch(const char *arr, const char *key, char *ret, } } -template<side_t side> +template <side_t side> static int -npy_argbinsearch(const char *arr, const char *key, - const char *sort, char *ret, - npy_intp arr_len, npy_intp key_len, - npy_intp arr_str, npy_intp key_str, - npy_intp sort_str, npy_intp ret_str, - PyArrayObject *cmp) +npy_argbinsearch(const char *arr, const char *key, const char *sort, char *ret, + npy_intp arr_len, npy_intp key_len, npy_intp arr_str, + npy_intp key_str, npy_intp sort_str, npy_intp ret_str, + PyArrayObject *cmp) { using Cmp = typename side_to_generic_cmp<side>::type; PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare; @@ -227,14 +242,14 @@ npy_argbinsearch(const char *arr, const char *key, while (min_idx < max_idx) { const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1); - const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str); + const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx * sort_str); const char *arr_ptr; if (sort_idx < 0 || sort_idx >= arr_len) { return -1; } - arr_ptr = arr + sort_idx*arr_str; + arr_ptr = arr + sort_idx * arr_str; if (Cmp{}(compare(arr_ptr, key, cmp), 0)) { min_idx = mid_idx + 1; @@ -254,88 +269,87 @@ npy_argbinsearch(const char *arr, const char *key, ***************************************************************************** */ -template<arg_t arg> +template <arg_t arg> struct binsearch_base; -template<> +template <> struct binsearch_base<arg> { - using function_type = PyArray_ArgBinSearchFunc*; + using function_type = PyArray_ArgBinSearchFunc *; struct value_type { int typenum; function_type binsearch[NPY_NSEARCHSIDES]; }; - template<class... Tags> - static constexpr std::array<value_type, sizeof...(Tags)> make_binsearch_map(npy::taglist<Tags...>) { + template <class... Tags> + static constexpr std::array<value_type, sizeof...(Tags)> + make_binsearch_map(npy::taglist<Tags...>) + { return std::array<value_type, sizeof...(Tags)>{ - value_type{ - Tags::type_value, - { - (function_type)&argbinsearch<Tags, left>, - (function_type)argbinsearch<Tags, right> - } - }... - }; + value_type{Tags::type_value, + {(function_type)&argbinsearch<Tags, left>, + (function_type)argbinsearch<Tags, right>}}...}; } static constexpr std::array<function_type, 2> npy_map = { - (function_type)&npy_argbinsearch<left>, - (function_type)&npy_argbinsearch<right> - }; + (function_type)&npy_argbinsearch<left>, + (function_type)&npy_argbinsearch<right>}; }; -constexpr std::array<binsearch_base<arg>::function_type, 2> binsearch_base<arg>::npy_map; +constexpr std::array<binsearch_base<arg>::function_type, 2> + binsearch_base<arg>::npy_map; -template<> +template <> struct binsearch_base<noarg> { - using function_type = PyArray_BinSearchFunc*; + using function_type = PyArray_BinSearchFunc *; struct value_type { int typenum; function_type binsearch[NPY_NSEARCHSIDES]; }; - template<class... Tags> - static constexpr std::array<value_type, sizeof...(Tags)> make_binsearch_map(npy::taglist<Tags...>) { + template <class... Tags> + static constexpr std::array<value_type, sizeof...(Tags)> + make_binsearch_map(npy::taglist<Tags...>) + { return std::array<value_type, sizeof...(Tags)>{ - value_type{ - Tags::type_value, - { - (function_type)&binsearch<Tags, left>, - (function_type)binsearch<Tags, right> - } - }... - }; + value_type{Tags::type_value, + {(function_type)&binsearch<Tags, left>, + (function_type)binsearch<Tags, right>}}...}; } static constexpr std::array<function_type, 2> npy_map = { - (function_type)&npy_binsearch<left>, - (function_type)&npy_binsearch<right> - }; + (function_type)&npy_binsearch<left>, + (function_type)&npy_binsearch<right>}; }; -constexpr std::array<binsearch_base<noarg>::function_type, 2> binsearch_base<noarg>::npy_map; +constexpr std::array<binsearch_base<noarg>::function_type, 2> + binsearch_base<noarg>::npy_map; // Handle generation of all binsearch variants -template<arg_t arg> +template <arg_t arg> struct binsearch_t : binsearch_base<arg> { using binsearch_base<arg>::make_binsearch_map; using value_type = typename binsearch_base<arg>::value_type; using taglist = npy::taglist< - /* If adding new types, make sure to keep them ordered by type num */ - npy::bool_tag, npy::byte_tag, npy::ubyte_tag, npy::short_tag, - npy::ushort_tag, npy::int_tag, npy::uint_tag, npy::long_tag, - npy::ulong_tag, npy::longlong_tag, npy::ulonglong_tag, npy::half_tag, - npy::float_tag, npy::double_tag, npy::longdouble_tag, npy::cfloat_tag, - npy::cdouble_tag, npy::clongdouble_tag, npy::datetime_tag, - npy::timedelta_tag>; - - static constexpr std::array<value_type, taglist::size> map = make_binsearch_map(taglist()); + /* If adding new types, make sure to keep them ordered by type num + */ + npy::bool_tag, npy::byte_tag, npy::ubyte_tag, npy::short_tag, + npy::ushort_tag, npy::int_tag, npy::uint_tag, npy::long_tag, + npy::ulong_tag, npy::longlong_tag, npy::ulonglong_tag, + npy::half_tag, npy::float_tag, npy::double_tag, + npy::longdouble_tag, npy::cfloat_tag, npy::cdouble_tag, + npy::clongdouble_tag, npy::datetime_tag, npy::timedelta_tag>; + + static constexpr std::array<value_type, taglist::size> map = + make_binsearch_map(taglist()); }; -template<arg_t arg> -constexpr std::array<typename binsearch_t<arg>::value_type, binsearch_t<arg>::taglist::size> binsearch_t<arg>::map; +template <arg_t arg> +constexpr std::array<typename binsearch_t<arg>::value_type, + binsearch_t<arg>::taglist::size> + binsearch_t<arg>::map; -template<arg_t arg> +template <arg_t arg> static NPY_INLINE typename binsearch_t<arg>::function_type _get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) { using binsearch = binsearch_t<arg>; - npy_intp nfuncs = binsearch::map.size();; + npy_intp nfuncs = binsearch::map.size(); + ; npy_intp min_idx = 0; npy_intp max_idx = nfuncs; int type = dtype->type_num; @@ -359,8 +373,7 @@ _get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) } } - if (min_idx < nfuncs && - binsearch::map[min_idx].typenum == type) { + if (min_idx < nfuncs && binsearch::map[min_idx].typenum == type) { return binsearch::map[min_idx].binsearch[side]; } @@ -371,17 +384,21 @@ _get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) return NULL; } - /* ***************************************************************************** ** C INTERFACE ** ***************************************************************************** */ extern "C" { - NPY_NO_EXPORT PyArray_BinSearchFunc* get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) { - return _get_binsearch_func<noarg>(dtype, side); - } - NPY_NO_EXPORT PyArray_ArgBinSearchFunc* get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) { - return _get_binsearch_func<arg>(dtype, side); - } +NPY_NO_EXPORT PyArray_BinSearchFunc * +get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) +{ + return _get_binsearch_func<noarg>(dtype, side); +} + +NPY_NO_EXPORT PyArray_ArgBinSearchFunc * +get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) +{ + return _get_binsearch_func<arg>(dtype, side); +} } diff --git a/numpy/core/src/npysort/selection.cpp b/numpy/core/src/npysort/selection.cpp index ea73bf5bc..7fd04a660 100644 --- a/numpy/core/src/npysort/selection.cpp +++ b/numpy/core/src/npysort/selection.cpp @@ -111,13 +111,16 @@ median3_swap_(type *v, npy_intp *tosort, npy_intp low, npy_intp mid, Idx<arg> idx(tosort); Sortee<type, arg> sortee(v, tosort); - if (Tag::less(v[idx(high)], v[idx(mid)])) + if (Tag::less(v[idx(high)], v[idx(mid)])) { std::swap(sortee(high), sortee(mid)); - if (Tag::less(v[idx(high)], v[idx(low)])) + } + if (Tag::less(v[idx(high)], v[idx(low)])) { std::swap(sortee(high), sortee(low)); + } /* move pivot to low */ - if (Tag::less(v[idx(low)], v[idx(mid)])) + if (Tag::less(v[idx(low)], v[idx(mid)])) { std::swap(sortee(low), sortee(mid)); + } /* move 3-lowest element to low + 1 */ std::swap(sortee(mid), sortee(low + 1)); } @@ -175,13 +178,16 @@ unguarded_partition_(type *v, npy_intp *tosort, const type pivot, npy_intp *ll, Sortee<type, arg> sortee(v, tosort); for (;;) { - do (*ll)++; - while (Tag::less(v[idx(*ll)], pivot)); - do (*hh)--; - while (Tag::less(pivot, v[idx(*hh)])); - - if (*hh < *ll) + do { + (*ll)++; + } while (Tag::less(v[idx(*ll)], pivot)); + do { + (*hh)--; + } while (Tag::less(pivot, v[idx(*hh)])); + + if (*hh < *ll) { break; + } std::swap(sortee(*ll), sortee(*hh)); } @@ -209,8 +215,9 @@ median_of_median5_(type *v, npy_intp *tosort, const npy_intp num, std::swap(sortee(subleft + m), sortee(i)); } - if (nmed > 2) + if (nmed > 2) { introselect_<Tag, arg>(v, tosort, nmed, nmed / 2, pivots, npiv); + } return nmed / 2; } @@ -268,8 +275,9 @@ introselect_(type *v, npy_intp *tosort, npy_intp num, npy_intp kth, npy_intp high = num - 1; int depth_limit; - if (npiv == NULL) + if (npiv == NULL) { pivots = NULL; + } while (pivots != NULL && *npiv > 0) { if (pivots[*npiv - 1] > kth) { @@ -361,10 +369,12 @@ introselect_(type *v, npy_intp *tosort, npy_intp num, npy_intp kth, store_pivot(hh, kth, pivots, npiv); } - if (hh >= kth) + if (hh >= kth) { high = hh - 1; - if (hh <= kth) + } + if (hh <= kth) { low = ll; + } } /* two elements */ |
