summaryrefslogtreecommitdiff
path: root/numpy/core/tests
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2019-11-04 17:32:10 -0700
committerGitHub <noreply@github.com>2019-11-04 17:32:10 -0700
commit4393e0c0cb1772d0f9691117d238d000b209decd (patch)
tree5dda9c9a885af59ecdcd4c3c947e69bd6bab48bf /numpy/core/tests
parent7b8513fff4c3673d8a967a1d72a73feab2072e8f (diff)
parent9aa8c4756d7282dd49a7093ce7a6725258106b9a (diff)
downloadnumpy-4393e0c0cb1772d0f9691117d238d000b209decd.tar.gz
Merge pull request #14800 from mattip/reorder-obj-comparison-loop
ENH: change object-array comparisons to prefer OO->O unfuncs
Diffstat (limited to 'numpy/core/tests')
-rw-r--r--numpy/core/tests/test_deprecations.py5
-rw-r--r--numpy/core/tests/test_ufunc.py14
-rw-r--r--numpy/core/tests/test_umath.py10
3 files changed, 18 insertions, 11 deletions
diff --git a/numpy/core/tests/test_deprecations.py b/numpy/core/tests/test_deprecations.py
index 8bffaa9af..9bdcd8241 100644
--- a/numpy/core/tests/test_deprecations.py
+++ b/numpy/core/tests/test_deprecations.py
@@ -172,10 +172,11 @@ class TestComparisonDeprecations(_DeprecationTestCase):
# (warning is issued a couple of times here)
self.assert_deprecated(op, args=(a, a[:-1]), num=None)
- # Element comparison error (numpy array can't be compared).
+ # ragged array comparison returns True/False
a = np.array([1, np.array([1,2,3])], dtype=object)
b = np.array([1, np.array([1,2,3])], dtype=object)
- self.assert_deprecated(op, args=(a, b), num=None)
+ res = op(a, b)
+ assert res.dtype == 'object'
def test_string(self):
# For two string arrays, strings always raised the broadcasting error:
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index 707c690dd..d9f961581 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -1090,14 +1090,18 @@ class TestUfunc(object):
return '=='
arr0d = np.array(HasComparisons())
- assert_equal(arr0d == arr0d, True)
- assert_equal(np.equal(arr0d, arr0d), True) # normal behavior is a cast
+ assert_equal(arr0d == arr0d, '==')
+ assert_equal(np.equal(arr0d, arr0d), '==')
+ assert_equal(np.equal(arr0d, arr0d, dtype=bool), True)
assert_equal(np.equal(arr0d, arr0d, dtype=object), '==')
arr1d = np.array([HasComparisons()])
- assert_equal(arr1d == arr1d, np.array([True]))
- assert_equal(np.equal(arr1d, arr1d), np.array([True])) # normal behavior is a cast
- assert_equal(np.equal(arr1d, arr1d, dtype=object), np.array(['==']))
+ ret_obj = np.array(['=='], dtype=object)
+ ret_bool = np.array([True])
+ assert_equal(arr1d == arr1d, ret_obj)
+ assert_equal(np.equal(arr1d, arr1d), ret_obj)
+ assert_equal(np.equal(arr1d, arr1d, dtype=object), ret_obj)
+ assert_equal(np.equal(arr1d, arr1d, dtype=bool), ret_bool)
def test_object_array_reduction(self):
# Reductions on object arrays
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 9b4ce9e47..96a9f1f8b 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -170,10 +170,11 @@ class TestOut(object):
class TestComparisons(object):
def test_ignore_object_identity_in_equal(self):
- # Check error raised when comparing identical objects whose comparison
+ # Check comparing identical objects whose comparison
# is not a simple boolean, e.g., arrays that are compared elementwise.
a = np.array([np.array([1, 2, 3]), None], dtype=object)
- assert_raises(ValueError, np.equal, a, a)
+ b = np.equal(a, a.copy())
+ assert b.shape == a.shape
# Check error raised when comparing identical non-comparable objects.
class FunkyType(object):
@@ -188,10 +189,11 @@ class TestComparisons(object):
assert_equal(np.equal(a, a), [False])
def test_ignore_object_identity_in_not_equal(self):
- # Check error raised when comparing identical objects whose comparison
+ # Check comparing identical objects whose comparison
# is not a simple boolean, e.g., arrays that are compared elementwise.
a = np.array([np.array([1, 2, 3]), None], dtype=object)
- assert_raises(ValueError, np.not_equal, a, a)
+ b = np.not_equal(a, a.copy())
+ assert b.shape == a.shape
# Check error raised when comparing identical non-comparable objects.
class FunkyType(object):