summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/tests/test_umath.py78
1 files changed, 24 insertions, 54 deletions
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 01ef9365e..7b6e2ee92 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -185,9 +185,19 @@ class TestOut:
class TestComparisons:
+ import operator
+
@pytest.mark.parametrize('dtype', np.sctypes['uint'] + np.sctypes['int'] +
np.sctypes['float'] + [np.bool_])
- def test_comparison_functions(self, dtype):
+ @pytest.mark.parametrize('py_comp,np_comp', [
+ (operator.lt, np.less),
+ (operator.le, np.less_equal),
+ (operator.gt, np.greater),
+ (operator.ge, np.greater_equal),
+ (operator.eq, np.equal),
+ (operator.ne, np.not_equal)
+ ])
+ def test_comparison_functions(self, dtype, py_comp, np_comp):
# Initialize input arrays
if dtype == np.bool_:
a = np.random.choice(a=[False, True], size=1000)
@@ -197,69 +207,29 @@ class TestComparisons:
a = np.random.randint(low=1, high=10, size=1000).astype(dtype)
b = np.random.randint(low=1, high=10, size=1000).astype(dtype)
scalar = 5
- scalar_np = np.dtype(dtype).type(scalar)
+ np_scalar = np.dtype(dtype).type(scalar)
a_lst = a.tolist()
b_lst = b.tolist()
# (Binary) Comparison (x1=array, x2=array)
- lt_b = np.less(a, b)
- le_b = np.less_equal(a, b)
- gt_b = np.greater(a, b)
- ge_b = np.greater_equal(a, b)
- eq_b = np.equal(a, b)
- ne_b = np.not_equal(a, b)
- lt_b_lst = [x < y for x, y in zip(a_lst, b_lst)]
- le_b_lst = [x <= y for x, y in zip(a_lst, b_lst)]
- gt_b_lst = [x > y for x, y in zip(a_lst, b_lst)]
- ge_b_lst = [x >= y for x, y in zip(a_lst, b_lst)]
- eq_b_lst = [x == y for x, y in zip(a_lst, b_lst)]
- ne_b_lst = [x != y for x, y in zip(a_lst, b_lst)]
+ comp_b = np_comp(a, b)
+ comp_b_list = [py_comp(x, y) for x, y in zip(a_lst, b_lst)]
# (Scalar1) Comparison (x1=scalar, x2=array)
- lt_s1 = np.less(scalar_np, b)
- le_s1 = np.less_equal(scalar_np, b)
- gt_s1 = np.greater(scalar_np, b)
- ge_s1 = np.greater_equal(scalar_np, b)
- eq_s1 = np.equal(scalar_np, b)
- ne_s1 = np.not_equal(scalar_np, b)
- lt_s1_lst = [scalar < x for x in b_lst]
- le_s1_lst = [scalar <= x for x in b_lst]
- gt_s1_lst = [scalar > x for x in b_lst]
- ge_s1_lst = [scalar >= x for x in b_lst]
- eq_s1_lst = [scalar == x for x in b_lst]
- ne_s1_lst = [scalar != x for x in b_lst]
+ comp_s1 = np_comp(np_scalar, b)
+ comp_s1_list = [py_comp(scalar, x) for x in b_lst]
# (Scalar2) Comparison (x1=array, x2=scalar)
- lt_s2 = np.less(a, scalar_np)
- le_s2 = np.less_equal(a, scalar_np)
- gt_s2 = np.greater(a, scalar_np)
- ge_s2 = np.greater_equal(a, scalar_np)
- eq_s2 = np.equal(a, scalar_np)
- ne_s2 = np.not_equal(a, scalar_np)
- lt_s2_lst = [x < scalar for x in a_lst]
- le_s2_lst = [x <= scalar for x in a_lst]
- gt_s2_lst = [x > scalar for x in a_lst]
- ge_s2_lst = [x >= scalar for x in a_lst]
- eq_s2_lst = [x == scalar for x in a_lst]
- ne_s2_lst = [x != scalar for x in a_lst]
-
- # Compare comparison functions (Python vs NumPy) using native Python
- def compare(lt, le, gt, ge, eq, ne, lt_lst, le_lst, gt_lst, ge_lst,
- eq_lst, ne_lst):
- assert_(lt.tolist() == lt_lst, "Comparison function check (lt)")
- assert_(le.tolist() == le_lst, "Comparison function check (le)")
- assert_(gt.tolist() == gt_lst, "Comparison function check (gt)")
- assert_(ge.tolist() == ge_lst, "Comparison function check (ge)")
- assert_(eq.tolist() == eq_lst, "Comparison function check (eq)")
- assert_(ne.tolist() == ne_lst, "Comparison function check (ne)")
+ comp_s2 = np_comp(a, np_scalar)
+ comp_s2_list = [py_comp(x, scalar) for x in a_lst]
# Sequence: Binary, Scalar1 and Scalar2
- compare(lt_b, le_b, gt_b, ge_b, eq_b, ne_b, lt_b_lst, le_b_lst,
- gt_b_lst, ge_b_lst, eq_b_lst, ne_b_lst)
- compare(lt_s1, le_s1, gt_s1, ge_s1, eq_s1, ne_s1, lt_s1_lst, le_s1_lst,
- gt_s1_lst, ge_s1_lst, eq_s1_lst, ne_s1_lst)
- compare(lt_s2, le_s2, gt_s2, ge_s2, eq_s2, ne_s2, lt_s2_lst, le_s2_lst,
- gt_s2_lst, ge_s2_lst, eq_s2_lst, ne_s2_lst)
+ assert_(comp_b.tolist() == comp_b_list,
+ f"Failed comparision ({py_comp.__name__})")
+ assert_(comp_s1.tolist() == comp_s1_list,
+ f"Failed comparision ({py_comp.__name__})")
+ assert_(comp_s2.tolist() == comp_s2_list,
+ f"Failed comparision ({py_comp.__name__})")
def test_ignore_object_identity_in_equal(self):
# Check comparing identical objects whose comparison