summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2021-12-13 22:52:37 -0700
committerGitHub <noreply@github.com>2021-12-13 22:52:37 -0700
commit5cd043c774180a135b56afb9f25c8cb6bf1e66f6 (patch)
tree3f0c3c0c1f572fc1276d3ead18c27673a8b9b180
parent43c113dd7aa36ef833315035858ea98a7b4732c5 (diff)
parentb358604b6d0fced16a0162d0d00aa02653f2e4a3 (diff)
downloadnumpy-5cd043c774180a135b56afb9f25c8cb6bf1e66f6.tar.gz
Merge pull request #20489 from serge-sans-paille/feature/road-to-cxx-binsearch
Convert binsearch.c.src to binsearch.cpp
-rw-r--r--numpy/core/setup.py4
-rw-r--r--numpy/core/src/common/npy_binsearch.h31
-rw-r--r--numpy/core/src/common/npy_binsearch.h.src144
-rw-r--r--numpy/core/src/common/npy_sort.h.src9
-rw-r--r--numpy/core/src/common/numpy_tag.h147
-rw-r--r--numpy/core/src/npysort/binsearch.c.src250
-rw-r--r--numpy/core/src/npysort/binsearch.cpp387
-rw-r--r--numpy/core/src/npysort/npysort_common.h18
8 files changed, 589 insertions, 401 deletions
diff --git a/numpy/core/setup.py b/numpy/core/setup.py
index 17fbd99af..9b6e74240 100644
--- a/numpy/core/setup.py
+++ b/numpy/core/setup.py
@@ -945,8 +945,8 @@ def configuration(parent_package='',top_path=None):
join('src', 'npysort', 'radixsort.cpp'),
join('src', 'common', 'npy_partition.h.src'),
join('src', 'npysort', 'selection.c.src'),
- join('src', 'common', 'npy_binsearch.h.src'),
- join('src', 'npysort', 'binsearch.c.src'),
+ join('src', 'common', 'npy_binsearch.h'),
+ join('src', 'npysort', 'binsearch.cpp'),
]
#######################################################################
diff --git a/numpy/core/src/common/npy_binsearch.h b/numpy/core/src/common/npy_binsearch.h
new file mode 100644
index 000000000..8d2f0714d
--- /dev/null
+++ b/numpy/core/src/common/npy_binsearch.h
@@ -0,0 +1,31 @@
+#ifndef __NPY_BINSEARCH_H__
+#define __NPY_BINSEARCH_H__
+
+#include "npy_sort.h"
+#include <numpy/npy_common.h>
+#include <numpy/ndarraytypes.h>
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef void (PyArray_BinSearchFunc)(const char*, const char*, char*,
+ npy_intp, npy_intp,
+ npy_intp, npy_intp, npy_intp,
+ PyArrayObject*);
+
+typedef int (PyArray_ArgBinSearchFunc)(const char*, const char*,
+ const char*, char*,
+ npy_intp, npy_intp, npy_intp,
+ npy_intp, npy_intp, npy_intp,
+ PyArrayObject*);
+
+NPY_NO_EXPORT PyArray_BinSearchFunc* get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side);
+NPY_NO_EXPORT PyArray_ArgBinSearchFunc* get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/numpy/core/src/common/npy_binsearch.h.src b/numpy/core/src/common/npy_binsearch.h.src
deleted file mode 100644
index 052c44482..000000000
--- a/numpy/core/src/common/npy_binsearch.h.src
+++ /dev/null
@@ -1,144 +0,0 @@
-#ifndef __NPY_BINSEARCH_H__
-#define __NPY_BINSEARCH_H__
-
-#include "npy_sort.h"
-#include <numpy/npy_common.h>
-#include <numpy/ndarraytypes.h>
-
-#define ARRAY_SIZE(a) (sizeof(a)/sizeof(a[0]))
-
-typedef void (PyArray_BinSearchFunc)(const char*, const char*, char*,
- npy_intp, npy_intp,
- npy_intp, npy_intp, npy_intp,
- PyArrayObject*);
-
-typedef int (PyArray_ArgBinSearchFunc)(const char*, const char*,
- const char*, char*,
- npy_intp, npy_intp, npy_intp,
- npy_intp, npy_intp, npy_intp,
- PyArrayObject*);
-
-typedef struct {
- int typenum;
- PyArray_BinSearchFunc *binsearch[NPY_NSEARCHSIDES];
-} binsearch_map;
-
-typedef struct {
- int typenum;
- PyArray_ArgBinSearchFunc *argbinsearch[NPY_NSEARCHSIDES];
-} argbinsearch_map;
-
-/**begin repeat
- *
- * #side = left, right#
- */
-
-/**begin repeat1
- *
- * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong,
- * longlong, ulonglong, half, float, double, longdouble,
- * cfloat, cdouble, clongdouble, datetime, timedelta#
- */
-
-NPY_NO_EXPORT void
-binsearch_@side@_@suff@(const char *arr, const char *key, char *ret,
- npy_intp arr_len, npy_intp key_len,
- npy_intp arr_str, npy_intp key_str, npy_intp ret_str,
- PyArrayObject *unused);
-NPY_NO_EXPORT int
-argbinsearch_@side@_@suff@(const char *arr, const char *key,
- const char *sort, char *ret,
- npy_intp arr_len, npy_intp key_len,
- npy_intp arr_str, npy_intp key_str,
- npy_intp sort_str, npy_intp ret_str,
- PyArrayObject *unused);
-/**end repeat1**/
-
-NPY_NO_EXPORT void
-npy_binsearch_@side@(const char *arr, const char *key, char *ret,
- npy_intp arr_len, npy_intp key_len,
- npy_intp arr_str, npy_intp key_str,
- npy_intp ret_str, PyArrayObject *cmp);
-NPY_NO_EXPORT int
-npy_argbinsearch_@side@(const char *arr, const char *key,
- const char *sort, char *ret,
- npy_intp arr_len, npy_intp key_len,
- npy_intp arr_str, npy_intp key_str,
- npy_intp sort_str, npy_intp ret_str,
- PyArrayObject *cmp);
-/**end repeat**/
-
-/**begin repeat
- *
- * #arg = , arg#
- * #Arg = , Arg#
- */
-
-static @arg@binsearch_map _@arg@binsearch_map[] = {
- /* If adding new types, make sure to keep them ordered by type num */
- /**begin repeat1
- *
- * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
- * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
- * CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA, HALF#
- * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong,
- * longlong, ulonglong, float, double, longdouble,
- * cfloat, cdouble, clongdouble, datetime, timedelta, half#
- */
- {NPY_@TYPE@,
- {
- &@arg@binsearch_left_@suff@,
- &@arg@binsearch_right_@suff@,
- },
- },
- /**end repeat1**/
-};
-
-static PyArray_@Arg@BinSearchFunc *gen@arg@binsearch_map[] = {
- &npy_@arg@binsearch_left,
- &npy_@arg@binsearch_right,
-};
-
-static NPY_INLINE PyArray_@Arg@BinSearchFunc*
-get_@arg@binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side)
-{
- npy_intp nfuncs = ARRAY_SIZE(_@arg@binsearch_map);
- npy_intp min_idx = 0;
- npy_intp max_idx = nfuncs;
- int type = dtype->type_num;
-
- if (side >= NPY_NSEARCHSIDES) {
- return NULL;
- }
-
- /*
- * It seems only fair that a binary search function be searched for
- * using a binary search...
- */
- while (min_idx < max_idx) {
- npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
-
- if (_@arg@binsearch_map[mid_idx].typenum < type) {
- min_idx = mid_idx + 1;
- }
- else {
- max_idx = mid_idx;
- }
- }
-
- if (min_idx < nfuncs &&
- _@arg@binsearch_map[min_idx].typenum == type) {
- return _@arg@binsearch_map[min_idx].@arg@binsearch[side];
- }
-
- if (dtype->f->compare) {
- return gen@arg@binsearch_map[side];
- }
-
- return NULL;
-}
-/**end repeat**/
-
-#undef ARRAY_SIZE
-
-#endif
diff --git a/numpy/core/src/common/npy_sort.h.src b/numpy/core/src/common/npy_sort.h.src
index b4a1e9b0c..a3f556f56 100644
--- a/numpy/core/src/common/npy_sort.h.src
+++ b/numpy/core/src/common/npy_sort.h.src
@@ -18,6 +18,11 @@ static NPY_INLINE int npy_get_msb(npy_uintp unum)
return depth_limit;
}
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
/*
*****************************************************************************
@@ -102,4 +107,8 @@ NPY_NO_EXPORT int npy_aheapsort(void *vec, npy_intp *ind, npy_intp cnt, void *ar
NPY_NO_EXPORT int npy_amergesort(void *vec, npy_intp *ind, npy_intp cnt, void *arr);
NPY_NO_EXPORT int npy_atimsort(void *vec, npy_intp *ind, npy_intp cnt, void *arr);
+#ifdef __cplusplus
+}
+#endif
+
#endif
diff --git a/numpy/core/src/common/numpy_tag.h b/numpy/core/src/common/numpy_tag.h
index dc8d5286b..60e9b02cd 100644
--- a/numpy/core/src/common/numpy_tag.h
+++ b/numpy/core/src/common/numpy_tag.h
@@ -1,8 +1,15 @@
#ifndef _NPY_COMMON_TAG_H_
#define _NPY_COMMON_TAG_H_
+#include "../npysort/npysort_common.h"
+
namespace npy {
+template<typename... tags>
+struct taglist {
+ static constexpr unsigned size = sizeof...(tags);
+};
+
struct integral_tag {
};
struct floating_point_tag {
@@ -14,63 +21,203 @@ struct date_tag {
struct bool_tag : integral_tag {
using type = npy_bool;
+ static constexpr NPY_TYPES type_value = NPY_BOOL;
+ static int less(type const& a, type const& b) {
+ return BOOL_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct byte_tag : integral_tag {
using type = npy_byte;
+ static constexpr NPY_TYPES type_value = NPY_BYTE;
+ static int less(type const& a, type const& b) {
+ return BYTE_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct ubyte_tag : integral_tag {
using type = npy_ubyte;
+ static constexpr NPY_TYPES type_value = NPY_UBYTE;
+ static int less(type const& a, type const& b) {
+ return UBYTE_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct short_tag : integral_tag {
using type = npy_short;
+ static constexpr NPY_TYPES type_value = NPY_SHORT;
+ static int less(type const& a, type const& b) {
+ return SHORT_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct ushort_tag : integral_tag {
using type = npy_ushort;
+ static constexpr NPY_TYPES type_value = NPY_USHORT;
+ static int less(type const& a, type const& b) {
+ return USHORT_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct int_tag : integral_tag {
using type = npy_int;
+ static constexpr NPY_TYPES type_value = NPY_INT;
+ static int less(type const& a, type const& b) {
+ return INT_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct uint_tag : integral_tag {
using type = npy_uint;
+ static constexpr NPY_TYPES type_value = NPY_UINT;
+ static int less(type const& a, type const& b) {
+ return UINT_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct long_tag : integral_tag {
using type = npy_long;
+ static constexpr NPY_TYPES type_value = NPY_LONG;
+ static int less(type const& a, type const& b) {
+ return LONG_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct ulong_tag : integral_tag {
using type = npy_ulong;
+ static constexpr NPY_TYPES type_value = NPY_ULONG;
+ static int less(type const& a, type const& b) {
+ return ULONG_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct longlong_tag : integral_tag {
using type = npy_longlong;
+ static constexpr NPY_TYPES type_value = NPY_LONGLONG;
+ static int less(type const& a, type const& b) {
+ return LONGLONG_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct ulonglong_tag : integral_tag {
using type = npy_ulonglong;
+ static constexpr NPY_TYPES type_value = NPY_ULONGLONG;
+ static int less(type const& a, type const& b) {
+ return ULONGLONG_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct half_tag {
using type = npy_half;
+ static constexpr NPY_TYPES type_value = NPY_HALF;
+ static int less(type const& a, type const& b) {
+ return HALF_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct float_tag : floating_point_tag {
using type = npy_float;
+ static constexpr NPY_TYPES type_value = NPY_FLOAT;
+ static int less(type const& a, type const& b) {
+ return FLOAT_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct double_tag : floating_point_tag {
using type = npy_double;
+ static constexpr NPY_TYPES type_value = NPY_DOUBLE;
+ static int less(type const& a, type const& b) {
+ return DOUBLE_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct longdouble_tag : floating_point_tag {
using type = npy_longdouble;
+ static constexpr NPY_TYPES type_value = NPY_LONGDOUBLE;
+ static int less(type const& a, type const& b) {
+ return LONGDOUBLE_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct cfloat_tag : complex_tag {
using type = npy_cfloat;
+ static constexpr NPY_TYPES type_value = NPY_CFLOAT;
+ static int less(type const& a, type const& b) {
+ return CFLOAT_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct cdouble_tag : complex_tag {
using type = npy_cdouble;
+ static constexpr NPY_TYPES type_value = NPY_CDOUBLE;
+ static int less(type const& a, type const& b) {
+ return CDOUBLE_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct clongdouble_tag : complex_tag {
using type = npy_clongdouble;
+ static constexpr NPY_TYPES type_value = NPY_CLONGDOUBLE;
+ static int less(type const& a, type const& b) {
+ return CLONGDOUBLE_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct datetime_tag : date_tag {
using type = npy_datetime;
+ static constexpr NPY_TYPES type_value = NPY_DATETIME;
+ static int less(type const& a, type const& b) {
+ return DATETIME_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
struct timedelta_tag : date_tag {
using type = npy_timedelta;
+ static constexpr NPY_TYPES type_value = NPY_TIMEDELTA;
+ static int less(type const& a, type const& b) {
+ return TIMEDELTA_LT(a, b);
+ }
+ static int less_equal(type const& a, type const& b) {
+ return !less(b, a);
+ }
};
} // namespace npy
diff --git a/numpy/core/src/npysort/binsearch.c.src b/numpy/core/src/npysort/binsearch.c.src
deleted file mode 100644
index 41165897b..000000000
--- a/numpy/core/src/npysort/binsearch.c.src
+++ /dev/null
@@ -1,250 +0,0 @@
-/* -*- c -*- */
-#define NPY_NO_DEPRECATED_API NPY_API_VERSION
-
-#include "npy_sort.h"
-#include "npysort_common.h"
-#include "npy_binsearch.h"
-
-#define NOT_USED NPY_UNUSED(unused)
-
-/*
- *****************************************************************************
- ** NUMERIC SEARCHES **
- *****************************************************************************
- */
-
-/**begin repeat
- *
- * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
- * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
- * CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA#
- * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong,
- * longlong, ulonglong, half, float, double, longdouble,
- * cfloat, cdouble, clongdouble, datetime, timedelta#
- * #type = npy_bool, npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int,
- * npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong,
- * npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat,
- * npy_cdouble, npy_clongdouble, npy_datetime, npy_timedelta#
- */
-
-#define @TYPE@_LTE(a, b) (!@TYPE@_LT((b), (a)))
-
-/**begin repeat1
- *
- * #side = left, right#
- * #CMP = LT, LTE#
- */
-
-NPY_NO_EXPORT void
-binsearch_@side@_@suff@(const char *arr, const char *key, char *ret,
- npy_intp arr_len, npy_intp key_len,
- npy_intp arr_str, npy_intp key_str, npy_intp ret_str,
- PyArrayObject *NOT_USED)
-{
- npy_intp min_idx = 0;
- npy_intp max_idx = arr_len;
- @type@ last_key_val;
-
- if (key_len == 0) {
- return;
- }
- last_key_val = *(const @type@ *)key;
-
- for (; key_len > 0; key_len--, key += key_str, ret += ret_str) {
- const @type@ key_val = *(const @type@ *)key;
- /*
- * Updating only one of the indices based on the previous key
- * gives the search a big boost when keys are sorted, but slightly
- * slows down things for purely random ones.
- */
- if (@TYPE@_LT(last_key_val, key_val)) {
- max_idx = arr_len;
- }
- else {
- min_idx = 0;
- max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len;
- }
-
- last_key_val = key_val;
-
- while (min_idx < max_idx) {
- const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
- const @type@ mid_val = *(const @type@ *)(arr + mid_idx*arr_str);
- if (@TYPE@_@CMP@(mid_val, key_val)) {
- min_idx = mid_idx + 1;
- }
- else {
- max_idx = mid_idx;
- }
- }
- *(npy_intp *)ret = min_idx;
- }
-}
-
-NPY_NO_EXPORT int
-argbinsearch_@side@_@suff@(const char *arr, const char *key,
- const char *sort, char *ret,
- npy_intp arr_len, npy_intp key_len,
- npy_intp arr_str, npy_intp key_str,
- npy_intp sort_str, npy_intp ret_str,
- PyArrayObject *NOT_USED)
-{
- npy_intp min_idx = 0;
- npy_intp max_idx = arr_len;
- @type@ last_key_val;
-
- if (key_len == 0) {
- return 0;
- }
- last_key_val = *(const @type@ *)key;
-
- for (; key_len > 0; key_len--, key += key_str, ret += ret_str) {
- const @type@ key_val = *(const @type@ *)key;
- /*
- * Updating only one of the indices based on the previous key
- * gives the search a big boost when keys are sorted, but slightly
- * slows down things for purely random ones.
- */
- if (@TYPE@_LT(last_key_val, key_val)) {
- max_idx = arr_len;
- }
- else {
- min_idx = 0;
- max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len;
- }
-
- last_key_val = key_val;
-
- while (min_idx < max_idx) {
- const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
- const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str);
- @type@ mid_val;
-
- if (sort_idx < 0 || sort_idx >= arr_len) {
- return -1;
- }
-
- mid_val = *(const @type@ *)(arr + sort_idx*arr_str);
-
- if (@TYPE@_@CMP@(mid_val, key_val)) {
- min_idx = mid_idx + 1;
- }
- else {
- max_idx = mid_idx;
- }
- }
- *(npy_intp *)ret = min_idx;
- }
- return 0;
-}
-
-/**end repeat1**/
-/**end repeat**/
-
-/*
- *****************************************************************************
- ** GENERIC SEARCH **
- *****************************************************************************
- */
-
- /**begin repeat
- *
- * #side = left, right#
- * #CMP = <, <=#
- */
-
-NPY_NO_EXPORT void
-npy_binsearch_@side@(const char *arr, const char *key, char *ret,
- npy_intp arr_len, npy_intp key_len,
- npy_intp arr_str, npy_intp key_str, npy_intp ret_str,
- PyArrayObject *cmp)
-{
- PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare;
- npy_intp min_idx = 0;
- npy_intp max_idx = arr_len;
- const char *last_key = key;
-
- for (; key_len > 0; key_len--, key += key_str, ret += ret_str) {
- /*
- * Updating only one of the indices based on the previous key
- * gives the search a big boost when keys are sorted, but slightly
- * slows down things for purely random ones.
- */
- if (compare(last_key, key, cmp) @CMP@ 0) {
- max_idx = arr_len;
- }
- else {
- min_idx = 0;
- max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len;
- }
-
- last_key = key;
-
- while (min_idx < max_idx) {
- const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
- const char *arr_ptr = arr + mid_idx*arr_str;
-
- if (compare(arr_ptr, key, cmp) @CMP@ 0) {
- min_idx = mid_idx + 1;
- }
- else {
- max_idx = mid_idx;
- }
- }
- *(npy_intp *)ret = min_idx;
- }
-}
-
-NPY_NO_EXPORT int
-npy_argbinsearch_@side@(const char *arr, const char *key,
- const char *sort, char *ret,
- npy_intp arr_len, npy_intp key_len,
- npy_intp arr_str, npy_intp key_str,
- npy_intp sort_str, npy_intp ret_str,
- PyArrayObject *cmp)
-{
- PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare;
- npy_intp min_idx = 0;
- npy_intp max_idx = arr_len;
- const char *last_key = key;
-
- for (; key_len > 0; key_len--, key += key_str, ret += ret_str) {
- /*
- * Updating only one of the indices based on the previous key
- * gives the search a big boost when keys are sorted, but slightly
- * slows down things for purely random ones.
- */
- if (compare(last_key, key, cmp) @CMP@ 0) {
- max_idx = arr_len;
- }
- else {
- min_idx = 0;
- max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len;
- }
-
- last_key = key;
-
- while (min_idx < max_idx) {
- const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
- const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str);
- const char *arr_ptr;
-
- if (sort_idx < 0 || sort_idx >= arr_len) {
- return -1;
- }
-
- arr_ptr = arr + sort_idx*arr_str;
-
- if (compare(arr_ptr, key, cmp) @CMP@ 0) {
- min_idx = mid_idx + 1;
- }
- else {
- max_idx = mid_idx;
- }
- }
- *(npy_intp *)ret = min_idx;
- }
- return 0;
-}
-
-/**end repeat**/
diff --git a/numpy/core/src/npysort/binsearch.cpp b/numpy/core/src/npysort/binsearch.cpp
new file mode 100644
index 000000000..cd5f03470
--- /dev/null
+++ b/numpy/core/src/npysort/binsearch.cpp
@@ -0,0 +1,387 @@
+/* -*- c -*- */
+
+#define NPY_NO_DEPRECATED_API NPY_API_VERSION
+
+#include "npy_sort.h"
+#include "numpy_tag.h"
+#include <numpy/npy_common.h>
+#include <numpy/ndarraytypes.h>
+
+#include "npy_binsearch.h"
+
+#include <array>
+#include <functional> // for std::less and std::less_equal
+
+// Enumerators for the variant of binsearch
+enum arg_t { noarg, arg};
+enum side_t { left, right};
+
+// Mapping from enumerators to comparators
+template<class Tag, side_t side>
+struct side_to_cmp;
+template<class Tag>
+struct side_to_cmp<Tag, left> { static constexpr auto value = Tag::less; };
+template<class Tag>
+struct side_to_cmp<Tag, right> { static constexpr auto value = Tag::less_equal; };
+
+template<side_t side>
+struct side_to_generic_cmp;
+template<>
+struct side_to_generic_cmp<left> { using type = std::less<int>; };
+template<>
+struct side_to_generic_cmp<right> { using type = std::less_equal<int>; };
+
+/*
+ *****************************************************************************
+ ** NUMERIC SEARCHES **
+ *****************************************************************************
+ */
+template<class Tag, side_t side>
+static void
+binsearch(const char *arr, const char *key, char *ret,
+ npy_intp arr_len, npy_intp key_len,
+ npy_intp arr_str, npy_intp key_str, npy_intp ret_str,
+ PyArrayObject*)
+{
+ using T = typename Tag::type;
+ auto cmp = side_to_cmp<Tag, side>::value;
+ npy_intp min_idx = 0;
+ npy_intp max_idx = arr_len;
+ T last_key_val;
+
+ if (key_len == 0) {
+ return;
+ }
+ last_key_val = *(const T *)key;
+
+ for (; key_len > 0; key_len--, key += key_str, ret += ret_str) {
+ const T key_val = *(const T *)key;
+ /*
+ * Updating only one of the indices based on the previous key
+ * gives the search a big boost when keys are sorted, but slightly
+ * slows down things for purely random ones.
+ */
+ if (cmp(last_key_val, key_val)) {
+ max_idx = arr_len;
+ }
+ else {
+ min_idx = 0;
+ max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len;
+ }
+
+ last_key_val = key_val;
+
+ while (min_idx < max_idx) {
+ const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
+ const T mid_val = *(const T *)(arr + mid_idx*arr_str);
+ if (cmp(mid_val, key_val)) {
+ min_idx = mid_idx + 1;
+ }
+ else {
+ max_idx = mid_idx;
+ }
+ }
+ *(npy_intp *)ret = min_idx;
+ }
+}
+
+template<class Tag, side_t side>
+static int
+argbinsearch(const char *arr, const char *key,
+ const char *sort, char *ret,
+ npy_intp arr_len, npy_intp key_len,
+ npy_intp arr_str, npy_intp key_str,
+ npy_intp sort_str, npy_intp ret_str, PyArrayObject*)
+{
+ using T = typename Tag::type;
+ auto cmp = side_to_cmp<Tag, side>::value;
+ npy_intp min_idx = 0;
+ npy_intp max_idx = arr_len;
+ T last_key_val;
+
+ if (key_len == 0) {
+ return 0;
+ }
+ last_key_val = *(const T *)key;
+
+ for (; key_len > 0; key_len--, key += key_str, ret += ret_str) {
+ const T key_val = *(const T *)key;
+ /*
+ * Updating only one of the indices based on the previous key
+ * gives the search a big boost when keys are sorted, but slightly
+ * slows down things for purely random ones.
+ */
+ if (cmp(last_key_val, key_val)) {
+ max_idx = arr_len;
+ }
+ else {
+ min_idx = 0;
+ max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len;
+ }
+
+ last_key_val = key_val;
+
+ while (min_idx < max_idx) {
+ const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
+ const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str);
+ T mid_val;
+
+ if (sort_idx < 0 || sort_idx >= arr_len) {
+ return -1;
+ }
+
+ mid_val = *(const T *)(arr + sort_idx*arr_str);
+
+ if (cmp(mid_val, key_val)) {
+ min_idx = mid_idx + 1;
+ }
+ else {
+ max_idx = mid_idx;
+ }
+ }
+ *(npy_intp *)ret = min_idx;
+ }
+ return 0;
+}
+
+/*
+ *****************************************************************************
+ ** GENERIC SEARCH **
+ *****************************************************************************
+ */
+
+template<side_t side>
+static void
+npy_binsearch(const char *arr, const char *key, char *ret,
+ npy_intp arr_len, npy_intp key_len,
+ npy_intp arr_str, npy_intp key_str, npy_intp ret_str,
+ PyArrayObject *cmp)
+{
+ using Cmp = typename side_to_generic_cmp<side>::type;
+ PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare;
+ npy_intp min_idx = 0;
+ npy_intp max_idx = arr_len;
+ const char *last_key = key;
+
+ for (; key_len > 0; key_len--, key += key_str, ret += ret_str) {
+ /*
+ * Updating only one of the indices based on the previous key
+ * gives the search a big boost when keys are sorted, but slightly
+ * slows down things for purely random ones.
+ */
+ if (Cmp{}(compare(last_key, key, cmp), 0)) {
+ max_idx = arr_len;
+ }
+ else {
+ min_idx = 0;
+ max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len;
+ }
+
+ last_key = key;
+
+ while (min_idx < max_idx) {
+ const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
+ const char *arr_ptr = arr + mid_idx*arr_str;
+
+ if (Cmp{}(compare(arr_ptr, key, cmp), 0)) {
+ min_idx = mid_idx + 1;
+ }
+ else {
+ max_idx = mid_idx;
+ }
+ }
+ *(npy_intp *)ret = min_idx;
+ }
+}
+
+template<side_t side>
+static int
+npy_argbinsearch(const char *arr, const char *key,
+ const char *sort, char *ret,
+ npy_intp arr_len, npy_intp key_len,
+ npy_intp arr_str, npy_intp key_str,
+ npy_intp sort_str, npy_intp ret_str,
+ PyArrayObject *cmp)
+{
+ using Cmp = typename side_to_generic_cmp<side>::type;
+ PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare;
+ npy_intp min_idx = 0;
+ npy_intp max_idx = arr_len;
+ const char *last_key = key;
+
+ for (; key_len > 0; key_len--, key += key_str, ret += ret_str) {
+ /*
+ * Updating only one of the indices based on the previous key
+ * gives the search a big boost when keys are sorted, but slightly
+ * slows down things for purely random ones.
+ */
+ if (Cmp{}(compare(last_key, key, cmp), 0)) {
+ max_idx = arr_len;
+ }
+ else {
+ min_idx = 0;
+ max_idx = (max_idx < arr_len) ? (max_idx + 1) : arr_len;
+ }
+
+ last_key = key;
+
+ while (min_idx < max_idx) {
+ const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
+ const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str);
+ const char *arr_ptr;
+
+ if (sort_idx < 0 || sort_idx >= arr_len) {
+ return -1;
+ }
+
+ arr_ptr = arr + sort_idx*arr_str;
+
+ if (Cmp{}(compare(arr_ptr, key, cmp), 0)) {
+ min_idx = mid_idx + 1;
+ }
+ else {
+ max_idx = mid_idx;
+ }
+ }
+ *(npy_intp *)ret = min_idx;
+ }
+ return 0;
+}
+
+/*
+ *****************************************************************************
+ ** GENERATOR **
+ *****************************************************************************
+ */
+
+template<arg_t arg>
+struct binsearch_base;
+
+template<>
+struct binsearch_base<arg> {
+ using function_type = PyArray_ArgBinSearchFunc*;
+ struct value_type {
+ int typenum;
+ function_type binsearch[NPY_NSEARCHSIDES];
+ };
+ template<class... Tags>
+ static constexpr std::array<value_type, sizeof...(Tags)> make_binsearch_map(npy::taglist<Tags...>) {
+ return std::array<value_type, sizeof...(Tags)>{
+ value_type{
+ Tags::type_value,
+ {
+ (function_type)&argbinsearch<Tags, left>,
+ (function_type)argbinsearch<Tags, right>
+ }
+ }...
+ };
+ }
+ static constexpr std::array<function_type, 2> npy_map = {
+ (function_type)&npy_argbinsearch<left>,
+ (function_type)&npy_argbinsearch<right>
+ };
+};
+constexpr std::array<binsearch_base<arg>::function_type, 2> binsearch_base<arg>::npy_map;
+
+template<>
+struct binsearch_base<noarg> {
+ using function_type = PyArray_BinSearchFunc*;
+ struct value_type {
+ int typenum;
+ function_type binsearch[NPY_NSEARCHSIDES];
+ };
+ template<class... Tags>
+ static constexpr std::array<value_type, sizeof...(Tags)> make_binsearch_map(npy::taglist<Tags...>) {
+ return std::array<value_type, sizeof...(Tags)>{
+ value_type{
+ Tags::type_value,
+ {
+ (function_type)&binsearch<Tags, left>,
+ (function_type)binsearch<Tags, right>
+ }
+ }...
+ };
+ }
+ static constexpr std::array<function_type, 2> npy_map = {
+ (function_type)&npy_binsearch<left>,
+ (function_type)&npy_binsearch<right>
+ };
+};
+constexpr std::array<binsearch_base<noarg>::function_type, 2> binsearch_base<noarg>::npy_map;
+
+// Handle generation of all binsearch variants
+template<arg_t arg>
+struct binsearch_t : binsearch_base<arg> {
+ using binsearch_base<arg>::make_binsearch_map;
+ using value_type = typename binsearch_base<arg>::value_type;
+
+ using taglist = npy::taglist<
+ /* If adding new types, make sure to keep them ordered by type num */
+ npy::bool_tag, npy::byte_tag, npy::ubyte_tag, npy::short_tag,
+ npy::ushort_tag, npy::int_tag, npy::uint_tag, npy::long_tag,
+ npy::ulong_tag, npy::longlong_tag, npy::ulonglong_tag, npy::half_tag,
+ npy::float_tag, npy::double_tag, npy::longdouble_tag, npy::cfloat_tag,
+ npy::cdouble_tag, npy::clongdouble_tag, npy::datetime_tag,
+ npy::timedelta_tag>;
+
+ static constexpr std::array<value_type, taglist::size> map = make_binsearch_map(taglist());
+};
+template<arg_t arg>
+constexpr std::array<typename binsearch_t<arg>::value_type, binsearch_t<arg>::taglist::size> binsearch_t<arg>::map;
+
+
+template<arg_t arg>
+static NPY_INLINE typename binsearch_t<arg>::function_type
+_get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side)
+{
+ using binsearch = binsearch_t<arg>;
+ npy_intp nfuncs = binsearch::map.size();;
+ npy_intp min_idx = 0;
+ npy_intp max_idx = nfuncs;
+ int type = dtype->type_num;
+
+ if ((int)side >= (int)NPY_NSEARCHSIDES) {
+ return NULL;
+ }
+
+ /*
+ * It seems only fair that a binary search function be searched for
+ * using a binary search...
+ */
+ while (min_idx < max_idx) {
+ npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
+
+ if (binsearch::map[mid_idx].typenum < type) {
+ min_idx = mid_idx + 1;
+ }
+ else {
+ max_idx = mid_idx;
+ }
+ }
+
+ if (min_idx < nfuncs &&
+ binsearch::map[min_idx].typenum == type) {
+ return binsearch::map[min_idx].binsearch[side];
+ }
+
+ if (dtype->f->compare) {
+ return binsearch::npy_map[side];
+ }
+
+ return NULL;
+}
+
+
+/*
+ *****************************************************************************
+ ** C INTERFACE **
+ *****************************************************************************
+ */
+extern "C" {
+ NPY_NO_EXPORT PyArray_BinSearchFunc* get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) {
+ return _get_binsearch_func<noarg>(dtype, side);
+ }
+ NPY_NO_EXPORT PyArray_ArgBinSearchFunc* get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) {
+ return _get_binsearch_func<arg>(dtype, side);
+ }
+}
diff --git a/numpy/core/src/npysort/npysort_common.h b/numpy/core/src/npysort/npysort_common.h
index 2a6e4d421..a537fcd08 100644
--- a/numpy/core/src/npysort/npysort_common.h
+++ b/numpy/core/src/npysort/npysort_common.h
@@ -4,6 +4,10 @@
#include <stdlib.h>
#include <numpy/ndarraytypes.h>
+#ifdef __cplusplus
+extern "C" {
+#endif
+
/*
*****************************************************************************
** SWAP MACROS **
@@ -139,14 +143,14 @@ LONGDOUBLE_LT(npy_longdouble a, npy_longdouble b)
NPY_INLINE static int
-npy_half_isnan(npy_half h)
+_npy_half_isnan(npy_half h)
{
return ((h&0x7c00u) == 0x7c00u) && ((h&0x03ffu) != 0x0000u);
}
NPY_INLINE static int
-npy_half_lt_nonan(npy_half h1, npy_half h2)
+_npy_half_lt_nonan(npy_half h1, npy_half h2)
{
if (h1&0x8000u) {
if (h2&0x8000u) {
@@ -173,11 +177,11 @@ HALF_LT(npy_half a, npy_half b)
{
int ret;
- if (npy_half_isnan(b)) {
- ret = !npy_half_isnan(a);
+ if (_npy_half_isnan(b)) {
+ ret = !_npy_half_isnan(a);
}
else {
- ret = !npy_half_isnan(a) && npy_half_lt_nonan(a, b);
+ ret = !_npy_half_isnan(a) && _npy_half_lt_nonan(a, b);
}
return ret;
@@ -373,4 +377,8 @@ GENERIC_SWAP(char *a, char *b, size_t len)
}
}
+#ifdef __cplusplus
+}
+#endif
+
#endif