diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-04-11 11:04:07 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-11 11:04:07 -0400 |
commit | c6bee683578d6549efb8cd972ebb2592b44ac2cc (patch) | |
tree | 554565673894fdbd80071671627cbe56747c3f5c /numpy/core | |
parent | 40ef8a6ade7282c613d58ed24c68915d61dcc07b (diff) | |
parent | 358f8220a298807c48963125d941dcd93a27ccc0 (diff) | |
download | numpy-c6bee683578d6549efb8cd972ebb2592b44ac2cc.tar.gz |
Merge pull request #10745 from eric-wieser/comparison-object-loop
ENH: Add object loops to the comparison ufuncs
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/code_generators/generate_umath.py | 6 | ||||
-rw-r--r-- | numpy/core/code_generators/ufunc_docstrings.py | 30 | ||||
-rw-r--r-- | numpy/core/src/umath/loops.c.src | 25 | ||||
-rw-r--r-- | numpy/core/src/umath/loops.h.src | 6 | ||||
-rw-r--r-- | numpy/core/tests/test_ufunc.py | 15 |
5 files changed, 61 insertions, 21 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index 1d3550e06..7492baf9d 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -420,36 +420,42 @@ defdict = { docstrings.get('numpy.core.umath.greater'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', TD(all, out='?', simd=[('avx2', ints)]), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'greater_equal': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.greater_equal'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', TD(all, out='?', simd=[('avx2', ints)]), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'less': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.less'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', TD(all, out='?', simd=[('avx2', ints)]), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'less_equal': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.less_equal'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', TD(all, out='?', simd=[('avx2', ints)]), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'equal': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.equal'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', TD(all, out='?', simd=[('avx2', ints)]), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'not_equal': Ufunc(2, 1, None, docstrings.get('numpy.core.umath.not_equal'), 'PyUFunc_SimpleBinaryComparisonTypeResolver', TD(all, out='?', simd=[('avx2', ints)]), + [TypeDescription('O', FullTypeDescr, 'OO', 'O')], ), 'logical_and': Ufunc(2, 1, One, diff --git a/numpy/core/code_generators/ufunc_docstrings.py b/numpy/core/code_generators/ufunc_docstrings.py index c7e5cf600..615970816 100644 --- a/numpy/core/code_generators/ufunc_docstrings.py +++ b/numpy/core/code_generators/ufunc_docstrings.py @@ -1077,8 +1077,9 @@ add_newdoc('numpy.core.umath', 'equal', Returns ------- - out : ndarray or bool - Output array of bools. + out : ndarray or scalar + Output array, element-wise comparison of `x1` and `x2`. + Typically of type bool, unless ``dtype=object`` is passed. $OUT_SCALAR_2 See Also @@ -1415,8 +1416,9 @@ add_newdoc('numpy.core.umath', 'greater', Returns ------- - out : bool or ndarray of bool - Array of bools. + out : ndarray or scalar + Output array, element-wise comparison of `x1` and `x2`. + Typically of type bool, unless ``dtype=object`` is passed. $OUT_SCALAR_2 @@ -1453,7 +1455,8 @@ add_newdoc('numpy.core.umath', 'greater_equal', Returns ------- out : bool or ndarray of bool - Array of bools. + Output array, element-wise comparison of `x1` and `x2`. + Typically of type bool, unless ``dtype=object`` is passed. $OUT_SCALAR_2 See Also @@ -1820,8 +1823,9 @@ add_newdoc('numpy.core.umath', 'less', Returns ------- - out : bool or ndarray of bool - Array of bools. + out : ndarray or scalar + Output array, element-wise comparison of `x1` and `x2`. + Typically of type bool, unless ``dtype=object`` is passed. $OUT_SCALAR_2 See Also @@ -1849,8 +1853,9 @@ add_newdoc('numpy.core.umath', 'less_equal', Returns ------- - out : bool or ndarray of bool - Array of bools. + out : ndarray or scalar + Output array, element-wise comparison of `x1` and `x2`. + Typically of type bool, unless ``dtype=object`` is passed. $OUT_SCALAR_2 See Also @@ -2668,12 +2673,11 @@ add_newdoc('numpy.core.umath', 'not_equal', Returns ------- - not_equal : ndarray bool, scalar bool - For each element in `x1, x2`, return True if `x1` is not equal - to `x2` and False otherwise. + out : ndarray or scalar + Output array, element-wise comparison of `x1` and `x2`. + Typically of type bool, unless ``dtype=object`` is passed. $OUT_SCALAR_2 - See Also -------- equal, greater, greater_equal, less, less_equal diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src index d196a8d4e..8b1c7e703 100644 --- a/numpy/core/src/umath/loops.c.src +++ b/numpy/core/src/umath/loops.c.src @@ -2840,10 +2840,14 @@ NPY_NO_EXPORT void * #OP = EQ, NE, GT, GE, LT, LE# * #identity = NPY_TRUE, NPY_FALSE, -1*4# */ + +/**begin repeat1 + * #suffix = , _OO_O# + * #as_bool = 1, 0# + */ NPY_NO_EXPORT void -OBJECT_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)) { +OBJECT@suffix@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)) { BINARY_LOOP { - int ret; PyObject *ret_obj; PyObject *in1 = *(PyObject **)ip1; PyObject *in2 = *(PyObject **)ip2; @@ -2860,14 +2864,21 @@ OBJECT_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUS if (ret_obj == NULL) { return; } - ret = PyObject_IsTrue(ret_obj); - Py_DECREF(ret_obj); - if (ret == -1) { - return; +#if @as_bool@ + { + int ret = PyObject_IsTrue(ret_obj); + Py_DECREF(ret_obj); + if (ret == -1) { + return; + } + *((npy_bool *)op1) = (npy_bool)ret; } - *((npy_bool *)op1) = (npy_bool)ret; +#else + *((PyObject **)op1) = ret_obj; +#endif } } +/**end repeat1**/ /**end repeat**/ NPY_NO_EXPORT void diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src index a01ef1529..5c2b2c22c 100644 --- a/numpy/core/src/umath/loops.h.src +++ b/numpy/core/src/umath/loops.h.src @@ -496,8 +496,12 @@ TIMEDELTA_mm_d_divide(char **args, npy_intp *dimensions, npy_intp *steps, void * * #kind = equal, not_equal, greater, greater_equal, less, less_equal# * #OP = EQ, NE, GT, GE, LT, LE# */ +/**begin repeat1 + * #suffix = , _OO_O# + */ NPY_NO_EXPORT void -OBJECT_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)); +OBJECT@suffix@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)); +/**end repeat1**/ /**end repeat**/ NPY_NO_EXPORT void diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index ea9ca021c..7a276c04d 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -814,6 +814,21 @@ class TestUfunc(object): assert_equal(np.logical_or.reduce(a), 3) assert_equal(np.logical_and.reduce(a), None) + def test_object_comparison(self): + class HasComparisons(object): + def __eq__(self, other): + return '==' + + arr0d = np.array(HasComparisons()) + assert_equal(arr0d == arr0d, True) + assert_equal(np.equal(arr0d, arr0d), True) # normal behavior is a cast + assert_equal(np.equal(arr0d, arr0d, dtype=object), '==') + + arr1d = np.array([HasComparisons()]) + assert_equal(arr1d == arr1d, np.array([True])) + assert_equal(np.equal(arr1d, arr1d), np.array([True])) # normal behavior is a cast + assert_equal(np.equal(arr1d, arr1d, dtype=object), np.array(['=='])) + def test_object_array_reduction(self): # Reductions on object arrays a = np.array(['a', 'b', 'c'], dtype=object) |