summaryrefslogtreecommitdiff
path: root/Lib/test/test_dataclasses.py
diff options
context:
space:
mode:
authorEric V. Smith <ericvsmith@users.noreply.github.com>2018-01-27 19:07:40 -0500
committerGitHub <noreply@github.com>2018-01-27 19:07:40 -0500
commitea8fc52e75363276db23c6a8d7a689f79efce4f9 (patch)
treeca662ba631df1f6e6e32b5b0d95a6b5458d5699c /Lib/test/test_dataclasses.py
parent2a2247ce5e1984eb2f2c41b269b38dbb795a60cf (diff)
downloadcpython-git-ea8fc52e75363276db23c6a8d7a689f79efce4f9.tar.gz
bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)
Class authors no longer need to specify repr=False if they want to provide a custom __repr__ for dataclasses. The same thing applies for the other dunder methods that the dataclass decorator adds. If dataclass finds that a dunder methods is defined in the class, it will not overwrite it.
Diffstat (limited to 'Lib/test/test_dataclasses.py')
-rwxr-xr-xLib/test/test_dataclasses.py664
1 files changed, 452 insertions, 212 deletions
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py
index 69819ea450..53281f9dd9 100755
--- a/Lib/test/test_dataclasses.py
+++ b/Lib/test/test_dataclasses.py
@@ -9,6 +9,7 @@ import unittest
from unittest.mock import Mock
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
from collections import deque, OrderedDict, namedtuple
+from functools import total_ordering
# Just any custom exception we can catch.
class CustomError(Exception): pass
@@ -82,68 +83,12 @@ class TestCase(unittest.TestCase):
class C(B):
x: int = 0
- def test_overwriting_init(self):
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __init__ '
- 'in C'):
- @dataclass
- class C:
- x: int
- def __init__(self, x):
- self.x = 2 * x
-
- @dataclass(init=False)
- class C:
- x: int
- def __init__(self, x):
- self.x = 2 * x
- self.assertEqual(C(5).x, 10)
-
- def test_overwriting_repr(self):
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __repr__ '
- 'in C'):
- @dataclass
- class C:
- x: int
- def __repr__(self):
- pass
-
- @dataclass(repr=False)
- class C:
- x: int
- def __repr__(self):
- return 'x'
- self.assertEqual(repr(C(0)), 'x')
-
- def test_overwriting_cmp(self):
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __eq__ '
- 'in C'):
- # This will generate the comparison functions, make sure we can't
- # overwrite them.
- @dataclass(hash=False, frozen=False)
- class C:
- x: int
- def __eq__(self):
- pass
-
- @dataclass(order=False, eq=False)
+ def test_overwriting_hash(self):
+ @dataclass(frozen=True)
class C:
x: int
- def __eq__(self, other):
- return True
- self.assertEqual(C(0), 'x')
-
- def test_overwriting_hash(self):
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __hash__ '
- 'in C'):
- @dataclass(frozen=True)
- class C:
- x: int
- def __hash__(self):
- pass
+ def __hash__(self):
+ pass
@dataclass(frozen=True,hash=False)
class C:
@@ -152,14 +97,11 @@ class TestCase(unittest.TestCase):
return 600
self.assertEqual(hash(C(0)), 600)
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __hash__ '
- 'in C'):
- @dataclass(frozen=True)
- class C:
- x: int
- def __hash__(self):
- pass
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __hash__(self):
+ pass
@dataclass(frozen=True, hash=False)
class C:
@@ -168,33 +110,6 @@ class TestCase(unittest.TestCase):
return 600
self.assertEqual(hash(C(0)), 600)
- def test_overwriting_frozen(self):
- # frozen uses __setattr__ and __delattr__
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __setattr__ '
- 'in C'):
- @dataclass(frozen=True)
- class C:
- x: int
- def __setattr__(self):
- pass
-
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __delattr__ '
- 'in C'):
- @dataclass(frozen=True)
- class C:
- x: int
- def __delattr__(self):
- pass
-
- @dataclass(frozen=False)
- class C:
- x: int
- def __setattr__(self, name, value):
- self.__dict__['x'] = value * 2
- self.assertEqual(C(10).x, 20)
-
def test_overwrite_fields_in_derived_class(self):
# Note that x from C1 replaces x in Base, but the order remains
# the same as defined in Base.
@@ -239,34 +154,6 @@ class TestCase(unittest.TestCase):
first = next(iter(sig.parameters))
self.assertEqual('self', first)
- def test_repr(self):
- @dataclass
- class B:
- x: int
-
- @dataclass
- class C(B):
- y: int = 10
-
- o = C(4)
- self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)')
-
- @dataclass
- class D(C):
- x: int = 20
- self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)')
-
- @dataclass
- class C:
- @dataclass
- class D:
- i: int
- @dataclass
- class E:
- pass
- self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)')
- self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()')
-
def test_0_field_compare(self):
# Ensure that order=False is the default.
@dataclass
@@ -420,80 +307,8 @@ class TestCase(unittest.TestCase):
self.assertEqual(hash(C(4)), hash((4,)))
self.assertEqual(hash(C(42)), hash((42,)))
- def test_hash(self):
- @dataclass(hash=True)
- class C:
- x: int
- y: str
- self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
-
- def test_no_hash(self):
- @dataclass(hash=None)
- class C:
- x: int
- with self.assertRaisesRegex(TypeError,
- "unhashable type: 'C'"):
- hash(C(1))
-
- def test_hash_rules(self):
- # There are 24 cases of:
- # hash=True/False/None
- # eq=True/False
- # order=True/False
- # frozen=True/False
- for (hash, eq, order, frozen, result ) in [
- (False, False, False, False, 'absent'),
- (False, False, False, True, 'absent'),
- (False, False, True, False, 'exception'),
- (False, False, True, True, 'exception'),
- (False, True, False, False, 'absent'),
- (False, True, False, True, 'absent'),
- (False, True, True, False, 'absent'),
- (False, True, True, True, 'absent'),
- (True, False, False, False, 'fn'),
- (True, False, False, True, 'fn'),
- (True, False, True, False, 'exception'),
- (True, False, True, True, 'exception'),
- (True, True, False, False, 'fn'),
- (True, True, False, True, 'fn'),
- (True, True, True, False, 'fn'),
- (True, True, True, True, 'fn'),
- (None, False, False, False, 'absent'),
- (None, False, False, True, 'absent'),
- (None, False, True, False, 'exception'),
- (None, False, True, True, 'exception'),
- (None, True, False, False, 'none'),
- (None, True, False, True, 'fn'),
- (None, True, True, False, 'none'),
- (None, True, True, True, 'fn'),
- ]:
- with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen):
- if result == 'exception':
- with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
- @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
- class C:
- pass
- else:
- @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
- class C:
- pass
-
- # See if the result matches what's expected.
- if result == 'fn':
- # __hash__ contains the function we generated.
- self.assertIn('__hash__', C.__dict__)
- self.assertIsNotNone(C.__dict__['__hash__'])
- elif result == 'absent':
- # __hash__ is not present in our class.
- self.assertNotIn('__hash__', C.__dict__)
- elif result == 'none':
- # __hash__ is set to None.
- self.assertIn('__hash__', C.__dict__)
- self.assertIsNone(C.__dict__['__hash__'])
- else:
- assert False, f'unknown result {result!r}'
-
def test_eq_order(self):
+ # Test combining eq and order.
for (eq, order, result ) in [
(False, False, 'neither'),
(False, True, 'exception'),
@@ -513,21 +328,18 @@ class TestCase(unittest.TestCase):
if result == 'neither':
self.assertNotIn('__eq__', C.__dict__)
- self.assertNotIn('__ne__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
self.assertNotIn('__ge__', C.__dict__)
elif result == 'both':
self.assertIn('__eq__', C.__dict__)
- self.assertIn('__ne__', C.__dict__)
self.assertIn('__lt__', C.__dict__)
self.assertIn('__le__', C.__dict__)
self.assertIn('__gt__', C.__dict__)
self.assertIn('__ge__', C.__dict__)
elif result == 'eq_only':
self.assertIn('__eq__', C.__dict__)
- self.assertIn('__ne__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
@@ -811,19 +623,6 @@ class TestCase(unittest.TestCase):
y: int
self.assertNotEqual(Point(1, 3), C(1, 3))
- def test_base_has_init(self):
- class B:
- def __init__(self):
- pass
-
- # Make sure that declaring this class doesn't raise an error.
- # The issue is that we can't override __init__ in our class,
- # but it should be okay to add __init__ to us if our base has
- # an __init__.
- @dataclass
- class C(B):
- x: int = 0
-
def test_frozen(self):
@dataclass(frozen=True)
class C:
@@ -2065,6 +1864,7 @@ class TestCase(unittest.TestCase):
'y': int,
'z': 'typing.Any'})
+
class TestDocString(unittest.TestCase):
def assertDocStrEqual(self, a, b):
# Because 3.6 and 3.7 differ in how inspect.signature work
@@ -2154,5 +1954,445 @@ class TestDocString(unittest.TestCase):
self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
+class TestInit(unittest.TestCase):
+ def test_base_has_init(self):
+ class B:
+ def __init__(self):
+ self.z = 100
+ pass
+
+ # Make sure that declaring this class doesn't raise an error.
+ # The issue is that we can't override __init__ in our class,
+ # but it should be okay to add __init__ to us if our base has
+ # an __init__.
+ @dataclass
+ class C(B):
+ x: int = 0
+ c = C(10)
+ self.assertEqual(c.x, 10)
+ self.assertNotIn('z', vars(c))
+
+ # Make sure that if we don't add an init, the base __init__
+ # gets called.
+ @dataclass(init=False)
+ class C(B):
+ x: int = 10
+ c = C()
+ self.assertEqual(c.x, 10)
+ self.assertEqual(c.z, 100)
+
+ def test_no_init(self):
+ dataclass(init=False)
+ class C:
+ i: int = 0
+ self.assertEqual(C().i, 0)
+
+ dataclass(init=False)
+ class C:
+ i: int = 2
+ def __init__(self):
+ self.i = 3
+ self.assertEqual(C().i, 3)
+
+ def test_overwriting_init(self):
+ # If the class has __init__, use it no matter the value of
+ # init=.
+
+ @dataclass
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(3).x, 6)
+
+ @dataclass(init=True)
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(4).x, 8)
+
+ @dataclass(init=False)
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(5).x, 10)
+
+
+class TestRepr(unittest.TestCase):
+ def test_repr(self):
+ @dataclass
+ class B:
+ x: int
+
+ @dataclass
+ class C(B):
+ y: int = 10
+
+ o = C(4)
+ self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
+
+ @dataclass
+ class D(C):
+ x: int = 20
+ self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
+
+ @dataclass
+ class C:
+ @dataclass
+ class D:
+ i: int
+ @dataclass
+ class E:
+ pass
+ self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
+ self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
+
+ def test_no_repr(self):
+ # Test a class with no __repr__ and repr=False.
+ @dataclass(repr=False)
+ class C:
+ x: int
+ self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
+ repr(C(3)))
+
+ # Test a class with a __repr__ and repr=False.
+ @dataclass(repr=False)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'C-class'
+ self.assertEqual(repr(C(3)), 'C-class')
+
+ def test_overwriting_repr(self):
+ # If the class has __repr__, use it no matter the value of
+ # repr=.
+
+ @dataclass
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+ @dataclass(repr=True)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+ @dataclass(repr=False)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+
+class TestFrozen(unittest.TestCase):
+ def test_overwriting_frozen(self):
+ # frozen uses __setattr__ and __delattr__
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __setattr__'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __setattr__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __delattr__'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __delattr__(self):
+ pass
+
+ @dataclass(frozen=False)
+ class C:
+ x: int
+ def __setattr__(self, name, value):
+ self.__dict__['x'] = value * 2
+ self.assertEqual(C(10).x, 20)
+
+
+class TestEq(unittest.TestCase):
+ def test_no_eq(self):
+ # Test a class with no __eq__ and eq=False.
+ @dataclass(eq=False)
+ class C:
+ x: int
+ self.assertNotEqual(C(0), C(0))
+ c = C(3)
+ self.assertEqual(c, c)
+
+ # Test a class with an __eq__ and eq=False.
+ @dataclass(eq=False)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 10
+ self.assertEqual(C(3), 10)
+
+ def test_overwriting_eq(self):
+ # If the class has __eq__, use it no matter the value of
+ # eq=.
+
+ @dataclass
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 3
+ self.assertEqual(C(1), 3)
+ self.assertNotEqual(C(1), 1)
+
+ @dataclass(eq=True)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 4
+ self.assertEqual(C(1), 4)
+ self.assertNotEqual(C(1), 1)
+
+ @dataclass(eq=False)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 5
+ self.assertEqual(C(1), 5)
+ self.assertNotEqual(C(1), 1)
+
+
+class TestOrdering(unittest.TestCase):
+ def test_functools_total_ordering(self):
+ # Test that functools.total_ordering works with this class.
+ @total_ordering
+ @dataclass
+ class C:
+ x: int
+ def __lt__(self, other):
+ # Perform the test "backward", just to make
+ # sure this is being called.
+ return self.x >= other
+
+ self.assertLess(C(0), -1)
+ self.assertLessEqual(C(0), -1)
+ self.assertGreater(C(0), 1)
+ self.assertGreaterEqual(C(0), 1)
+
+ def test_no_order(self):
+ # Test that no ordering functions are added by default.
+ @dataclass(order=False)
+ class C:
+ x: int
+ # Make sure no order methods are added.
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__lt__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+
+ # Test that __lt__ is still called
+ @dataclass(order=False)
+ class C:
+ x: int
+ def __lt__(self, other):
+ return False
+ # Make sure other methods aren't added.
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+
+ def test_overwriting_order(self):
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __lt__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __lt__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __le__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __le__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __gt__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __gt__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __ge__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __ge__(self):
+ pass
+
+class TestHash(unittest.TestCase):
+ def test_hash(self):
+ @dataclass(hash=True)
+ class C:
+ x: int
+ y: str
+ self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
+
+ def test_hash_false(self):
+ @dataclass(hash=False)
+ class C:
+ x: int
+ y: str
+ self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo')))
+
+ def test_hash_none(self):
+ @dataclass(hash=None)
+ class C:
+ x: int
+ with self.assertRaisesRegex(TypeError,
+ "unhashable type: 'C'"):
+ hash(C(1))
+
+ def test_hash_rules(self):
+ def non_bool(value):
+ # Map to something else that's True, but not a bool.
+ if value is None:
+ return None
+ if value:
+ return (3,)
+ return 0
+
+ def test(case, hash, eq, frozen, with_hash, result):
+ with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen):
+ if with_hash:
+ @dataclass(hash=hash, eq=eq, frozen=frozen)
+ class C:
+ def __hash__(self):
+ return 0
+ else:
+ @dataclass(hash=hash, eq=eq, frozen=frozen)
+ class C:
+ pass
+
+ # See if the result matches what's expected.
+ if result in ('fn', 'fn-x'):
+ # __hash__ contains the function we generated.
+ self.assertIn('__hash__', C.__dict__)
+ self.assertIsNotNone(C.__dict__['__hash__'])
+
+ if result == 'fn-x':
+ # This is the "auto-hash test" case. We
+ # should overwrite __hash__ iff there's an
+ # __eq__ and if __hash__=None.
+
+ # There are two ways of getting __hash__=None:
+ # explicitely, and by defining __eq__. If
+ # __eq__ is defined, python will add __hash__
+ # when the class is created.
+ @dataclass(hash=hash, eq=eq, frozen=frozen)
+ class C:
+ def __eq__(self, other): pass
+ __hash__ = None
+
+ # Hash should be overwritten (non-None).
+ self.assertIsNotNone(C.__dict__['__hash__'])
+
+ # Same test as above, but we don't provide
+ # __hash__, it will implicitely set to None.
+ @dataclass(hash=hash, eq=eq, frozen=frozen)
+ class C:
+ def __eq__(self, other): pass
+
+ # Hash should be overwritten (non-None).
+ self.assertIsNotNone(C.__dict__['__hash__'])
+
+ elif result == '':
+ # __hash__ is not present in our class.
+ if not with_hash:
+ self.assertNotIn('__hash__', C.__dict__)
+ elif result == 'none':
+ # __hash__ is set to None.
+ self.assertIn('__hash__', C.__dict__)
+ self.assertIsNone(C.__dict__['__hash__'])
+ else:
+ assert False, f'unknown result {result!r}'
+
+ # There are 12 cases of:
+ # hash=True/False/None
+ # eq=True/False
+ # frozen=True/False
+ # And for each of these, a different result if
+ # __hash__ is defined or not.
+ for case, (hash, eq, frozen, result_no, result_yes) in enumerate([
+ (None, False, False, '', ''),
+ (None, False, True, '', ''),
+ (None, True, False, 'none', ''),
+ (None, True, True, 'fn', 'fn-x'),
+ (False, False, False, '', ''),
+ (False, False, True, '', ''),
+ (False, True, False, '', ''),
+ (False, True, True, '', ''),
+ (True, False, False, 'fn', 'fn-x'),
+ (True, False, True, 'fn', 'fn-x'),
+ (True, True, False, 'fn', 'fn-x'),
+ (True, True, True, 'fn', 'fn-x'),
+ ], 1):
+ test(case, hash, eq, frozen, False, result_no)
+ test(case, hash, eq, frozen, True, result_yes)
+
+ # Test non-bool truth values, too. This is just to
+ # make sure the data-driven table in the decorator
+ # handles non-bool values.
+ test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no)
+ test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True, result_yes)
+
+
+ def test_eq_only(self):
+ # If a class defines __eq__, __hash__ is automatically added
+ # and set to None. This is normal Python behavior, not
+ # related to dataclasses. Make sure we don't interfere with
+ # that (see bpo=32546).
+
+ @dataclass
+ class C:
+ i: int
+ def __eq__(self, other):
+ return self.i == other.i
+ self.assertEqual(C(1), C(1))
+ self.assertNotEqual(C(1), C(4))
+
+ # And make sure things work in this case if we specify
+ # hash=True.
+ @dataclass(hash=True)
+ class C:
+ i: int
+ def __eq__(self, other):
+ return self.i == other.i
+ self.assertEqual(C(1), C(1.0))
+ self.assertEqual(hash(C(1)), hash(C(1.0)))
+
+ # And check that the classes __eq__ is being used, despite
+ # specifying eq=True.
+ @dataclass(hash=True, eq=True)
+ class C:
+ i: int
+ def __eq__(self, other):
+ return self.i == 3 and self.i == other.i
+ self.assertEqual(C(3), C(3))
+ self.assertNotEqual(C(1), C(1))
+ self.assertEqual(hash(C(1)), hash(C(1.0)))
+
+
if __name__ == '__main__':
unittest.main()