diff options
Diffstat (limited to 'numpy/testing/_private/utils.py')
-rw-r--r-- | numpy/testing/_private/utils.py | 66 |
1 files changed, 58 insertions, 8 deletions
diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index e4f8b9892..c553658cb 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -36,6 +36,7 @@ __all__ = [ 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', 'HAS_REFCOUNT', 'suppress_warnings', 'assert_array_compare', 'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON', + '_OLD_PROMOTION' ] @@ -52,6 +53,8 @@ IS_PYSTON = hasattr(sys, "pyston_version_info") HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON HAS_LAPACK64 = numpy.linalg.lapack_lite._ilp64 +_OLD_PROMOTION = lambda: np._get_promotion_state() == 'legacy' + def import_nose(): """ Import nose only when needed. @@ -473,6 +476,7 @@ def print_assert_equal(test_string, actual, desired): raise AssertionError(msg.getvalue()) +@np._no_nep50_warning() def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): """ Raises an AssertionError if two items are not equal up to desired @@ -485,7 +489,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): The test verifies that the elements of `actual` and `desired` satisfy. - ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + ``abs(desired-actual) < float64(1.5 * 10**(-decimal))`` That is a looser test than originally documented, but agrees with what the actual implementation in `assert_array_almost_equal` did up to rounding @@ -595,10 +599,11 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): return except (NotImplementedError, TypeError): pass - if abs(desired - actual) >= 1.5 * 10.0**(-decimal): + if abs(desired - actual) >= np.float64(1.5 * 10.0**(-decimal)): raise AssertionError(_build_err_msg()) +@np._no_nep50_warning() def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): """ Raises an AssertionError if two items are not equal up to significant @@ -698,8 +703,10 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): raise AssertionError(msg) +@np._no_nep50_warning() def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', - precision=6, equal_nan=True, equal_inf=True): + precision=6, equal_nan=True, equal_inf=True, + *, strict=False): __tracebackhide__ = True # Hide traceback for py.test from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ @@ -753,11 +760,18 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', return y_id try: - cond = (x.shape == () or y.shape == ()) or x.shape == y.shape + if strict: + cond = x.shape == y.shape and x.dtype == y.dtype + else: + cond = (x.shape == () or y.shape == ()) or x.shape == y.shape if not cond: + if x.shape != y.shape: + reason = f'\n(shapes {x.shape}, {y.shape} mismatch)' + else: + reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)' msg = build_err_msg([x, y], err_msg - + f'\n(shapes {x.shape}, {y.shape} mismatch)', + + reason, verbose=verbose, header=header, names=('x', 'y'), precision=precision) raise AssertionError(msg) @@ -814,6 +828,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', # ignore errors for non-numeric types with contextlib.suppress(TypeError): error = abs(x - y) + if np.issubdtype(x.dtype, np.unsignedinteger): + error2 = abs(y - x) + np.minimum(error, error2, out=error) max_abs_error = max(error) if getattr(error, 'dtype', object_) == object_: remarks.append('Max absolute difference: ' @@ -852,7 +869,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', raise ValueError(msg) -def assert_array_equal(x, y, err_msg='', verbose=True): +def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): """ Raises an AssertionError if two array_like objects are not equal. @@ -876,6 +893,10 @@ def assert_array_equal(x, y, err_msg='', verbose=True): The error message to be printed in case of failure. verbose : bool, optional If True, the conflicting values are appended to the error message. + strict : bool, optional + If True, raise an AssertionError when either the shape or the data + type of the array_like objects does not match. The special + handling for scalars mentioned in the Notes section is disabled. Raises ------ @@ -892,7 +913,7 @@ def assert_array_equal(x, y, err_msg='', verbose=True): ----- When one of `x` and `y` is a scalar and the other is array_like, the function checks that each element of the array_like object is equal to - the scalar. + the scalar. This behaviour can be disabled with the `strict` parameter. Examples -------- @@ -929,12 +950,41 @@ def assert_array_equal(x, y, err_msg='', verbose=True): >>> x = np.full((2, 5), fill_value=3) >>> np.testing.assert_array_equal(x, 3) + Use `strict` to raise an AssertionError when comparing a scalar with an + array: + + >>> np.testing.assert_array_equal(x, 3, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + <BLANKLINE> + (shapes (2, 5), () mismatch) + x: array([[3, 3, 3, 3, 3], + [3, 3, 3, 3, 3]]) + y: array(3) + + The `strict` parameter also ensures that the array data types match: + + >>> x = np.array([2, 2, 2]) + >>> y = np.array([2., 2., 2.], dtype=np.float32) + >>> np.testing.assert_array_equal(x, y, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + <BLANKLINE> + (dtypes int64, float32 mismatch) + x: array([2, 2, 2]) + y: array([2., 2., 2.], dtype=float32) """ __tracebackhide__ = True # Hide traceback for py.test assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, - verbose=verbose, header='Arrays are not equal') + verbose=verbose, header='Arrays are not equal', + strict=strict) +@np._no_nep50_warning() def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): """ Raises an AssertionError if two objects are not equal up to desired |