diff options
| -rw-r--r-- | numpy/core/tests/test_umath.py | 78 |
1 files changed, 24 insertions, 54 deletions
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 01ef9365e..7b6e2ee92 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -185,9 +185,19 @@ class TestOut: class TestComparisons: + import operator + @pytest.mark.parametrize('dtype', np.sctypes['uint'] + np.sctypes['int'] + np.sctypes['float'] + [np.bool_]) - def test_comparison_functions(self, dtype): + @pytest.mark.parametrize('py_comp,np_comp', [ + (operator.lt, np.less), + (operator.le, np.less_equal), + (operator.gt, np.greater), + (operator.ge, np.greater_equal), + (operator.eq, np.equal), + (operator.ne, np.not_equal) + ]) + def test_comparison_functions(self, dtype, py_comp, np_comp): # Initialize input arrays if dtype == np.bool_: a = np.random.choice(a=[False, True], size=1000) @@ -197,69 +207,29 @@ class TestComparisons: a = np.random.randint(low=1, high=10, size=1000).astype(dtype) b = np.random.randint(low=1, high=10, size=1000).astype(dtype) scalar = 5 - scalar_np = np.dtype(dtype).type(scalar) + np_scalar = np.dtype(dtype).type(scalar) a_lst = a.tolist() b_lst = b.tolist() # (Binary) Comparison (x1=array, x2=array) - lt_b = np.less(a, b) - le_b = np.less_equal(a, b) - gt_b = np.greater(a, b) - ge_b = np.greater_equal(a, b) - eq_b = np.equal(a, b) - ne_b = np.not_equal(a, b) - lt_b_lst = [x < y for x, y in zip(a_lst, b_lst)] - le_b_lst = [x <= y for x, y in zip(a_lst, b_lst)] - gt_b_lst = [x > y for x, y in zip(a_lst, b_lst)] - ge_b_lst = [x >= y for x, y in zip(a_lst, b_lst)] - eq_b_lst = [x == y for x, y in zip(a_lst, b_lst)] - ne_b_lst = [x != y for x, y in zip(a_lst, b_lst)] + comp_b = np_comp(a, b) + comp_b_list = [py_comp(x, y) for x, y in zip(a_lst, b_lst)] # (Scalar1) Comparison (x1=scalar, x2=array) - lt_s1 = np.less(scalar_np, b) - le_s1 = np.less_equal(scalar_np, b) - gt_s1 = np.greater(scalar_np, b) - ge_s1 = np.greater_equal(scalar_np, b) - eq_s1 = np.equal(scalar_np, b) - ne_s1 = np.not_equal(scalar_np, b) - lt_s1_lst = [scalar < x for x in b_lst] - le_s1_lst = [scalar <= x for x in b_lst] - gt_s1_lst = [scalar > x for x in b_lst] - ge_s1_lst = [scalar >= x for x in b_lst] - eq_s1_lst = [scalar == x for x in b_lst] - ne_s1_lst = [scalar != x for x in b_lst] + comp_s1 = np_comp(np_scalar, b) + comp_s1_list = [py_comp(scalar, x) for x in b_lst] # (Scalar2) Comparison (x1=array, x2=scalar) - lt_s2 = np.less(a, scalar_np) - le_s2 = np.less_equal(a, scalar_np) - gt_s2 = np.greater(a, scalar_np) - ge_s2 = np.greater_equal(a, scalar_np) - eq_s2 = np.equal(a, scalar_np) - ne_s2 = np.not_equal(a, scalar_np) - lt_s2_lst = [x < scalar for x in a_lst] - le_s2_lst = [x <= scalar for x in a_lst] - gt_s2_lst = [x > scalar for x in a_lst] - ge_s2_lst = [x >= scalar for x in a_lst] - eq_s2_lst = [x == scalar for x in a_lst] - ne_s2_lst = [x != scalar for x in a_lst] - - # Compare comparison functions (Python vs NumPy) using native Python - def compare(lt, le, gt, ge, eq, ne, lt_lst, le_lst, gt_lst, ge_lst, - eq_lst, ne_lst): - assert_(lt.tolist() == lt_lst, "Comparison function check (lt)") - assert_(le.tolist() == le_lst, "Comparison function check (le)") - assert_(gt.tolist() == gt_lst, "Comparison function check (gt)") - assert_(ge.tolist() == ge_lst, "Comparison function check (ge)") - assert_(eq.tolist() == eq_lst, "Comparison function check (eq)") - assert_(ne.tolist() == ne_lst, "Comparison function check (ne)") + comp_s2 = np_comp(a, np_scalar) + comp_s2_list = [py_comp(x, scalar) for x in a_lst] # Sequence: Binary, Scalar1 and Scalar2 - compare(lt_b, le_b, gt_b, ge_b, eq_b, ne_b, lt_b_lst, le_b_lst, - gt_b_lst, ge_b_lst, eq_b_lst, ne_b_lst) - compare(lt_s1, le_s1, gt_s1, ge_s1, eq_s1, ne_s1, lt_s1_lst, le_s1_lst, - gt_s1_lst, ge_s1_lst, eq_s1_lst, ne_s1_lst) - compare(lt_s2, le_s2, gt_s2, ge_s2, eq_s2, ne_s2, lt_s2_lst, le_s2_lst, - gt_s2_lst, ge_s2_lst, eq_s2_lst, ne_s2_lst) + assert_(comp_b.tolist() == comp_b_list, + f"Failed comparision ({py_comp.__name__})") + assert_(comp_s1.tolist() == comp_s1_list, + f"Failed comparision ({py_comp.__name__})") + assert_(comp_s2.tolist() == comp_s2_list, + f"Failed comparision ({py_comp.__name__})") def test_ignore_object_identity_in_equal(self): # Check comparing identical objects whose comparison |
