summaryrefslogtreecommitdiff
path: root/test/test_transaction.py
blob: e3d13c8a6d4056ba7686747c2fbdf85f0aaec015 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
KafkaTransaction tests.
"""
from unittest2 import TestCase

from mock import MagicMock, patch

from kafka.common import OffsetOutOfRangeError
from kafka.transaction import KafkaTransaction


class TestKafkaTransaction(TestCase):
    """
    KafkaTransaction tests.
    """

    def setUp(self):
        self.client = MagicMock()
        self.consumer = MagicMock()
        self.topic = "topic"
        self.group = "group"
        self.partition = 0
        self.consumer.topic = self.topic
        self.consumer.group = self.group
        self.consumer.client = self.client
        self.consumer.offsets = {self.partition: 0}
        self.transaction = KafkaTransaction(self.consumer)

    def test_noop(self):
        """
        Should revert consumer after transaction with no mark() call.
        """
        with self.transaction:
            # advance offset
            self.consumer.offsets = {self.partition: 1}

        # offset restored
        self.assertEqual(self.consumer.offsets, {self.partition: 0})
        # and seek called with relative zero delta
        self.assertEqual(self.consumer.seek.call_count, 1)
        self.assertEqual(self.consumer.seek.call_args[0], (0, 1))

    def test_mark(self):
        """
        Should remain at marked location.
        """
        with self.transaction as transaction:
            transaction.mark(self.partition, 0)
            # advance offset
            self.consumer.offsets = {self.partition: 1}

        # offset sent to client
        self.assertEqual(self.client.send_offset_commit_request.call_count, 1)

        # offset remains advanced
        self.assertEqual(self.consumer.offsets, {self.partition: 1})

        # and seek called with relative zero delta
        self.assertEqual(self.consumer.seek.call_count, 1)
        self.assertEqual(self.consumer.seek.call_args[0], (0, 1))

    def test_mark_multiple(self):
        """
        Should remain at highest marked location.
        """
        with self.transaction as transaction:
            transaction.mark(self.partition, 0)
            transaction.mark(self.partition, 1)
            transaction.mark(self.partition, 2)
            # advance offset
            self.consumer.offsets = {self.partition: 3}

        # offset sent to client
        self.assertEqual(self.client.send_offset_commit_request.call_count, 1)

        # offset remains advanced
        self.assertEqual(self.consumer.offsets, {self.partition: 3})

        # and seek called with relative zero delta
        self.assertEqual(self.consumer.seek.call_count, 1)
        self.assertEqual(self.consumer.seek.call_args[0], (0, 1))

    def test_rollback(self):
        """
        Should rollback to beginning of transaction.
        """
        with self.assertRaises(Exception):
            with self.transaction as transaction:
                transaction.mark(self.partition, 0)
                # advance offset
                self.consumer.offsets = {self.partition: 1}

                raise Exception("Intentional failure")

        # offset rolled back (ignoring mark)
        self.assertEqual(self.consumer.offsets, {self.partition: 0})

        # and seek called with relative zero delta
        self.assertEqual(self.consumer.seek.call_count, 1)
        self.assertEqual(self.consumer.seek.call_args[0], (0, 1))

    def test_out_of_range(self):
        """
        Should remain at beginning of range.
        """
        def _seek(offset, whence):
            # seek must be called with 0, 0 to find the beginning of the range
            self.assertEqual(offset, 0)
            self.assertEqual(whence, 0)
            # set offsets to something different
            self.consumer.offsets = {self.partition: 100}

        with patch.object(self.consumer, "seek", _seek):
            with self.transaction:
                raise OffsetOutOfRangeError()

        self.assertEqual(self.consumer.offsets, {self.partition: 100})