summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-04-11 11:04:07 -0400
committerGitHub <noreply@github.com>2018-04-11 11:04:07 -0400
commitc6bee683578d6549efb8cd972ebb2592b44ac2cc (patch)
tree554565673894fdbd80071671627cbe56747c3f5c /numpy/core
parent40ef8a6ade7282c613d58ed24c68915d61dcc07b (diff)
parent358f8220a298807c48963125d941dcd93a27ccc0 (diff)
downloadnumpy-c6bee683578d6549efb8cd972ebb2592b44ac2cc.tar.gz
Merge pull request #10745 from eric-wieser/comparison-object-loop
ENH: Add object loops to the comparison ufuncs
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/code_generators/generate_umath.py6
-rw-r--r--numpy/core/code_generators/ufunc_docstrings.py30
-rw-r--r--numpy/core/src/umath/loops.c.src25
-rw-r--r--numpy/core/src/umath/loops.h.src6
-rw-r--r--numpy/core/tests/test_ufunc.py15
5 files changed, 61 insertions, 21 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index 1d3550e06..7492baf9d 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -420,36 +420,42 @@ defdict = {
docstrings.get('numpy.core.umath.greater'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'greater_equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.greater_equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'less':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.less'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'less_equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.less_equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'not_equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.not_equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
+ [TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'logical_and':
Ufunc(2, 1, One,
diff --git a/numpy/core/code_generators/ufunc_docstrings.py b/numpy/core/code_generators/ufunc_docstrings.py
index c7e5cf600..615970816 100644
--- a/numpy/core/code_generators/ufunc_docstrings.py
+++ b/numpy/core/code_generators/ufunc_docstrings.py
@@ -1077,8 +1077,9 @@ add_newdoc('numpy.core.umath', 'equal',
Returns
-------
- out : ndarray or bool
- Output array of bools.
+ out : ndarray or scalar
+ Output array, element-wise comparison of `x1` and `x2`.
+ Typically of type bool, unless ``dtype=object`` is passed.
$OUT_SCALAR_2
See Also
@@ -1415,8 +1416,9 @@ add_newdoc('numpy.core.umath', 'greater',
Returns
-------
- out : bool or ndarray of bool
- Array of bools.
+ out : ndarray or scalar
+ Output array, element-wise comparison of `x1` and `x2`.
+ Typically of type bool, unless ``dtype=object`` is passed.
$OUT_SCALAR_2
@@ -1453,7 +1455,8 @@ add_newdoc('numpy.core.umath', 'greater_equal',
Returns
-------
out : bool or ndarray of bool
- Array of bools.
+ Output array, element-wise comparison of `x1` and `x2`.
+ Typically of type bool, unless ``dtype=object`` is passed.
$OUT_SCALAR_2
See Also
@@ -1820,8 +1823,9 @@ add_newdoc('numpy.core.umath', 'less',
Returns
-------
- out : bool or ndarray of bool
- Array of bools.
+ out : ndarray or scalar
+ Output array, element-wise comparison of `x1` and `x2`.
+ Typically of type bool, unless ``dtype=object`` is passed.
$OUT_SCALAR_2
See Also
@@ -1849,8 +1853,9 @@ add_newdoc('numpy.core.umath', 'less_equal',
Returns
-------
- out : bool or ndarray of bool
- Array of bools.
+ out : ndarray or scalar
+ Output array, element-wise comparison of `x1` and `x2`.
+ Typically of type bool, unless ``dtype=object`` is passed.
$OUT_SCALAR_2
See Also
@@ -2668,12 +2673,11 @@ add_newdoc('numpy.core.umath', 'not_equal',
Returns
-------
- not_equal : ndarray bool, scalar bool
- For each element in `x1, x2`, return True if `x1` is not equal
- to `x2` and False otherwise.
+ out : ndarray or scalar
+ Output array, element-wise comparison of `x1` and `x2`.
+ Typically of type bool, unless ``dtype=object`` is passed.
$OUT_SCALAR_2
-
See Also
--------
equal, greater, greater_equal, less, less_equal
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index d196a8d4e..8b1c7e703 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -2840,10 +2840,14 @@ NPY_NO_EXPORT void
* #OP = EQ, NE, GT, GE, LT, LE#
* #identity = NPY_TRUE, NPY_FALSE, -1*4#
*/
+
+/**begin repeat1
+ * #suffix = , _OO_O#
+ * #as_bool = 1, 0#
+ */
NPY_NO_EXPORT void
-OBJECT_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)) {
+OBJECT@suffix@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)) {
BINARY_LOOP {
- int ret;
PyObject *ret_obj;
PyObject *in1 = *(PyObject **)ip1;
PyObject *in2 = *(PyObject **)ip2;
@@ -2860,14 +2864,21 @@ OBJECT_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUS
if (ret_obj == NULL) {
return;
}
- ret = PyObject_IsTrue(ret_obj);
- Py_DECREF(ret_obj);
- if (ret == -1) {
- return;
+#if @as_bool@
+ {
+ int ret = PyObject_IsTrue(ret_obj);
+ Py_DECREF(ret_obj);
+ if (ret == -1) {
+ return;
+ }
+ *((npy_bool *)op1) = (npy_bool)ret;
}
- *((npy_bool *)op1) = (npy_bool)ret;
+#else
+ *((PyObject **)op1) = ret_obj;
+#endif
}
}
+/**end repeat1**/
/**end repeat**/
NPY_NO_EXPORT void
diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src
index a01ef1529..5c2b2c22c 100644
--- a/numpy/core/src/umath/loops.h.src
+++ b/numpy/core/src/umath/loops.h.src
@@ -496,8 +496,12 @@ TIMEDELTA_mm_d_divide(char **args, npy_intp *dimensions, npy_intp *steps, void *
* #kind = equal, not_equal, greater, greater_equal, less, less_equal#
* #OP = EQ, NE, GT, GE, LT, LE#
*/
+/**begin repeat1
+ * #suffix = , _OO_O#
+ */
NPY_NO_EXPORT void
-OBJECT_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));
+OBJECT@suffix@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));
+/**end repeat1**/
/**end repeat**/
NPY_NO_EXPORT void
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index ea9ca021c..7a276c04d 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -814,6 +814,21 @@ class TestUfunc(object):
assert_equal(np.logical_or.reduce(a), 3)
assert_equal(np.logical_and.reduce(a), None)
+ def test_object_comparison(self):
+ class HasComparisons(object):
+ def __eq__(self, other):
+ return '=='
+
+ arr0d = np.array(HasComparisons())
+ assert_equal(arr0d == arr0d, True)
+ assert_equal(np.equal(arr0d, arr0d), True) # normal behavior is a cast
+ assert_equal(np.equal(arr0d, arr0d, dtype=object), '==')
+
+ arr1d = np.array([HasComparisons()])
+ assert_equal(arr1d == arr1d, np.array([True]))
+ assert_equal(np.equal(arr1d, arr1d), np.array([True])) # normal behavior is a cast
+ assert_equal(np.equal(arr1d, arr1d, dtype=object), np.array(['==']))
+
def test_object_array_reduction(self):
# Reductions on object arrays
a = np.array(['a', 'b', 'c'], dtype=object)