diff options
| -rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 48 |
1 files changed, 8 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index e7fbb88cd..78f547de2 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -971,8 +971,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op, PyArrayMultiIterObject *mit; int val; - /* Cast arrays to a common type */ - if (PyArray_TYPE(self) != PyArray_DESCR(other)->type_num) { + if (PyArray_TYPE(self) != PyArray_TYPE(other)) { /* * Comparison between Bytes and Unicode is not defined in Py3K; * we follow. @@ -981,53 +980,22 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op, return Py_NotImplemented; } if (PyArray_ISNOTSWAPPED(self) != PyArray_ISNOTSWAPPED(other)) { - PyObject *new; - if (PyArray_TYPE(self) == NPY_STRING && - PyArray_DESCR(other)->type_num == NPY_UNICODE) { - PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(other)); - unicode->elsize = PyArray_DESCR(self)->elsize << 2; - new = PyArray_FromAny((PyObject *)self, unicode, - 0, 0, 0, NULL); - if (new == NULL) { - return NULL; - } - Py_INCREF(other); - self = (PyArrayObject *)new; - } - else if ((PyArray_TYPE(self) == NPY_UNICODE) && - ((PyArray_DESCR(other)->type_num == NPY_STRING) || - (PyArray_ISNOTSWAPPED(self) != PyArray_ISNOTSWAPPED(other)))) { - PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(self)); - - if (PyArray_DESCR(other)->type_num == NPY_STRING) { - unicode->elsize = PyArray_DESCR(other)->elsize << 2; - } - else { - unicode->elsize = PyArray_DESCR(other)->elsize; - } - new = PyArray_FromAny((PyObject *)other, unicode, - 0, 0, 0, NULL); - if (new == NULL) { - return NULL; - } - Py_INCREF(self); - other = (PyArrayObject *)new; - } - else { - PyErr_SetString(PyExc_TypeError, - "invalid string data-types " - "in comparison"); + /* Cast `other` to the same byte order as `self` (both unicode here) */ + PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(self)); + unicode->elsize = PyArray_DESCR(other)->elsize; + PyObject *new = PyArray_FromAny((PyObject *)other, + unicode, 0, 0, 0, NULL); + if (new == NULL) { return NULL; } + other = (PyArrayObject *)new; } else { - Py_INCREF(self); Py_INCREF(other); } /* Broad-cast the arrays to a common shape */ mit = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, self, other); - Py_DECREF(self); Py_DECREF(other); if (mit == NULL) { return NULL; |
