summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBago Amirbekian <mrbago@gmail.com>2013-08-02 19:01:40 -0700
committerBago Amirbekian <mrbago@gmail.com>2013-08-17 09:27:06 -0700
commitbb925ed337d0346c414a43506537750a87c4771d (patch)
treee8c755976f5c9b7f1cf9b1803fca27ceb749f36e /numpy
parenta06b8005f042a46c0609a559f6bae21d43572d1e (diff)
downloadnumpy-bb925ed337d0346c414a43506537750a87c4771d.tar.gz
RF - have searchsorted copy haystack when the needle is bigger than the haystack
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/item_selection.c26
1 files changed, 14 insertions, 12 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 7fca4fa9b..18eb26f97 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -2095,6 +2095,7 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
PyArrayObject *sorter = NULL;
PyArrayObject *ret = NULL;
PyArray_Descr *dtype;
+ int ap1_flags = NPY_ARRAY_NOTSWAPPED | NPY_ARRAY_ALIGNED;
NPY_BEGIN_THREADS_DEF;
/* Find common type */
@@ -2103,23 +2104,24 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
return NULL;
}
- /* need ap1 as contiguous array and of right type */
- Py_INCREF(dtype);
- ap1 = (PyArrayObject *)PyArray_CheckFromAny((PyObject *)op1, dtype,
- 1, 1,
- NPY_ARRAY_ALIGNED | NPY_ARRAY_NOTSWAPPED,
- NULL);
- if (ap1 == NULL) {
- Py_DECREF(dtype);
- return NULL;
- }
-
/* need ap2 as contiguous array and of right type */
+ Py_INCREF(dtype);
ap2 = (PyArrayObject *)PyArray_CheckFromAny(op2, dtype,
0, 0,
NPY_ARRAY_CARRAY_RO | NPY_ARRAY_NOTSWAPPED,
NULL);
if (ap2 == NULL) {
+ Py_DECREF(dtype);
+ return NULL;
+ }
+
+ if (PyArray_SIZE(ap2) > PyArray_SIZE(op1)) {
+ ap1_flags |= NPY_ARRAY_CARRAY_RO;
+ }
+
+ ap1 = (PyArrayObject *)PyArray_CheckFromAny((PyObject *)op1, dtype,
+ 1, 1, ap1_flags, NULL);
+ if (ap1 == NULL) {
goto fail;
}
/* check that comparison function exists */
@@ -2148,7 +2150,7 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
/* convert to known integer size */
sorter = (PyArrayObject *)PyArray_FromArray(ap3,
PyArray_DescrFromType(NPY_INTP),
- NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED);
+ NPY_ARRAY_ALIGNED | NPY_ARRAY_NOTSWAPPED);
if (sorter == NULL) {
PyErr_SetString(PyExc_ValueError,
"could not parse sorter argument");