diff options
author | Eric V. Smith <ericvsmith@users.noreply.github.com> | 2018-01-27 19:07:40 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-27 19:07:40 -0500 |
commit | ea8fc52e75363276db23c6a8d7a689f79efce4f9 (patch) | |
tree | ca662ba631df1f6e6e32b5b0d95a6b5458d5699c /Lib/test/test_dataclasses.py | |
parent | 2a2247ce5e1984eb2f2c41b269b38dbb795a60cf (diff) | |
download | cpython-git-ea8fc52e75363276db23c6a8d7a689f79efce4f9.tar.gz |
bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)
Class authors no longer need to specify repr=False if they want to provide a custom __repr__ for dataclasses. The same thing applies for the other dunder methods that the dataclass decorator adds. If dataclass finds that a dunder methods is defined in the class, it will not overwrite it.
Diffstat (limited to 'Lib/test/test_dataclasses.py')
-rwxr-xr-x | Lib/test/test_dataclasses.py | 664 |
1 files changed, 452 insertions, 212 deletions
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 69819ea450..53281f9dd9 100755 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -9,6 +9,7 @@ import unittest from unittest.mock import Mock from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar from collections import deque, OrderedDict, namedtuple +from functools import total_ordering # Just any custom exception we can catch. class CustomError(Exception): pass @@ -82,68 +83,12 @@ class TestCase(unittest.TestCase): class C(B): x: int = 0 - def test_overwriting_init(self): - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __init__ ' - 'in C'): - @dataclass - class C: - x: int - def __init__(self, x): - self.x = 2 * x - - @dataclass(init=False) - class C: - x: int - def __init__(self, x): - self.x = 2 * x - self.assertEqual(C(5).x, 10) - - def test_overwriting_repr(self): - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __repr__ ' - 'in C'): - @dataclass - class C: - x: int - def __repr__(self): - pass - - @dataclass(repr=False) - class C: - x: int - def __repr__(self): - return 'x' - self.assertEqual(repr(C(0)), 'x') - - def test_overwriting_cmp(self): - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __eq__ ' - 'in C'): - # This will generate the comparison functions, make sure we can't - # overwrite them. - @dataclass(hash=False, frozen=False) - class C: - x: int - def __eq__(self): - pass - - @dataclass(order=False, eq=False) + def test_overwriting_hash(self): + @dataclass(frozen=True) class C: x: int - def __eq__(self, other): - return True - self.assertEqual(C(0), 'x') - - def test_overwriting_hash(self): - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __hash__ ' - 'in C'): - @dataclass(frozen=True) - class C: - x: int - def __hash__(self): - pass + def __hash__(self): + pass @dataclass(frozen=True,hash=False) class C: @@ -152,14 +97,11 @@ class TestCase(unittest.TestCase): return 600 self.assertEqual(hash(C(0)), 600) - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __hash__ ' - 'in C'): - @dataclass(frozen=True) - class C: - x: int - def __hash__(self): - pass + @dataclass(frozen=True) + class C: + x: int + def __hash__(self): + pass @dataclass(frozen=True, hash=False) class C: @@ -168,33 +110,6 @@ class TestCase(unittest.TestCase): return 600 self.assertEqual(hash(C(0)), 600) - def test_overwriting_frozen(self): - # frozen uses __setattr__ and __delattr__ - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __setattr__ ' - 'in C'): - @dataclass(frozen=True) - class C: - x: int - def __setattr__(self): - pass - - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __delattr__ ' - 'in C'): - @dataclass(frozen=True) - class C: - x: int - def __delattr__(self): - pass - - @dataclass(frozen=False) - class C: - x: int - def __setattr__(self, name, value): - self.__dict__['x'] = value * 2 - self.assertEqual(C(10).x, 20) - def test_overwrite_fields_in_derived_class(self): # Note that x from C1 replaces x in Base, but the order remains # the same as defined in Base. @@ -239,34 +154,6 @@ class TestCase(unittest.TestCase): first = next(iter(sig.parameters)) self.assertEqual('self', first) - def test_repr(self): - @dataclass - class B: - x: int - - @dataclass - class C(B): - y: int = 10 - - o = C(4) - self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)') - - @dataclass - class D(C): - x: int = 20 - self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)') - - @dataclass - class C: - @dataclass - class D: - i: int - @dataclass - class E: - pass - self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)') - self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()') - def test_0_field_compare(self): # Ensure that order=False is the default. @dataclass @@ -420,80 +307,8 @@ class TestCase(unittest.TestCase): self.assertEqual(hash(C(4)), hash((4,))) self.assertEqual(hash(C(42)), hash((42,))) - def test_hash(self): - @dataclass(hash=True) - class C: - x: int - y: str - self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) - - def test_no_hash(self): - @dataclass(hash=None) - class C: - x: int - with self.assertRaisesRegex(TypeError, - "unhashable type: 'C'"): - hash(C(1)) - - def test_hash_rules(self): - # There are 24 cases of: - # hash=True/False/None - # eq=True/False - # order=True/False - # frozen=True/False - for (hash, eq, order, frozen, result ) in [ - (False, False, False, False, 'absent'), - (False, False, False, True, 'absent'), - (False, False, True, False, 'exception'), - (False, False, True, True, 'exception'), - (False, True, False, False, 'absent'), - (False, True, False, True, 'absent'), - (False, True, True, False, 'absent'), - (False, True, True, True, 'absent'), - (True, False, False, False, 'fn'), - (True, False, False, True, 'fn'), - (True, False, True, False, 'exception'), - (True, False, True, True, 'exception'), - (True, True, False, False, 'fn'), - (True, True, False, True, 'fn'), - (True, True, True, False, 'fn'), - (True, True, True, True, 'fn'), - (None, False, False, False, 'absent'), - (None, False, False, True, 'absent'), - (None, False, True, False, 'exception'), - (None, False, True, True, 'exception'), - (None, True, False, False, 'none'), - (None, True, False, True, 'fn'), - (None, True, True, False, 'none'), - (None, True, True, True, 'fn'), - ]: - with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen): - if result == 'exception': - with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'): - @dataclass(hash=hash, eq=eq, order=order, frozen=frozen) - class C: - pass - else: - @dataclass(hash=hash, eq=eq, order=order, frozen=frozen) - class C: - pass - - # See if the result matches what's expected. - if result == 'fn': - # __hash__ contains the function we generated. - self.assertIn('__hash__', C.__dict__) - self.assertIsNotNone(C.__dict__['__hash__']) - elif result == 'absent': - # __hash__ is not present in our class. - self.assertNotIn('__hash__', C.__dict__) - elif result == 'none': - # __hash__ is set to None. - self.assertIn('__hash__', C.__dict__) - self.assertIsNone(C.__dict__['__hash__']) - else: - assert False, f'unknown result {result!r}' - def test_eq_order(self): + # Test combining eq and order. for (eq, order, result ) in [ (False, False, 'neither'), (False, True, 'exception'), @@ -513,21 +328,18 @@ class TestCase(unittest.TestCase): if result == 'neither': self.assertNotIn('__eq__', C.__dict__) - self.assertNotIn('__ne__', C.__dict__) self.assertNotIn('__lt__', C.__dict__) self.assertNotIn('__le__', C.__dict__) self.assertNotIn('__gt__', C.__dict__) self.assertNotIn('__ge__', C.__dict__) elif result == 'both': self.assertIn('__eq__', C.__dict__) - self.assertIn('__ne__', C.__dict__) self.assertIn('__lt__', C.__dict__) self.assertIn('__le__', C.__dict__) self.assertIn('__gt__', C.__dict__) self.assertIn('__ge__', C.__dict__) elif result == 'eq_only': self.assertIn('__eq__', C.__dict__) - self.assertIn('__ne__', C.__dict__) self.assertNotIn('__lt__', C.__dict__) self.assertNotIn('__le__', C.__dict__) self.assertNotIn('__gt__', C.__dict__) @@ -811,19 +623,6 @@ class TestCase(unittest.TestCase): y: int self.assertNotEqual(Point(1, 3), C(1, 3)) - def test_base_has_init(self): - class B: - def __init__(self): - pass - - # Make sure that declaring this class doesn't raise an error. - # The issue is that we can't override __init__ in our class, - # but it should be okay to add __init__ to us if our base has - # an __init__. - @dataclass - class C(B): - x: int = 0 - def test_frozen(self): @dataclass(frozen=True) class C: @@ -2065,6 +1864,7 @@ class TestCase(unittest.TestCase): 'y': int, 'z': 'typing.Any'}) + class TestDocString(unittest.TestCase): def assertDocStrEqual(self, a, b): # Because 3.6 and 3.7 differ in how inspect.signature work @@ -2154,5 +1954,445 @@ class TestDocString(unittest.TestCase): self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") +class TestInit(unittest.TestCase): + def test_base_has_init(self): + class B: + def __init__(self): + self.z = 100 + pass + + # Make sure that declaring this class doesn't raise an error. + # The issue is that we can't override __init__ in our class, + # but it should be okay to add __init__ to us if our base has + # an __init__. + @dataclass + class C(B): + x: int = 0 + c = C(10) + self.assertEqual(c.x, 10) + self.assertNotIn('z', vars(c)) + + # Make sure that if we don't add an init, the base __init__ + # gets called. + @dataclass(init=False) + class C(B): + x: int = 10 + c = C() + self.assertEqual(c.x, 10) + self.assertEqual(c.z, 100) + + def test_no_init(self): + dataclass(init=False) + class C: + i: int = 0 + self.assertEqual(C().i, 0) + + dataclass(init=False) + class C: + i: int = 2 + def __init__(self): + self.i = 3 + self.assertEqual(C().i, 3) + + def test_overwriting_init(self): + # If the class has __init__, use it no matter the value of + # init=. + + @dataclass + class C: + x: int + def __init__(self, x): + self.x = 2 * x + self.assertEqual(C(3).x, 6) + + @dataclass(init=True) + class C: + x: int + def __init__(self, x): + self.x = 2 * x + self.assertEqual(C(4).x, 8) + + @dataclass(init=False) + class C: + x: int + def __init__(self, x): + self.x = 2 * x + self.assertEqual(C(5).x, 10) + + +class TestRepr(unittest.TestCase): + def test_repr(self): + @dataclass + class B: + x: int + + @dataclass + class C(B): + y: int = 10 + + o = C(4) + self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)') + + @dataclass + class D(C): + x: int = 20 + self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)') + + @dataclass + class C: + @dataclass + class D: + i: int + @dataclass + class E: + pass + self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)') + self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()') + + def test_no_repr(self): + # Test a class with no __repr__ and repr=False. + @dataclass(repr=False) + class C: + x: int + self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at', + repr(C(3))) + + # Test a class with a __repr__ and repr=False. + @dataclass(repr=False) + class C: + x: int + def __repr__(self): + return 'C-class' + self.assertEqual(repr(C(3)), 'C-class') + + def test_overwriting_repr(self): + # If the class has __repr__, use it no matter the value of + # repr=. + + @dataclass + class C: + x: int + def __repr__(self): + return 'x' + self.assertEqual(repr(C(0)), 'x') + + @dataclass(repr=True) + class C: + x: int + def __repr__(self): + return 'x' + self.assertEqual(repr(C(0)), 'x') + + @dataclass(repr=False) + class C: + x: int + def __repr__(self): + return 'x' + self.assertEqual(repr(C(0)), 'x') + + +class TestFrozen(unittest.TestCase): + def test_overwriting_frozen(self): + # frozen uses __setattr__ and __delattr__ + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __setattr__'): + @dataclass(frozen=True) + class C: + x: int + def __setattr__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __delattr__'): + @dataclass(frozen=True) + class C: + x: int + def __delattr__(self): + pass + + @dataclass(frozen=False) + class C: + x: int + def __setattr__(self, name, value): + self.__dict__['x'] = value * 2 + self.assertEqual(C(10).x, 20) + + +class TestEq(unittest.TestCase): + def test_no_eq(self): + # Test a class with no __eq__ and eq=False. + @dataclass(eq=False) + class C: + x: int + self.assertNotEqual(C(0), C(0)) + c = C(3) + self.assertEqual(c, c) + + # Test a class with an __eq__ and eq=False. + @dataclass(eq=False) + class C: + x: int + def __eq__(self, other): + return other == 10 + self.assertEqual(C(3), 10) + + def test_overwriting_eq(self): + # If the class has __eq__, use it no matter the value of + # eq=. + + @dataclass + class C: + x: int + def __eq__(self, other): + return other == 3 + self.assertEqual(C(1), 3) + self.assertNotEqual(C(1), 1) + + @dataclass(eq=True) + class C: + x: int + def __eq__(self, other): + return other == 4 + self.assertEqual(C(1), 4) + self.assertNotEqual(C(1), 1) + + @dataclass(eq=False) + class C: + x: int + def __eq__(self, other): + return other == 5 + self.assertEqual(C(1), 5) + self.assertNotEqual(C(1), 1) + + +class TestOrdering(unittest.TestCase): + def test_functools_total_ordering(self): + # Test that functools.total_ordering works with this class. + @total_ordering + @dataclass + class C: + x: int + def __lt__(self, other): + # Perform the test "backward", just to make + # sure this is being called. + return self.x >= other + + self.assertLess(C(0), -1) + self.assertLessEqual(C(0), -1) + self.assertGreater(C(0), 1) + self.assertGreaterEqual(C(0), 1) + + def test_no_order(self): + # Test that no ordering functions are added by default. + @dataclass(order=False) + class C: + x: int + # Make sure no order methods are added. + self.assertNotIn('__le__', C.__dict__) + self.assertNotIn('__lt__', C.__dict__) + self.assertNotIn('__ge__', C.__dict__) + self.assertNotIn('__gt__', C.__dict__) + + # Test that __lt__ is still called + @dataclass(order=False) + class C: + x: int + def __lt__(self, other): + return False + # Make sure other methods aren't added. + self.assertNotIn('__le__', C.__dict__) + self.assertNotIn('__ge__', C.__dict__) + self.assertNotIn('__gt__', C.__dict__) + + def test_overwriting_order(self): + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __lt__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __lt__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __le__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __le__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __gt__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __gt__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __ge__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __ge__(self): + pass + +class TestHash(unittest.TestCase): + def test_hash(self): + @dataclass(hash=True) + class C: + x: int + y: str + self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) + + def test_hash_false(self): + @dataclass(hash=False) + class C: + x: int + y: str + self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo'))) + + def test_hash_none(self): + @dataclass(hash=None) + class C: + x: int + with self.assertRaisesRegex(TypeError, + "unhashable type: 'C'"): + hash(C(1)) + + def test_hash_rules(self): + def non_bool(value): + # Map to something else that's True, but not a bool. + if value is None: + return None + if value: + return (3,) + return 0 + + def test(case, hash, eq, frozen, with_hash, result): + with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen): + if with_hash: + @dataclass(hash=hash, eq=eq, frozen=frozen) + class C: + def __hash__(self): + return 0 + else: + @dataclass(hash=hash, eq=eq, frozen=frozen) + class C: + pass + + # See if the result matches what's expected. + if result in ('fn', 'fn-x'): + # __hash__ contains the function we generated. + self.assertIn('__hash__', C.__dict__) + self.assertIsNotNone(C.__dict__['__hash__']) + + if result == 'fn-x': + # This is the "auto-hash test" case. We + # should overwrite __hash__ iff there's an + # __eq__ and if __hash__=None. + + # There are two ways of getting __hash__=None: + # explicitely, and by defining __eq__. If + # __eq__ is defined, python will add __hash__ + # when the class is created. + @dataclass(hash=hash, eq=eq, frozen=frozen) + class C: + def __eq__(self, other): pass + __hash__ = None + + # Hash should be overwritten (non-None). + self.assertIsNotNone(C.__dict__['__hash__']) + + # Same test as above, but we don't provide + # __hash__, it will implicitely set to None. + @dataclass(hash=hash, eq=eq, frozen=frozen) + class C: + def __eq__(self, other): pass + + # Hash should be overwritten (non-None). + self.assertIsNotNone(C.__dict__['__hash__']) + + elif result == '': + # __hash__ is not present in our class. + if not with_hash: + self.assertNotIn('__hash__', C.__dict__) + elif result == 'none': + # __hash__ is set to None. + self.assertIn('__hash__', C.__dict__) + self.assertIsNone(C.__dict__['__hash__']) + else: + assert False, f'unknown result {result!r}' + + # There are 12 cases of: + # hash=True/False/None + # eq=True/False + # frozen=True/False + # And for each of these, a different result if + # __hash__ is defined or not. + for case, (hash, eq, frozen, result_no, result_yes) in enumerate([ + (None, False, False, '', ''), + (None, False, True, '', ''), + (None, True, False, 'none', ''), + (None, True, True, 'fn', 'fn-x'), + (False, False, False, '', ''), + (False, False, True, '', ''), + (False, True, False, '', ''), + (False, True, True, '', ''), + (True, False, False, 'fn', 'fn-x'), + (True, False, True, 'fn', 'fn-x'), + (True, True, False, 'fn', 'fn-x'), + (True, True, True, 'fn', 'fn-x'), + ], 1): + test(case, hash, eq, frozen, False, result_no) + test(case, hash, eq, frozen, True, result_yes) + + # Test non-bool truth values, too. This is just to + # make sure the data-driven table in the decorator + # handles non-bool values. + test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no) + test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True, result_yes) + + + def test_eq_only(self): + # If a class defines __eq__, __hash__ is automatically added + # and set to None. This is normal Python behavior, not + # related to dataclasses. Make sure we don't interfere with + # that (see bpo=32546). + + @dataclass + class C: + i: int + def __eq__(self, other): + return self.i == other.i + self.assertEqual(C(1), C(1)) + self.assertNotEqual(C(1), C(4)) + + # And make sure things work in this case if we specify + # hash=True. + @dataclass(hash=True) + class C: + i: int + def __eq__(self, other): + return self.i == other.i + self.assertEqual(C(1), C(1.0)) + self.assertEqual(hash(C(1)), hash(C(1.0))) + + # And check that the classes __eq__ is being used, despite + # specifying eq=True. + @dataclass(hash=True, eq=True) + class C: + i: int + def __eq__(self, other): + return self.i == 3 and self.i == other.i + self.assertEqual(C(3), C(3)) + self.assertNotEqual(C(1), C(1)) + self.assertEqual(hash(C(1)), hash(C(1.0))) + + if __name__ == '__main__': unittest.main() |