summaryrefslogtreecommitdiff
path: root/numpy/testing/_private/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r--numpy/testing/_private/utils.py69
1 files changed, 49 insertions, 20 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py
index 8a31fcf15..b14c776d9 100644
--- a/numpy/testing/_private/utils.py
+++ b/numpy/testing/_private/utils.py
@@ -284,6 +284,10 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
check that all elements of these objects are equal. An exception is raised
at the first conflicting values.
+ When one of `actual` and `desired` is a scalar and the other is array_like,
+ the function checks that each element of the array_like object is equal to
+ the scalar.
+
This function handles NaN comparisons as if NaN was a "normal" number.
That is, no assertion is raised if both objects have NaNs in the same
positions. This is in contrast to the IEEE standard on NaNs, which says
@@ -374,21 +378,6 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
if isscalar(desired) != isscalar(actual):
raise AssertionError(msg)
- # Inf/nan/negative zero handling
- try:
- isdesnan = gisnan(desired)
- isactnan = gisnan(actual)
- if isdesnan and isactnan:
- return # both nan, so equal
-
- # handle signed zero specially for floats
- if desired == 0 and actual == 0:
- if not signbit(desired) == signbit(actual):
- raise AssertionError(msg)
-
- except (TypeError, ValueError, NotImplementedError):
- pass
-
try:
isdesnat = isnat(desired)
isactnat = isnat(actual)
@@ -404,6 +393,33 @@ def assert_equal(actual, desired, err_msg='', verbose=True):
except (TypeError, ValueError, NotImplementedError):
pass
+ # Inf/nan/negative zero handling
+ try:
+ isdesnan = gisnan(desired)
+ isactnan = gisnan(actual)
+ if isdesnan and isactnan:
+ return # both nan, so equal
+
+ # handle signed zero specially for floats
+ array_actual = array(actual)
+ array_desired = array(desired)
+ if (array_actual.dtype.char in 'Mm' or
+ array_desired.dtype.char in 'Mm'):
+ # version 1.18
+ # until this version, gisnan failed for datetime64 and timedelta64.
+ # Now it succeeds but comparison to scalar with a different type
+ # emits a DeprecationWarning.
+ # Avoid that by skipping the next check
+ raise NotImplementedError('cannot compare to a scalar '
+ 'with a different type')
+
+ if desired == 0 and actual == 0:
+ if not signbit(desired) == signbit(actual):
+ raise AssertionError(msg)
+
+ except (TypeError, ValueError, NotImplementedError):
+ pass
+
try:
# Explicitly use __eq__ for comparison, gh-2552
if not (desired == actual):
@@ -841,10 +857,11 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
Raises an AssertionError if two array_like objects are not equal.
Given two array_like objects, check that the shape is equal and all
- elements of these objects are equal. An exception is raised at
- shape mismatch or conflicting values. In contrast to the standard usage
- in numpy, NaNs are compared like numbers, no assertion is raised if
- both objects have NaNs in the same positions.
+ elements of these objects are equal (but see the Notes for the special
+ handling of a scalar). An exception is raised at shape mismatch or
+ conflicting values. In contrast to the standard usage in numpy, NaNs
+ are compared like numbers, no assertion is raised if both objects have
+ NaNs in the same positions.
The usual caution for verifying equality with floating point numbers is
advised.
@@ -871,6 +888,12 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
relative and/or absolute precision.
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
+ Notes
+ -----
+ When one of `x` and `y` is a scalar and the other is array_like, the
+ function checks that each element of the array_like object is equal to
+ the scalar.
+
Examples
--------
The first assert does not raise an exception:
@@ -878,7 +901,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
>>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
... [np.exp(0),2.33333, np.nan])
- Assert fails with numerical inprecision with floats:
+ Assert fails with numerical imprecision with floats:
>>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
... [1, np.sqrt(np.pi)**2, np.nan])
@@ -899,6 +922,12 @@ def assert_array_equal(x, y, err_msg='', verbose=True):
... [1, np.sqrt(np.pi)**2, np.nan],
... rtol=1e-10, atol=0)
+ As mentioned in the Notes section, `assert_array_equal` has special
+ handling for scalars. Here the test checks that each value in `x` is 3:
+
+ >>> x = np.full((2, 5), fill_value=3)
+ >>> np.testing.assert_array_equal(x, 3)
+
"""
__tracebackhide__ = True # Hide traceback for py.test
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,