summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--benchmarks/benchmarks/bench_strings.py45
-rw-r--r--numpy/core/include/numpy/experimental_dtype_api.h2
-rw-r--r--numpy/core/setup.py1
-rw-r--r--numpy/core/src/common/numpyos.h8
-rw-r--r--numpy/core/src/multiarray/array_method.h9
-rw-r--r--numpy/core/src/multiarray/arrayobject.c402
-rw-r--r--numpy/core/src/multiarray/common_dtype.h8
-rw-r--r--numpy/core/src/multiarray/convert_datatype.h12
-rw-r--r--numpy/core/src/multiarray/dtypemeta.h7
-rw-r--r--numpy/core/src/multiarray/experimental_public_dtype_api.c32
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c12
-rw-r--r--numpy/core/src/umath/dispatching.c32
-rw-r--r--numpy/core/src/umath/dispatching.h9
-rw-r--r--numpy/core/src/umath/string_ufuncs.cpp449
-rw-r--r--numpy/core/src/umath/string_ufuncs.h19
-rw-r--r--numpy/core/src/umath/umathmodule.c7
-rw-r--r--numpy/core/tests/test_deprecations.py2
-rw-r--r--numpy/core/tests/test_strings.py85
-rw-r--r--numpy/core/tests/test_unicode.py9
19 files changed, 716 insertions, 434 deletions
diff --git a/benchmarks/benchmarks/bench_strings.py b/benchmarks/benchmarks/bench_strings.py
new file mode 100644
index 000000000..e500d7f3f
--- /dev/null
+++ b/benchmarks/benchmarks/bench_strings.py
@@ -0,0 +1,45 @@
+from __future__ import absolute_import, division, print_function
+
+from .common import Benchmark
+
+import numpy as np
+import operator
+
+
+_OPERATORS = {
+ '==': operator.eq,
+ '!=': operator.ne,
+ '<': operator.lt,
+ '<=': operator.le,
+ '>': operator.gt,
+ '>=': operator.ge,
+}
+
+
+class StringComparisons(Benchmark):
+ # Basic string comparison speed tests
+ params = [
+ [100, 10000, (1000, 20)],
+ ['U', 'S'],
+ [True, False],
+ ['==', '!=', '<', '<=', '>', '>=']]
+ param_names = ['shape', 'dtype', 'contig', 'operator']
+ int64 = np.dtype(np.int64)
+
+ def setup(self, shape, dtype, contig, operator):
+ self.arr = np.arange(np.prod(shape)).astype(dtype).reshape(shape)
+ self.arr_identical = self.arr.copy()
+ self.arr_different = self.arr[::-1].copy()
+
+ if not contig:
+ self.arr = self.arr[..., ::2]
+ self.arr_identical = self.arr_identical[..., ::2]
+ self.arr_different = self.arr_different[..., ::2]
+
+ self.operator = _OPERATORS[operator]
+
+ def time_compare_identical(self, shape, dtype, contig, operator):
+ self.operator(self.arr, self.arr_identical)
+
+ def time_compare_different(self, shape, dtype, contig, operator):
+ self.operator(self.arr, self.arr_different)
diff --git a/numpy/core/include/numpy/experimental_dtype_api.h b/numpy/core/include/numpy/experimental_dtype_api.h
index 1dd6215e6..23e9a8d21 100644
--- a/numpy/core/include/numpy/experimental_dtype_api.h
+++ b/numpy/core/include/numpy/experimental_dtype_api.h
@@ -214,7 +214,7 @@ typedef struct {
} PyArrayMethod_Spec;
-typedef PyObject *_ufunc_addloop_fromspec_func(
+typedef int _ufunc_addloop_fromspec_func(
PyObject *ufunc, PyArrayMethod_Spec *spec);
/*
* The main ufunc registration function. This adds a new implementation/loop
diff --git a/numpy/core/setup.py b/numpy/core/setup.py
index 7d072c15c..616a1b37f 100644
--- a/numpy/core/setup.py
+++ b/numpy/core/setup.py
@@ -1081,6 +1081,7 @@ def configuration(parent_package='',top_path=None):
join('src', 'umath', 'scalarmath.c.src'),
join('src', 'umath', 'ufunc_type_resolution.c'),
join('src', 'umath', 'override.c'),
+ join('src', 'umath', 'string_ufuncs.cpp'),
# For testing. Eventually, should use public API and be separate:
join('src', 'umath', '_scaled_float_dtype.c'),
]
diff --git a/numpy/core/src/common/numpyos.h b/numpy/core/src/common/numpyos.h
index ce49cbea7..6e526af17 100644
--- a/numpy/core/src/common/numpyos.h
+++ b/numpy/core/src/common/numpyos.h
@@ -1,6 +1,10 @@
#ifndef NUMPY_CORE_SRC_COMMON_NPY_NUMPYOS_H_
#define NUMPY_CORE_SRC_COMMON_NPY_NUMPYOS_H_
+#ifdef __cplusplus
+extern "C" {
+#endif
+
NPY_NO_EXPORT char*
NumPyOS_ascii_formatd(char *buffer, size_t buf_size,
const char *format,
@@ -39,4 +43,8 @@ NumPyOS_strtoll(const char *str, char **endptr, int base);
NPY_NO_EXPORT npy_ulonglong
NumPyOS_strtoull(const char *str, char **endptr, int base);
+#ifdef __cplusplus
+}
+#endif
+
#endif /* NUMPY_CORE_SRC_COMMON_NPY_NUMPYOS_H_ */
diff --git a/numpy/core/src/multiarray/array_method.h b/numpy/core/src/multiarray/array_method.h
index 30dd94a80..6e6a026bc 100644
--- a/numpy/core/src/multiarray/array_method.h
+++ b/numpy/core/src/multiarray/array_method.h
@@ -7,6 +7,9 @@
#include <Python.h>
#include <numpy/ndarraytypes.h>
+#ifdef __cplusplus
+extern "C" {
+#endif
typedef enum {
/* Flag for whether the GIL is required */
@@ -249,6 +252,10 @@ PyArrayMethod_FromSpec(PyArrayMethod_Spec *spec);
* need better tests when a public version is exposed.
*/
NPY_NO_EXPORT PyBoundArrayMethodObject *
-PyArrayMethod_FromSpec_int(PyArrayMethod_Spec *spec, int private);
+PyArrayMethod_FromSpec_int(PyArrayMethod_Spec *spec, int priv);
+
+#ifdef __cplusplus
+}
+#endif
#endif /* NUMPY_CORE_SRC_MULTIARRAY_ARRAY_METHOD_H_ */
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index a1f0e2d5b..b164112ad 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -641,375 +641,11 @@ PyArray_FailUnlessWriteable(PyArrayObject *obj, const char *name)
return 0;
}
-/* This also handles possibly mis-aligned data */
-/* Compare s1 and s2 which are not necessarily NULL-terminated.
- s1 is of length len1
- s2 is of length len2
- If they are NULL terminated, then stop comparison.
-*/
-static int
-_myunincmp(npy_ucs4 const *s1, npy_ucs4 const *s2, int len1, int len2)
-{
- npy_ucs4 const *sptr;
- npy_ucs4 *s1t = NULL;
- npy_ucs4 *s2t = NULL;
- int val;
- npy_intp size;
- int diff;
-
- /* Replace `s1` and `s2` with aligned copies if needed */
- if ((npy_intp)s1 % sizeof(npy_ucs4) != 0) {
- size = len1*sizeof(npy_ucs4);
- s1t = malloc(size);
- memcpy(s1t, s1, size);
- s1 = s1t;
- }
- if ((npy_intp)s2 % sizeof(npy_ucs4) != 0) {
- size = len2*sizeof(npy_ucs4);
- s2t = malloc(size);
- memcpy(s2t, s2, size);
- s2 = s1t;
- }
-
- val = PyArray_CompareUCS4(s1, s2, PyArray_MIN(len1,len2));
- if ((val != 0) || (len1 == len2)) {
- goto finish;
- }
- if (len2 > len1) {
- sptr = s2+len1;
- val = -1;
- diff = len2-len1;
- }
- else {
- sptr = s1+len2;
- val = 1;
- diff=len1-len2;
- }
- while (diff--) {
- if (*sptr != 0) {
- goto finish;
- }
- sptr++;
- }
- val = 0;
-
- finish:
- /* Cleanup the aligned copies */
- if (s1t) {
- free(s1t);
- }
- if (s2t) {
- free(s2t);
- }
- return val;
-}
-
-
-
-
-/*
- * Compare s1 and s2 which are not necessarily NULL-terminated.
- * s1 is of length len1
- * s2 is of length len2
- * If they are NULL terminated, then stop comparison.
- */
-static int
-_mystrncmp(char const *s1, char const *s2, int len1, int len2)
-{
- char const *sptr;
- int val;
- int diff;
-
- val = memcmp(s1, s2, PyArray_MIN(len1, len2));
- if ((val != 0) || (len1 == len2)) {
- return val;
- }
- if (len2 > len1) {
- sptr = s2 + len1;
- val = -1;
- diff = len2 - len1;
- }
- else {
- sptr = s1 + len2;
- val = 1;
- diff = len1 - len2;
- }
- while (diff--) {
- if (*sptr != 0) {
- return val;
- }
- sptr++;
- }
- return 0; /* Only happens if NULLs are everywhere */
-}
-
-/* Borrowed from Numarray */
-
-#define SMALL_STRING 2048
-
-static void _rstripw(char *s, int n)
-{
- int i;
- for (i = n - 1; i >= 1; i--) { /* Never strip to length 0. */
- int c = s[i];
-
- if (!c || NumPyOS_ascii_isspace((int)c)) {
- s[i] = 0;
- }
- else {
- break;
- }
- }
-}
-
-static void _unistripw(npy_ucs4 *s, int n)
-{
- int i;
- for (i = n - 1; i >= 1; i--) { /* Never strip to length 0. */
- npy_ucs4 c = s[i];
- if (!c || NumPyOS_ascii_isspace((int)c)) {
- s[i] = 0;
- }
- else {
- break;
- }
- }
-}
-
-
-static char *
-_char_copy_n_strip(char const *original, char *temp, int nc)
-{
- if (nc > SMALL_STRING) {
- temp = malloc(nc);
- if (!temp) {
- PyErr_NoMemory();
- return NULL;
- }
- }
- memcpy(temp, original, nc);
- _rstripw(temp, nc);
- return temp;
-}
-
-static void
-_char_release(char *ptr, int nc)
-{
- if (nc > SMALL_STRING) {
- free(ptr);
- }
-}
-
-static char *
-_uni_copy_n_strip(char const *original, char *temp, int nc)
-{
- if (nc*sizeof(npy_ucs4) > SMALL_STRING) {
- temp = malloc(nc*sizeof(npy_ucs4));
- if (!temp) {
- PyErr_NoMemory();
- return NULL;
- }
- }
- memcpy(temp, original, nc*sizeof(npy_ucs4));
- _unistripw((npy_ucs4 *)temp, nc);
- return temp;
-}
-
-static void
-_uni_release(char *ptr, int nc)
-{
- if (nc*sizeof(npy_ucs4) > SMALL_STRING) {
- free(ptr);
- }
-}
-
-
-/* End borrowed from numarray */
-
-#define _rstrip_loop(CMP) { \
- void *aptr, *bptr; \
- char atemp[SMALL_STRING], btemp[SMALL_STRING]; \
- while(size--) { \
- aptr = stripfunc(iself->dataptr, atemp, N1); \
- if (!aptr) return -1; \
- bptr = stripfunc(iother->dataptr, btemp, N2); \
- if (!bptr) { \
- relfunc(aptr, N1); \
- return -1; \
- } \
- val = compfunc(aptr, bptr, N1, N2); \
- *dptr = (val CMP 0); \
- PyArray_ITER_NEXT(iself); \
- PyArray_ITER_NEXT(iother); \
- dptr += 1; \
- relfunc(aptr, N1); \
- relfunc(bptr, N2); \
- } \
- }
-
-#define _reg_loop(CMP) { \
- while(size--) { \
- val = compfunc((void *)iself->dataptr, \
- (void *)iother->dataptr, \
- N1, N2); \
- *dptr = (val CMP 0); \
- PyArray_ITER_NEXT(iself); \
- PyArray_ITER_NEXT(iother); \
- dptr += 1; \
- } \
- }
-
-static int
-_compare_strings(PyArrayObject *result, PyArrayMultiIterObject *multi,
- int cmp_op, void *func, int rstrip)
-{
- PyArrayIterObject *iself, *iother;
- npy_bool *dptr;
- npy_intp size;
- int val;
- int N1, N2;
- int (*compfunc)(void *, void *, int, int);
- void (*relfunc)(char *, int);
- char* (*stripfunc)(char const *, char *, int);
-
- compfunc = func;
- dptr = (npy_bool *)PyArray_DATA(result);
- iself = multi->iters[0];
- iother = multi->iters[1];
- size = multi->size;
- N1 = PyArray_DESCR(iself->ao)->elsize;
- N2 = PyArray_DESCR(iother->ao)->elsize;
- if ((void *)compfunc == (void *)_myunincmp) {
- N1 >>= 2;
- N2 >>= 2;
- stripfunc = _uni_copy_n_strip;
- relfunc = _uni_release;
- }
- else {
- stripfunc = _char_copy_n_strip;
- relfunc = _char_release;
- }
- switch (cmp_op) {
- case Py_EQ:
- if (rstrip) {
- _rstrip_loop(==);
- } else {
- _reg_loop(==);
- }
- break;
- case Py_NE:
- if (rstrip) {
- _rstrip_loop(!=);
- } else {
- _reg_loop(!=);
- }
- break;
- case Py_LT:
- if (rstrip) {
- _rstrip_loop(<);
- } else {
- _reg_loop(<);
- }
- break;
- case Py_LE:
- if (rstrip) {
- _rstrip_loop(<=);
- } else {
- _reg_loop(<=);
- }
- break;
- case Py_GT:
- if (rstrip) {
- _rstrip_loop(>);
- } else {
- _reg_loop(>);
- }
- break;
- case Py_GE:
- if (rstrip) {
- _rstrip_loop(>=);
- } else {
- _reg_loop(>=);
- }
- break;
- default:
- PyErr_SetString(PyExc_RuntimeError, "bad comparison operator");
- return -1;
- }
- return 0;
-}
-
-#undef _reg_loop
-#undef _rstrip_loop
-#undef SMALL_STRING
+/* From umath/string_ufuncs.cpp/h */
NPY_NO_EXPORT PyObject *
-_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
- int rstrip)
-{
- PyArrayObject *result;
- PyArrayMultiIterObject *mit;
- int val;
-
- if (PyArray_TYPE(self) != PyArray_TYPE(other)) {
- /*
- * Comparison between Bytes and Unicode is not defined in Py3K;
- * we follow.
- */
- Py_INCREF(Py_NotImplemented);
- return Py_NotImplemented;
- }
- if (PyArray_ISNOTSWAPPED(self) != PyArray_ISNOTSWAPPED(other)) {
- /* Cast `other` to the same byte order as `self` (both unicode here) */
- PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(self));
- if (unicode == NULL) {
- return NULL;
- }
- unicode->elsize = PyArray_DESCR(other)->elsize;
- PyObject *new = PyArray_FromAny((PyObject *)other,
- unicode, 0, 0, 0, NULL);
- if (new == NULL) {
- return NULL;
- }
- other = (PyArrayObject *)new;
- }
- else {
- Py_INCREF(other);
- }
-
- /* Broad-cast the arrays to a common shape */
- mit = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, self, other);
- Py_DECREF(other);
- if (mit == NULL) {
- return NULL;
- }
-
- result = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
- PyArray_DescrFromType(NPY_BOOL),
- mit->nd,
- mit->dimensions,
- NULL, NULL, 0,
- NULL);
- if (result == NULL) {
- goto finish;
- }
-
- if (PyArray_TYPE(self) == NPY_UNICODE) {
- val = _compare_strings(result, mit, cmp_op, _myunincmp, rstrip);
- }
- else {
- val = _compare_strings(result, mit, cmp_op, _mystrncmp, rstrip);
- }
-
- if (val < 0) {
- Py_DECREF(result);
- result = NULL;
- }
-
- finish:
- Py_DECREF(mit);
- return (PyObject *)result;
-}
+_umath_strings_richcompare(
+ PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip);
/*
* VOID-type arrays can only be compared equal and not-equal
@@ -1207,7 +843,7 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
return NULL;
}
/* compare as a string. Assumes self and other have same descr->type */
- return _strings_richcompare(self, other, cmp_op, 0);
+ return _umath_strings_richcompare(self, other, cmp_op, 0);
}
}
@@ -1341,36 +977,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
PyObject *obj_self = (PyObject *)self;
PyObject *result = NULL;
- /* Special case for string arrays (which don't and currently can't have
- * ufunc loops defined, so there's no point in trying).
- */
- if (PyArray_ISSTRING(self)) {
- array_other = (PyArrayObject *)PyArray_FromObject(other,
- NPY_NOTYPE, 0, 0);
- if (array_other == NULL) {
- PyErr_Clear();
- /* Never mind, carry on, see what happens */
- }
- else if (!PyArray_ISSTRING(array_other)) {
- Py_DECREF(array_other);
- /* Never mind, carry on, see what happens */
- }
- else {
- result = _strings_richcompare(self, array_other, cmp_op, 0);
- Py_DECREF(array_other);
- return result;
- }
- /* If we reach this point, it means that we are not comparing
- * string-to-string. It's possible that this will still work out,
- * e.g. if the other array is an object array, then both will be cast
- * to object or something? I don't know how that works actually, but
- * it does, b/c this works:
- * l = ["a", "b"]
- * assert np.array(l, dtype="S1") == np.array(l, dtype="O")
- * So we fall through and see what happens.
- */
- }
-
switch (cmp_op) {
case Py_LT:
RICHCMP_GIVE_UP_IF_NEEDED(obj_self, other);
diff --git a/numpy/core/src/multiarray/common_dtype.h b/numpy/core/src/multiarray/common_dtype.h
index 13d38ddf8..9f25fc14e 100644
--- a/numpy/core/src/multiarray/common_dtype.h
+++ b/numpy/core/src/multiarray/common_dtype.h
@@ -7,6 +7,10 @@
#include <numpy/ndarraytypes.h>
#include "dtypemeta.h"
+#ifdef __cplusplus
+extern "C" {
+#endif
+
NPY_NO_EXPORT PyArray_DTypeMeta *
PyArray_CommonDType(PyArray_DTypeMeta *dtype1, PyArray_DTypeMeta *dtype2);
@@ -14,4 +18,8 @@ NPY_NO_EXPORT PyArray_DTypeMeta *
PyArray_PromoteDTypeSequence(
npy_intp length, PyArray_DTypeMeta **dtypes_in);
+#ifdef __cplusplus
+}
+#endif
+
#endif /* NUMPY_CORE_SRC_MULTIARRAY_COMMON_DTYPE_H_ */
diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h
index d1865d1c2..af6d790cf 100644
--- a/numpy/core/src/multiarray/convert_datatype.h
+++ b/numpy/core/src/multiarray/convert_datatype.h
@@ -3,6 +3,10 @@
#include "array_method.h"
+#ifdef __cplusplus
+extern "C" {
+#endif
+
extern NPY_NO_EXPORT npy_intp REQUIRED_STR_LEN[];
NPY_NO_EXPORT PyObject *
@@ -34,7 +38,7 @@ dtype_kind_to_ordering(char kind);
/* Used by PyArray_CanCastArrayTo and in the legacy ufunc type resolution */
NPY_NO_EXPORT npy_bool
can_cast_scalar_to(PyArray_Descr *scal_type, char *scal_data,
- PyArray_Descr *to, NPY_CASTING casting);
+ PyArray_Descr *to, NPY_CASTING casting);
NPY_NO_EXPORT int
should_use_min_scalar(npy_intp narrs, PyArrayObject **arr,
@@ -59,7 +63,7 @@ NPY_NO_EXPORT int
PyArray_AddCastingImplementation(PyBoundArrayMethodObject *meth);
NPY_NO_EXPORT int
-PyArray_AddCastingImplementation_FromSpec(PyArrayMethod_Spec *spec, int private);
+PyArray_AddCastingImplementation_FromSpec(PyArrayMethod_Spec *spec, int private_);
NPY_NO_EXPORT NPY_CASTING
PyArray_MinCastSafety(NPY_CASTING casting1, NPY_CASTING casting2);
@@ -99,4 +103,8 @@ simple_cast_resolve_descriptors(
NPY_NO_EXPORT int
PyArray_InitializeCasts(void);
+#ifdef __cplusplus
+}
+#endif
+
#endif /* NUMPY_CORE_SRC_MULTIARRAY_CONVERT_DATATYPE_H_ */
diff --git a/numpy/core/src/multiarray/dtypemeta.h b/numpy/core/src/multiarray/dtypemeta.h
index e7d5505d8..618491c98 100644
--- a/numpy/core/src/multiarray/dtypemeta.h
+++ b/numpy/core/src/multiarray/dtypemeta.h
@@ -1,6 +1,9 @@
#ifndef NUMPY_CORE_SRC_MULTIARRAY_DTYPEMETA_H_
#define NUMPY_CORE_SRC_MULTIARRAY_DTYPEMETA_H_
+#ifdef __cplusplus
+extern "C" {
+#endif
/* DType flags, currently private, since we may just expose functions */
#define NPY_DT_LEGACY 1 << 0
@@ -126,4 +129,8 @@ python_builtins_are_known_scalar_types(
NPY_NO_EXPORT int
dtypemeta_wrap_legacy_descriptor(PyArray_Descr *dtypem);
+#ifdef __cplusplus
+}
+#endif
+
#endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPEMETA_H_ */
diff --git a/numpy/core/src/multiarray/experimental_public_dtype_api.c b/numpy/core/src/multiarray/experimental_public_dtype_api.c
index cf5f152ab..441dbdc1f 100644
--- a/numpy/core/src/multiarray/experimental_public_dtype_api.c
+++ b/numpy/core/src/multiarray/experimental_public_dtype_api.c
@@ -300,37 +300,13 @@ PyArrayInitDTypeMeta_FromSpec(
}
-/* Function is defined in umath/dispatching.c (same/one compilation unit) */
+/* Functions defined in umath/dispatching.c (same/one compilation unit) */
NPY_NO_EXPORT int
PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate);
-static int
-PyUFunc_AddLoopFromSpec(PyObject *ufunc, PyArrayMethod_Spec *spec)
-{
- if (!PyObject_TypeCheck(ufunc, &PyUFunc_Type)) {
- PyErr_SetString(PyExc_TypeError,
- "ufunc object passed is not a ufunc!");
- return -1;
- }
- PyBoundArrayMethodObject *bmeth =
- (PyBoundArrayMethodObject *)PyArrayMethod_FromSpec(spec);
- if (bmeth == NULL) {
- return -1;
- }
- int nargs = bmeth->method->nin + bmeth->method->nout;
- PyObject *dtypes = PyArray_TupleFromItems(
- nargs, (PyObject **)bmeth->dtypes, 1);
- if (dtypes == NULL) {
- return -1;
- }
- PyObject *info = PyTuple_Pack(2, dtypes, bmeth->method);
- Py_DECREF(bmeth);
- Py_DECREF(dtypes);
- if (info == NULL) {
- return -1;
- }
- return PyUFunc_AddLoop((PyUFuncObject *)ufunc, info, 0);
-}
+NPY_NO_EXPORT int
+PyUFunc_AddLoopFromSpec(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate);
+
/*
* Function is defined in umath/wrapping_array_method.c
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 5209d6914..a3cb3e131 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -85,6 +85,10 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
NPY_NO_EXPORT int initscalarmath(PyObject *);
NPY_NO_EXPORT int set_matmul_flags(PyObject *d); /* in ufunc_object.c */
+/* From umath/string_ufuncs.cpp/h */
+NPY_NO_EXPORT PyObject *
+_umath_strings_richcompare(
+ PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip);
/*
* global variable to determine if legacy printing is enabled, accessible from
@@ -3726,6 +3730,12 @@ format_longfloat(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
TrimMode_LeaveOneZero, -1, -1);
}
+
+/*
+ * The only purpose of this function is that it allows the "rstrip".
+ * From my (@seberg's) perspective, this function should be deprecated
+ * and I do not think it matters if it is not particularly fast.
+ */
static PyObject *
compare_chararrays(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
{
@@ -3791,7 +3801,7 @@ compare_chararrays(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
return NULL;
}
if (PyArray_ISSTRING(newarr) && PyArray_ISSTRING(newoth)) {
- res = _strings_richcompare(newarr, newoth, cmp_op, rstrip != 0);
+ res = _umath_strings_richcompare(newarr, newoth, cmp_op, rstrip != 0);
}
else {
PyErr_SetString(PyExc_TypeError,
diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c
index b8f102b3d..620335d88 100644
--- a/numpy/core/src/umath/dispatching.c
+++ b/numpy/core/src/umath/dispatching.c
@@ -145,6 +145,38 @@ PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate)
}
+/*
+ * Add loop directly to a ufunc from a given ArrayMethod spec.
+ */
+NPY_NO_EXPORT int
+PyUFunc_AddLoopFromSpec(PyObject *ufunc, PyArrayMethod_Spec *spec)
+{
+ if (!PyObject_TypeCheck(ufunc, &PyUFunc_Type)) {
+ PyErr_SetString(PyExc_TypeError,
+ "ufunc object passed is not a ufunc!");
+ return -1;
+ }
+ PyBoundArrayMethodObject *bmeth =
+ (PyBoundArrayMethodObject *)PyArrayMethod_FromSpec(spec);
+ if (bmeth == NULL) {
+ return -1;
+ }
+ int nargs = bmeth->method->nin + bmeth->method->nout;
+ PyObject *dtypes = PyArray_TupleFromItems(
+ nargs, (PyObject **)bmeth->dtypes, 1);
+ if (dtypes == NULL) {
+ return -1;
+ }
+ PyObject *info = PyTuple_Pack(2, dtypes, bmeth->method);
+ Py_DECREF(bmeth);
+ Py_DECREF(dtypes);
+ if (info == NULL) {
+ return -1;
+ }
+ return PyUFunc_AddLoop((PyUFuncObject *)ufunc, info, 0);
+}
+
+
/**
* Resolves the implementation to use, this uses typical multiple dispatching
* methods of finding the best matching implementation or resolver.
diff --git a/numpy/core/src/umath/dispatching.h b/numpy/core/src/umath/dispatching.h
index a7e9e88d0..f2ab0be2e 100644
--- a/numpy/core/src/umath/dispatching.h
+++ b/numpy/core/src/umath/dispatching.h
@@ -6,6 +6,9 @@
#include <numpy/ufuncobject.h>
#include "array_method.h"
+#ifdef __cplusplus
+extern "C" {
+#endif
typedef int promoter_function(PyUFuncObject *ufunc,
PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
@@ -14,6 +17,9 @@ typedef int promoter_function(PyUFuncObject *ufunc,
NPY_NO_EXPORT int
PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate);
+NPY_NO_EXPORT int
+PyUFunc_AddLoopFromSpec(PyObject *ufunc, PyArrayMethod_Spec *spec);
+
NPY_NO_EXPORT PyArrayMethodObject *
promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
PyArrayObject *const ops[],
@@ -41,5 +47,8 @@ object_only_ufunc_promoter(PyUFuncObject *ufunc,
NPY_NO_EXPORT int
install_logical_ufunc_promoter(PyObject *ufunc);
+#ifdef __cplusplus
+}
+#endif
#endif /*_NPY_DISPATCHING_H */
diff --git a/numpy/core/src/umath/string_ufuncs.cpp b/numpy/core/src/umath/string_ufuncs.cpp
new file mode 100644
index 000000000..1b45ad71f
--- /dev/null
+++ b/numpy/core/src/umath/string_ufuncs.cpp
@@ -0,0 +1,449 @@
+#include <Python.h>
+
+#define NPY_NO_DEPRECATED_API NPY_API_VERSION
+#define _MULTIARRAYMODULE
+#define _UMATHMODULE
+
+#include "numpy/ndarraytypes.h"
+
+#include "numpyos.h"
+#include "dispatching.h"
+#include "dtypemeta.h"
+#include "common_dtype.h"
+#include "convert_datatype.h"
+
+#include "string_ufuncs.h"
+
+
+template <typename character>
+static NPY_INLINE int
+character_cmp(character a, character b)
+{
+ if (a == b) {
+ return 0;
+ }
+ else if (a < b) {
+ return -1;
+ }
+ else {
+ return 1;
+ }
+}
+
+
+/*
+ * Compare two strings of different length. Note that either string may be
+ * zero padded (trailing zeros are ignored in other words, the shorter word
+ * is always padded with zeros).
+ */
+template <bool rstrip, typename character>
+static NPY_INLINE int
+string_cmp(int len1, const character *str1, int len2, const character *str2)
+{
+ if (rstrip) {
+ /*
+ * Ignore/"trim" trailing whitespace (and 0s). Note that this function
+ * does not support unicode whitespace (and never has).
+ */
+ while (len1 > 0) {
+ character c = str1[len1-1];
+ if (c != (character)0 && !NumPyOS_ascii_isspace(c)) {
+ break;
+ }
+ len1--;
+ }
+ while (len2 > 0) {
+ character c = str2[len2-1];
+ if (c != (character)0 && !NumPyOS_ascii_isspace(c)) {
+ break;
+ }
+ len2--;
+ }
+ }
+
+ int n = PyArray_MIN(len1, len2);
+
+ if (sizeof(character) == 1) {
+ /*
+ * TODO: `memcmp` makes things 2x faster for longer words that match
+ * exactly, but at least 2x slower for short or mismatching ones.
+ */
+ int cmp = memcmp(str1, str2, n);
+ if (cmp != 0) {
+ return cmp;
+ }
+ str1 += n;
+ str2 += n;
+ }
+ else {
+ for (int i = 0; i < n; i++) {
+ int cmp = character_cmp(*str1, *str2);
+ if (cmp != 0) {
+ return cmp;
+ }
+ str1++;
+ str2++;
+ }
+ }
+ if (len1 > len2) {
+ for (int i = n; i < len1; i++) {
+ int cmp = character_cmp(*str1, (character)0);
+ if (cmp != 0) {
+ return cmp;
+ }
+ str1++;
+ }
+ }
+ else if (len2 > len1) {
+ for (int i = n; i < len2; i++) {
+ int cmp = character_cmp((character)0, *str2);
+ if (cmp != 0) {
+ return cmp;
+ }
+ str2++;
+ }
+ }
+ return 0;
+}
+
+
+/*
+ * Helper for templating, avoids warnings about uncovered switch paths.
+ */
+enum class COMP {
+ EQ, NE, LT, LE, GT, GE,
+};
+
+static char const *
+comp_name(COMP comp) {
+ switch(comp) {
+ case COMP::EQ: return "equal";
+ case COMP::NE: return "not_equal";
+ case COMP::LT: return "less";
+ case COMP::LE: return "less_equal";
+ case COMP::GT: return "greater";
+ case COMP::GE: return "greater_equal";
+ default:
+ assert(0);
+ return nullptr;
+ }
+}
+
+
+template <bool rstrip, COMP comp, typename character>
+static int
+string_comparison_loop(PyArrayMethod_Context *context,
+ char *const data[], npy_intp const dimensions[],
+ npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
+{
+ /*
+ * Note, fetching `elsize` from the descriptor is OK even without the GIL,
+ * however it may be that this should be moved into `auxdata` eventually,
+ * which may also be slightly faster/cleaner (but more involved).
+ */
+ int len1 = context->descriptors[0]->elsize / sizeof(character);
+ int len2 = context->descriptors[1]->elsize / sizeof(character);
+
+ char *in1 = data[0];
+ char *in2 = data[1];
+ char *out = data[2];
+
+ npy_intp N = dimensions[0];
+
+ while (N--) {
+ int cmp = string_cmp<rstrip>(
+ len1, (character *)in1, len2, (character *)in2);
+ npy_bool res;
+ switch (comp) {
+ case COMP::EQ:
+ res = cmp == 0;
+ break;
+ case COMP::NE:
+ res = cmp != 0;
+ break;
+ case COMP::LT:
+ res = cmp < 0;
+ break;
+ case COMP::LE:
+ res = cmp <= 0;
+ break;
+ case COMP::GT:
+ res = cmp > 0;
+ break;
+ case COMP::GE:
+ res = cmp >= 0;
+ break;
+ }
+ *(npy_bool *)out = res;
+
+ in1 += strides[0];
+ in2 += strides[1];
+ out += strides[2];
+ }
+ return 0;
+}
+
+
+/*
+ * Machinery to add the string loops to the existing ufuncs.
+ */
+
+/*
+ * This function replaces the strided loop with the passed in one,
+ * and registers it with the given ufunc.
+ */
+static int
+add_loop(PyObject *umath, const char *ufunc_name,
+ PyArrayMethod_Spec *spec, PyArrayMethod_StridedLoop *loop)
+{
+ PyObject *name = PyUnicode_FromString(ufunc_name);
+ if (name == nullptr) {
+ return -1;
+ }
+ PyObject *ufunc = PyObject_GetItem(umath, name);
+ Py_DECREF(name);
+ if (ufunc == nullptr) {
+ return -1;
+ }
+ spec->slots[0].pfunc = (void *)loop;
+
+ int res = PyUFunc_AddLoopFromSpec(ufunc, spec);
+ Py_DECREF(ufunc);
+ return res;
+}
+
+
+template<bool rstrip, typename character, COMP...>
+struct add_loops;
+
+template<bool rstrip, typename character>
+struct add_loops<rstrip, character> {
+ int operator()(PyObject*, PyArrayMethod_Spec*) {
+ return 0;
+ }
+};
+
+template<bool rstrip, typename character, COMP comp, COMP... comps>
+struct add_loops<rstrip, character, comp, comps...> {
+ int operator()(PyObject* umath, PyArrayMethod_Spec* spec) {
+ PyArrayMethod_StridedLoop* loop = string_comparison_loop<rstrip, comp, character>;
+
+ if (add_loop(umath, comp_name(comp), spec, loop) < 0) {
+ return -1;
+ }
+ else {
+ return add_loops<rstrip, character, comps...>()(umath, spec);
+ }
+ }
+};
+
+
+NPY_NO_EXPORT int
+init_string_ufuncs(PyObject *umath)
+{
+ int res = -1;
+ /* NOTE: This should recieve global symbols? */
+ PyArray_DTypeMeta *String = PyArray_DTypeFromTypeNum(NPY_STRING);
+ PyArray_DTypeMeta *Unicode = PyArray_DTypeFromTypeNum(NPY_UNICODE);
+ PyArray_DTypeMeta *Bool = PyArray_DTypeFromTypeNum(NPY_BOOL);
+
+ /* We start with the string loops: */
+ PyArray_DTypeMeta *dtypes[] = {String, String, Bool};
+ /*
+ * We only have one loop right now, the strided one. The default type
+ * resolver ensures native byte order/canonical representation.
+ */
+ PyType_Slot slots[] = {
+ {NPY_METH_strided_loop, nullptr},
+ {0, nullptr}
+ };
+
+ PyArrayMethod_Spec spec = {};
+ spec.name = "templated_string_comparison";
+ spec.nin = 2;
+ spec.nout = 1;
+ spec.dtypes = dtypes;
+ spec.slots = slots;
+ spec.flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
+
+ /* All String loops */
+ using string_looper = add_loops<false, npy_byte, COMP::EQ, COMP::NE, COMP::LT, COMP::LE, COMP::GT, COMP::GE>;
+ if (string_looper()(umath, &spec) < 0) {
+ goto finish;
+ }
+
+ /* All Unicode loops */
+ using ucs_looper = add_loops<false, npy_ucs4, COMP::EQ, COMP::NE, COMP::LT, COMP::LE, COMP::GT, COMP::GE>;
+ dtypes[0] = Unicode;
+ dtypes[1] = Unicode;
+ if (ucs_looper()(umath, &spec) < 0) {
+ goto finish;
+ }
+
+ res = 0;
+ finish:
+ Py_DECREF(String);
+ Py_DECREF(Unicode);
+ Py_DECREF(Bool);
+ return res;
+}
+
+
+template <bool rstrip, typename character>
+static PyArrayMethod_StridedLoop *
+get_strided_loop(int comp)
+{
+ switch (comp) {
+ case Py_EQ:
+ return string_comparison_loop<rstrip, COMP::EQ, character>;
+ case Py_NE:
+ return string_comparison_loop<rstrip, COMP::NE, character>;
+ case Py_LT:
+ return string_comparison_loop<rstrip, COMP::LT, character>;
+ case Py_LE:
+ return string_comparison_loop<rstrip, COMP::LE, character>;
+ case Py_GT:
+ return string_comparison_loop<rstrip, COMP::GT, character>;
+ case Py_GE:
+ return string_comparison_loop<rstrip, COMP::GE, character>;
+ default:
+ assert(false); /* caller ensures this */
+ }
+ return nullptr;
+}
+
+
+/*
+ * This function is used for `compare_chararrays` and currently also void
+ * comparisons (unstructured voids). The first could probably be deprecated
+ * and removed but is used by `np.char.chararray` the latter should also be
+ * moved to the ufunc probably (removing the need for manual looping).
+ *
+ * The `rstrip` mechanism is presumably for some fortran compat, but the
+ * question is whether it would not be better to have/use `rstrip` on such
+ * an array first...
+ *
+ * NOTE: This function is also used for unstructured voids, this works because
+ * `npy_byte` is correct.
+ */
+NPY_NO_EXPORT PyObject *
+_umath_strings_richcompare(
+ PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip)
+{
+ NpyIter *iter = nullptr;
+ PyObject *result = nullptr;
+
+ char **dataptr = nullptr;
+ npy_intp *strides = nullptr;
+ npy_intp *countptr = nullptr;
+ npy_intp size = 0;
+
+ PyArrayMethod_Context context = {};
+ NpyIter_IterNextFunc *iternext = nullptr;
+
+ npy_uint32 it_flags = (
+ NPY_ITER_EXTERNAL_LOOP | NPY_ITER_ZEROSIZE_OK |
+ NPY_ITER_BUFFERED | NPY_ITER_GROWINNER);
+ npy_uint32 op_flags[3] = {
+ NPY_ITER_READONLY | NPY_ITER_ALIGNED,
+ NPY_ITER_READONLY | NPY_ITER_ALIGNED,
+ NPY_ITER_WRITEONLY | NPY_ITER_ALLOCATE | NPY_ITER_ALIGNED};
+
+ PyArrayMethod_StridedLoop *strided_loop = nullptr;
+ NPY_BEGIN_THREADS_DEF;
+
+ if (PyArray_TYPE(self) != PyArray_TYPE(other)) {
+ /*
+ * Comparison between Bytes and Unicode is not defined in Py3K;
+ * we follow.
+ * TODO: This makes no sense at all for `compare_chararrays`, kept
+ * only under the assumption that we are more likely to deprecate
+ * than fix it to begin with.
+ */
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+
+ PyArrayObject *ops[3] = {self, other, nullptr};
+ PyArray_Descr *descrs[3] = {nullptr, nullptr, PyArray_DescrFromType(NPY_BOOL)};
+ /* TODO: ensuring native byte order is not really necessary for == and != */
+ descrs[0] = NPY_DT_CALL_ensure_canonical(PyArray_DESCR(self));
+ if (descrs[0] == nullptr) {
+ goto finish;
+ }
+ descrs[1] = NPY_DT_CALL_ensure_canonical(PyArray_DESCR(other));
+ if (descrs[1] == nullptr) {
+ goto finish;
+ }
+
+ /*
+ * Create the iterator:
+ */
+ iter = NpyIter_AdvancedNew(
+ 3, ops, it_flags, NPY_KEEPORDER, NPY_SAFE_CASTING, op_flags, descrs,
+ -1, nullptr, nullptr, 0);
+ if (iter == nullptr) {
+ goto finish;
+ }
+
+ size = NpyIter_GetIterSize(iter);
+ if (size == 0) {
+ result = (PyObject *)NpyIter_GetOperandArray(iter)[2];
+ Py_INCREF(result);
+ goto finish;
+ }
+
+ iternext = NpyIter_GetIterNext(iter, nullptr);
+ if (iternext == nullptr) {
+ goto finish;
+ }
+
+ /*
+ * Prepare the inner-loop and execute it (we only need descriptors to be
+ * passed in).
+ */
+ context.descriptors = descrs;
+
+ dataptr = NpyIter_GetDataPtrArray(iter);
+ strides = NpyIter_GetInnerStrideArray(iter);
+ countptr = NpyIter_GetInnerLoopSizePtr(iter);
+
+ if (rstrip == 0) {
+ /* NOTE: Also used for VOID, so can be STRING, UNICODE, or VOID: */
+ if (descrs[0]->type_num != NPY_UNICODE) {
+ strided_loop = get_strided_loop<false, npy_byte>(cmp_op);
+ }
+ else {
+ strided_loop = get_strided_loop<false, npy_ucs4>(cmp_op);
+ }
+ }
+ else {
+ if (descrs[0]->type_num != NPY_UNICODE) {
+ strided_loop = get_strided_loop<true, npy_byte>(cmp_op);
+ }
+ else {
+ strided_loop = get_strided_loop<true, npy_ucs4>(cmp_op);
+ }
+ }
+
+ NPY_BEGIN_THREADS_THRESHOLDED(size);
+
+ do {
+ /* We know the loop cannot fail */
+ strided_loop(&context, dataptr, countptr, strides, nullptr);
+ } while (iternext(iter) != 0);
+
+ NPY_END_THREADS;
+
+ result = (PyObject *)NpyIter_GetOperandArray(iter)[2];
+ Py_INCREF(result);
+
+ finish:
+ if (NpyIter_Deallocate(iter) < 0) {
+ Py_CLEAR(result);
+ }
+ Py_XDECREF(descrs[0]);
+ Py_XDECREF(descrs[1]);
+ Py_XDECREF(descrs[2]);
+ return result;
+}
diff --git a/numpy/core/src/umath/string_ufuncs.h b/numpy/core/src/umath/string_ufuncs.h
new file mode 100644
index 000000000..aa1719954
--- /dev/null
+++ b/numpy/core/src/umath/string_ufuncs.h
@@ -0,0 +1,19 @@
+#ifndef _NPY_CORE_SRC_UMATH_STRING_UFUNCS_H_
+#define _NPY_CORE_SRC_UMATH_STRING_UFUNCS_H_
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+NPY_NO_EXPORT int
+init_string_ufuncs(PyObject *umath);
+
+NPY_NO_EXPORT PyObject *
+_umath_strings_richcompare(
+ PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* _NPY_CORE_SRC_UMATH_STRING_UFUNCS_H_ */ \ No newline at end of file
diff --git a/numpy/core/src/umath/umathmodule.c b/numpy/core/src/umath/umathmodule.c
index 49328d19e..17fedec6f 100644
--- a/numpy/core/src/umath/umathmodule.c
+++ b/numpy/core/src/umath/umathmodule.c
@@ -23,11 +23,13 @@
#include "numpy/npy_math.h"
#include "number.h"
#include "dispatching.h"
+#include "string_ufuncs.h"
/* Automatically generated code to define all ufuncs: */
#include "funcs.inc"
#include "__umath_generated.c"
+
static PyUFuncGenericFunction pyfunc_functions[] = {PyUFunc_On_Om};
static int
@@ -347,5 +349,10 @@ int initumath(PyObject *m)
if (install_logical_ufunc_promoter(s) < 0) {
return -1;
}
+
+ if (init_string_ufuncs(d) < 0) {
+ return -1;
+ }
+
return 0;
}
diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py
index 2b7864433..2255cb2a3 100644
--- a/numpy/core/tests/test_deprecations.py
+++ b/numpy/core/tests/test_deprecations.py
@@ -166,7 +166,7 @@ class TestComparisonDeprecations(_DeprecationTestCase):
# For two string arrays, strings always raised the broadcasting error:
a = np.array(['a', 'b'])
b = np.array(['a', 'b', 'c'])
- assert_raises(ValueError, lambda x, y: x == y, a, b)
+ assert_warns(FutureWarning, lambda x, y: x == y, a, b)
# The empty list is not cast to string, and this used to pass due
# to dtype mismatch; now (2018-06-21) it correctly leads to a
diff --git a/numpy/core/tests/test_strings.py b/numpy/core/tests/test_strings.py
new file mode 100644
index 000000000..2b87ed654
--- /dev/null
+++ b/numpy/core/tests/test_strings.py
@@ -0,0 +1,85 @@
+import pytest
+
+import operator
+import numpy as np
+
+from numpy.testing import assert_array_equal
+
+
+COMPARISONS = [
+ (operator.eq, np.equal, "=="),
+ (operator.ne, np.not_equal, "!="),
+ (operator.lt, np.less, "<"),
+ (operator.le, np.less_equal, "<="),
+ (operator.gt, np.greater, ">"),
+ (operator.ge, np.greater_equal, ">="),
+]
+
+
+@pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS)
+def test_mixed_string_comparison_ufuncs_fail(op, ufunc, sym):
+ arr_string = np.array(["a", "b"], dtype="S")
+ arr_unicode = np.array(["a", "c"], dtype="U")
+
+ with pytest.raises(TypeError, match="did not contain a loop"):
+ ufunc(arr_string, arr_unicode)
+
+ with pytest.raises(TypeError, match="did not contain a loop"):
+ ufunc(arr_unicode, arr_string)
+
+@pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS)
+def test_mixed_string_comparisons_ufuncs_with_cast(op, ufunc, sym):
+ arr_string = np.array(["a", "b"], dtype="S")
+ arr_unicode = np.array(["a", "c"], dtype="U")
+
+ # While there is no loop, manual casting is acceptable:
+ res1 = ufunc(arr_string, arr_unicode, signature="UU->?", casting="unsafe")
+ res2 = ufunc(arr_string, arr_unicode, signature="SS->?", casting="unsafe")
+
+ expected = op(arr_string.astype('U'), arr_unicode)
+ assert_array_equal(res1, expected)
+ assert_array_equal(res2, expected)
+
+
+@pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS)
+@pytest.mark.parametrize("dtypes", [
+ ("S2", "S2"), ("S2", "S10"),
+ ("<U1", "<U1"), ("<U1", ">U1"), (">U1", ">U1"),
+ ("<U1", "<U10"), ("<U1", ">U10")])
+@pytest.mark.parametrize("aligned", [True, False])
+def test_string_comparisons(op, ufunc, sym, dtypes, aligned):
+ # ensure native byte-order for the first view to stay within unicode range
+ native_dt = np.dtype(dtypes[0]).newbyteorder("=")
+ arr = np.arange(2**15).view(native_dt).astype(dtypes[0])
+ if not aligned:
+ # Make `arr` unaligned:
+ new = np.zeros(arr.nbytes + 1, dtype=np.uint8)[1:].view(dtypes[0])
+ new[...] = arr
+ arr = new
+
+ arr2 = arr.astype(dtypes[1], copy=True)
+ np.random.shuffle(arr2)
+ arr[0] = arr2[0] # make sure one matches
+
+ expected = [op(d1, d2) for d1, d2 in zip(arr.tolist(), arr2.tolist())]
+ assert_array_equal(op(arr, arr2), expected)
+ assert_array_equal(ufunc(arr, arr2), expected)
+ assert_array_equal(np.compare_chararrays(arr, arr2, sym, False), expected)
+
+ expected = [op(d2, d1) for d1, d2 in zip(arr.tolist(), arr2.tolist())]
+ assert_array_equal(op(arr2, arr), expected)
+ assert_array_equal(ufunc(arr2, arr), expected)
+ assert_array_equal(np.compare_chararrays(arr2, arr, sym, False), expected)
+
+
+@pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS)
+@pytest.mark.parametrize("dtypes", [
+ ("S2", "S2"), ("S2", "S10"), ("<U1", "<U1"), ("<U1", ">U10")])
+def test_string_comparisons_empty(op, ufunc, sym, dtypes):
+ arr = np.empty((1, 0, 1, 5), dtype=dtypes[0])
+ arr2 = np.empty((100, 1, 0, 1), dtype=dtypes[1])
+
+ expected = np.empty(np.broadcast_shapes(arr.shape, arr2.shape), dtype=bool)
+ assert_array_equal(op(arr, arr2), expected)
+ assert_array_equal(ufunc(arr, arr2), expected)
+ assert_array_equal(np.compare_chararrays(arr, arr2, sym, False), expected)
diff --git a/numpy/core/tests/test_unicode.py b/numpy/core/tests/test_unicode.py
index 8e0dd47cb..12de25771 100644
--- a/numpy/core/tests/test_unicode.py
+++ b/numpy/core/tests/test_unicode.py
@@ -1,3 +1,5 @@
+import pytest
+
import numpy as np
from numpy.testing import assert_, assert_equal, assert_array_equal
@@ -33,8 +35,11 @@ def test_string_cast():
uni_arr1 = str_arr.astype('>U')
uni_arr2 = str_arr.astype('<U')
- assert_(str_arr != uni_arr1)
- assert_(str_arr != uni_arr2)
+ with pytest.warns(FutureWarning):
+ assert str_arr != uni_arr1
+ with pytest.warns(FutureWarning):
+ assert str_arr != uni_arr2
+
assert_array_equal(uni_arr1, uni_arr2)