diff options
Diffstat (limited to 'numpy/testing/tests/test_utils.py')
-rw-r--r-- | numpy/testing/tests/test_utils.py | 70 |
1 files changed, 32 insertions, 38 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 44f93a693..b899e94f4 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -1,10 +1,7 @@ -from __future__ import division, absolute_import, print_function - import warnings import sys import os import itertools -import textwrap import pytest import weakref @@ -20,7 +17,7 @@ from numpy.testing import ( from numpy.core.overrides import ARRAY_FUNCTION_ENABLED -class _GenericTest(object): +class _GenericTest: def _test_equal(self, a, b): self._assert_func(a, b) @@ -90,6 +87,21 @@ class TestArrayEqual(_GenericTest): for t in ['S1', 'U1']: foo(t) + def test_0_ndim_array(self): + x = np.array(473963742225900817127911193656584771) + y = np.array(18535119325151578301457182298393896) + assert_raises(AssertionError, self._assert_func, x, y) + + y = x + self._assert_func(x, y) + + x = np.array(43) + y = np.array(10) + assert_raises(AssertionError, self._assert_func, x, y) + + y = x + self._assert_func(x, y) + def test_generic_rank3(self): """Test rank 3 array for all dtypes.""" def foo(t): @@ -196,7 +208,7 @@ class TestArrayEqual(_GenericTest): self._test_not_equal(b, a) -class TestBuildErrorMessage(object): +class TestBuildErrorMessage: def test_build_err_msg_defaults(self): x = np.array([1.00001, 2.00002, 3.00003]) @@ -328,24 +340,6 @@ class TestEqual(TestArrayEqual): self._assert_func(x, x) self._test_not_equal(x, y) - def test_error_message(self): - with pytest.raises(AssertionError) as exc_info: - self._assert_func(np.array([1, 2]), np.array([[1, 2]])) - msg = str(exc_info.value) - msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)") - msg_reference = textwrap.dedent("""\ - - Arrays are not equal - - (shapes (2,), (1, 2) mismatch) - x: array([1, 2]) - y: array([[1, 2]])""") - - try: - assert_equal(msg, msg_reference) - except AssertionError: - assert_equal(msg2, msg_reference) - def test_object(self): #gh-12942 import datetime @@ -603,14 +597,14 @@ class TestAlmostEqual(_GenericTest): self._assert_func(a, a) -class TestApproxEqual(object): +class TestApproxEqual: def setup(self): self._assert_func = assert_approx_equal - def test_simple_arrays(self): - x = np.array([1234.22]) - y = np.array([1234.23]) + def test_simple_0d_arrays(self): + x = np.array(1234.22) + y = np.array(1234.23) self._assert_func(x, y, significant=5) self._assert_func(x, y, significant=6) @@ -646,7 +640,7 @@ class TestApproxEqual(object): assert_raises(AssertionError, lambda: self._assert_func(ainf, anan)) -class TestArrayAssertLess(object): +class TestArrayAssertLess: def setup(self): self._assert_func = assert_array_less @@ -756,7 +750,7 @@ class TestArrayAssertLess(object): @pytest.mark.skip(reason="The raises decorator depends on Nose") -class TestRaises(object): +class TestRaises: def setup(self): class MyException(Exception): @@ -790,7 +784,7 @@ class TestRaises(object): raise AssertionError("should have raised an AssertionError") -class TestWarns(object): +class TestWarns: def test_warn(self): def f(): @@ -841,7 +835,7 @@ class TestWarns(object): raise AssertionError("wrong warning caught by assert_warn") -class TestAssertAllclose(object): +class TestAssertAllclose: def test_simple(self): x = 1e-3 @@ -911,7 +905,7 @@ class TestAssertAllclose(object): assert_('Max relative difference: 0.5' in msg) -class TestArrayAlmostEqualNulp(object): +class TestArrayAlmostEqualNulp: def test_float64_pass(self): # The number of units of least precision @@ -1108,7 +1102,7 @@ class TestArrayAlmostEqualNulp(object): xi, y + y*1j, nulp) -class TestULP(object): +class TestULP: def test_equal(self): x = np.random.randn(10) @@ -1164,7 +1158,7 @@ class TestULP(object): maxulp=maxulp)) -class TestStringEqual(object): +class TestStringEqual: def test_simple(self): assert_string_equal("hello", "hello") assert_string_equal("hello\nmultiline", "hello\nmultiline") @@ -1226,7 +1220,7 @@ def test_warn_len_equal_call_scenarios(): # check that no assertion is uncaught # parallel scenario -- no warning issued yet - class mod(object): + class mod: pass mod_inst = mod() @@ -1236,7 +1230,7 @@ def test_warn_len_equal_call_scenarios(): # serial test scenario -- the __warningregistry__ # attribute should be present - class mod(object): + class mod: def __init__(self): self.__warningregistry__ = {'warning1':1, 'warning2':2} @@ -1511,7 +1505,7 @@ def test_clear_and_catch_warnings_inherit(): @pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts") -class TestAssertNoGcCycles(object): +class TestAssertNoGcCycles: """ Test assert_no_gc_cycles """ def test_passes(self): def no_cycle(): @@ -1545,7 +1539,7 @@ class TestAssertNoGcCycles(object): error, instead of hanging forever trying to clear it. """ - class ReferenceCycleInDel(object): + class ReferenceCycleInDel: """ An object that not only contains a reference cycle, but creates new cycles whenever it's garbage-collected and its __del__ runs |