summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2016-07-31 11:55:58 -0700
committerDana Powers <dana.powers@gmail.com>2016-09-24 14:01:05 -0700
commit26fe1f2e296aa78e0fe79c01f0b974dfc8741246 (patch)
treef0ae7d24e3b341b550797000b2a80d84b52cefb8
parent2a7aca1630b81669595d753083239ec9fbf66ff5 (diff)
downloadkafka-python-26fe1f2e296aa78e0fe79c01f0b974dfc8741246.tar.gz
First scratch commit kafka.streams
-rw-r--r--kafka/streams/__init__.py0
-rw-r--r--kafka/streams/errors.py23
-rw-r--r--kafka/streams/kafka.py185
-rw-r--r--kafka/streams/processor/__init__.py0
-rw-r--r--kafka/streams/processor/_partition_grouper.py62
-rw-r--r--kafka/streams/processor/assignment/__init__.py0
-rw-r--r--kafka/streams/processor/assignment/assignment_info.py75
-rw-r--r--kafka/streams/processor/assignment/client_state.py66
-rw-r--r--kafka/streams/processor/assignment/subscription_info.py61
-rw-r--r--kafka/streams/processor/assignment/task_assignor.py166
-rw-r--r--kafka/streams/processor/context.py101
-rw-r--r--kafka/streams/processor/internal_topic_manager.py188
-rw-r--r--kafka/streams/processor/node.py132
-rw-r--r--kafka/streams/processor/partition_group.py158
-rw-r--r--kafka/streams/processor/processor.py80
-rw-r--r--kafka/streams/processor/processor_state_manager.py306
-rw-r--r--kafka/streams/processor/punctuation.py32
-rw-r--r--kafka/streams/processor/quick_union.py62
-rw-r--r--kafka/streams/processor/record_collector.py68
-rw-r--r--kafka/streams/processor/record_queue.py166
-rw-r--r--kafka/streams/processor/stream_partition_assignor.py435
-rw-r--r--kafka/streams/processor/stream_task.py277
-rw-r--r--kafka/streams/processor/stream_thread.py697
-rw-r--r--kafka/streams/processor/task.py333
-rw-r--r--kafka/streams/processor/topology.py21
-rw-r--r--kafka/streams/processor/topology_builder.py642
-rw-r--r--kafka/streams/utils.py20
27 files changed, 4356 insertions, 0 deletions
diff --git a/kafka/streams/__init__.py b/kafka/streams/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/kafka/streams/__init__.py
diff --git a/kafka/streams/errors.py b/kafka/streams/errors.py
new file mode 100644
index 0000000..e35007e
--- /dev/null
+++ b/kafka/streams/errors.py
@@ -0,0 +1,23 @@
+from __future__ import absolute_import
+
+from kafka.errors import KafkaError, IllegalStateError
+
+
+class StreamsError(KafkaError):
+ pass
+
+
+class ProcessorStateError(StreamsError):
+ pass
+
+
+class TopologyBuilderError(StreamsError):
+ pass
+
+
+class NoSuchElementError(StreamsError):
+ pass
+
+
+class TaskAssignmentError(StreamsError):
+ pass
diff --git a/kafka/streams/kafka.py b/kafka/streams/kafka.py
new file mode 100644
index 0000000..bc7dcec
--- /dev/null
+++ b/kafka/streams/kafka.py
@@ -0,0 +1,185 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import copy
+import logging
+import uuid
+
+import kafka.streams.errors as Errors
+
+from .processor.stream_thread import StreamThread
+from .utils import AtomicInteger
+
+log = logging.getLogger(__name__)
+
+
+# container states
+CREATED = 0
+RUNNING = 1
+STOPPED = 2
+
+
+class KafkaStreams(object):
+ """
+ Kafka Streams allows for performing continuous computation on input coming
+ from one or more input topics and sends output to zero or more output
+ topics.
+
+ The computational logic can be specified either by using the TopologyBuilder
+ class to define the a DAG topology of Processors or by using the
+ KStreamBuilder class which provides the high-level KStream DSL to define
+ the transformation.
+
+ The KafkaStreams class manages the lifecycle of a Kafka Streams instance.
+ One stream instance can contain one or more threads specified in the configs
+ for the processing work.
+
+ A KafkaStreams instance can co-ordinate with any other instances with the
+ same application ID (whether in this same process, on other processes on
+ this machine, or on remote machines) as a single (possibly distributed)
+ stream processing client. These instances will divide up the work based on
+ the assignment of the input topic partitions so that all partitions are
+ being consumed. If instances are added or failed, all instances will
+ rebalance the partition assignment among themselves to balance processing
+ load.
+
+ Internally the KafkaStreams instance contains a normal KafkaProducer and
+ KafkaConsumer instance that is used for reading input and writing output.
+
+ A simple example might look like this:
+
+ builder = (KStreamBuilder().stream('my-input-topic')
+ .map_values(lambda value: str(len(value))
+ .to('my-output-topic'))
+
+ streams = KafkaStreams(builder,
+ application_id='my-stream-processing-application',
+ bootstrap_servers=['localhost:9092'],
+ key_serializer=json.dumps,
+ key_deserializer=json.loads,
+ value_serializer=json.dumps,
+ value_deserializer=json.loads)
+ streams.start()
+ """
+ STREAM_CLIENT_ID_SEQUENCE = AtomicInteger(0)
+ METRICS_PREFIX = 'kafka.streams'
+
+ DEFAULT_CONFIG = {
+ 'application_id': None,
+ 'bootstrap_servers': None,
+ 'num_stream_threads': 1,
+ }
+
+ def __init__(self, builder, **configs):
+ """Construct the stream instance.
+
+ Arguments:
+ builder (...): The processor topology builder specifying the computational logic
+ """
+ self.config = copy.copy(self.DEFAULT_CONFIG)
+ for key in self.config:
+ if key in configs:
+ self.config[key] = configs.pop(key)
+
+ # Only check for extra config keys in top-level class
+ log.warning('Unrecognized configs: %s', configs.keys())
+
+ self._state = CREATED
+
+ # processId is expected to be unique across JVMs and to be used
+ # in userData of the subscription request to allow assignor be aware
+ # of the co-location of stream thread's consumers. It is for internal
+ # usage only and should not be exposed to users at all.
+ self.config['process_id'] = uuid.uuid4().hex
+
+ # The application ID is a required config and hence should always have value
+ if 'application_id' not in self.config:
+ raise Errors.StreamsError('application_id is a required parameter')
+
+ builder.set_application_id(self.config['application_id'])
+
+ if 'client_id' not in self.config:
+ next_id = self.STREAM_CLIENT_ID_SEQUENCE.increment()
+ self.config['client_id'] = self.config['application_id'] + "-" + str(next_id)
+
+ # reporters = self.config['metric_reporters']
+
+ #MetricConfig metricConfig = new MetricConfig().samples(config.getInt(StreamsConfig.METRICS_NUM_SAMPLES_CONFIG))
+ # .timeWindow(config.getLong(StreamsConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG),
+ # TimeUnit.MILLISECONDS);
+
+ #self._metrics = new Metrics(metricConfig, reporters, time);
+
+ self._threads = [StreamThread(builder, **self.config)
+ for _ in range(self.config['num_stream_threads'])]
+
+ #synchronized
+ def start(self):
+ """Start the stream instance by starting all its threads.
+
+ Raises:
+ IllegalStateException if process was already started
+ """
+ log.debug('Starting Kafka Stream process')
+
+ if self._state == CREATED:
+ for thread in self._threads:
+ thread.start()
+
+ self._state = RUNNING
+
+ log.info('Started Kafka Stream process')
+ elif self._state == RUNNING:
+ raise Errors.IllegalStateError('This process was already started.')
+ else:
+ raise Errors.IllegalStateError('Cannot restart after closing.')
+
+ #synchronized
+ def close(self):
+ """Shutdown this stream instance.
+
+ Signals all the threads to stop, and then waits for them to join.
+
+ Raises:
+ IllegalStateException if process has not started yet
+ """
+ log.debug('Stopping Kafka Stream process')
+
+ if self._state == RUNNING:
+ # signal the threads to stop and wait
+ for thread in self._threads:
+ thread.close()
+
+ for thread in self._threads:
+ thread.join()
+
+ if self._state != STOPPED:
+ #metrics.close()
+ self._state = STOPPED
+ log.info('Stopped Kafka Stream process')
+
+ def setUncaughtExceptionHandler(self, handler):
+ """Sets the handler invoked when a stream thread abruptly terminates
+ due to an uncaught exception.
+
+ Arguments:
+ handler: the object to use as this thread's uncaught exception handler.
+ If None then this thread has no explicit handler.
+ """
+ for thread in self._threads:
+ thread.set_uncaught_exception_handler(handler)
diff --git a/kafka/streams/processor/__init__.py b/kafka/streams/processor/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/kafka/streams/processor/__init__.py
diff --git a/kafka/streams/processor/_partition_grouper.py b/kafka/streams/processor/_partition_grouper.py
new file mode 100644
index 0000000..453438c
--- /dev/null
+++ b/kafka/streams/processor/_partition_grouper.py
@@ -0,0 +1,62 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+
+"""
+ * Default implementation of the {@link PartitionGrouper} interface that groups partitions by the partition id.
+ *
+ * Join operations requires that topics of the joining entities are copartitoned, i.e., being partitioned by the same key and having the same
+ * number of partitions. Copartitioning is ensured by having the same number of partitions on
+ * joined topics, and by using the serialization and Producer's default partitioner.
+"""
+class DefaultPartitionGrouper(object):
+
+ """
+ * Generate tasks with the assigned topic partitions.
+ *
+ * @param topicGroups group of topics that need to be joined together
+ * @param metadata metadata of the consuming cluster
+ * @return The map from generated task ids to the assigned partitions
+ """
+ def partition_groups(self, topic_groups, metadata):
+ groups = {}
+
+ for topic_group_id, topic_group in topic_groups.items():
+
+ max_num_partitions = self.max_num_partitions(metadata, topic_group)
+
+ for partition_id in range(max_num_partitions):
+ group = set()
+
+ for topic in topic_group:
+ if partition_id < len(metadata.partitions_for_topic(topic)):
+ group.add(TopicPartition(topic, partition_id))
+ groups[TaskId(topicGroupId, partitionId)] = group
+
+ return groups
+
+ def max_num_partitions(self, metadata, topics):
+ max_num_partitions = 0
+ for topic in topics:
+ partitions = metadata.partitions_for_topic(topic)
+
+ if not partitions:
+ raise Errors.StreamsError("Topic not found during partition assignment: " + topic)
+
+ num_partitions = len(partitions)
+ if num_partitions > max_num_partitions:
+ max_num_partitions = num_partitions
+ return max_num_partitions
diff --git a/kafka/streams/processor/assignment/__init__.py b/kafka/streams/processor/assignment/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/kafka/streams/processor/assignment/__init__.py
diff --git a/kafka/streams/processor/assignment/assignment_info.py b/kafka/streams/processor/assignment/assignment_info.py
new file mode 100644
index 0000000..791aa19
--- /dev/null
+++ b/kafka/streams/processor/assignment/assignment_info.py
@@ -0,0 +1,75 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import json
+import logging
+
+from kafka import TopicPartition
+from kafka.streams.errors import TaskAssignmentError
+from kafka.streams.processor.partition_group import TaskId
+
+log = logging.getLogger(__name__)
+
+
+class AssignmentInfo(object):
+ CURRENT_VERSION = 1
+
+ def __init__(self, active_tasks, standby_tasks, version=None):
+ self.version = self.CURRENT_VERSION if version is None else version
+ self.active_tasks = active_tasks
+ self.standby_tasks = standby_tasks
+
+ def encode(self):
+
+ try:
+ if self.version == self.CURRENT_VERSION:
+ data = {
+ 'version': self.version,
+ 'active_tasks': [list(task) for task in self.active_tasks],
+ 'standby_tasks': [[list(task), [list(tp) for tp in partitions]]
+ for task, partitions in self.standby_tasks.items()]
+ }
+ return json.dumps(data).encode('utf-8')
+
+ else:
+ raise TaskAssignmentError('Unable to encode assignment data: version=' + str(self.version))
+
+ except Exception as ex:
+ raise TaskAssignmentError('Failed to encode AssignmentInfo', ex)
+
+ @classmethod
+ def decode(cls, data):
+ try:
+ decoded = json.loads(data.decode('utf-8'))
+
+ if decoded['version'] == cls.CURRENT_VERSION:
+ decoded['active_tasks'] = [TaskId(*task) for task in decoded['active_tasks']]
+ decoded['standby_tasks'] = dict([
+ (TaskId(*task), set([TopicPartition(*partition) for partition in partitions]))
+ for task, partitions in decoded['standby_tasks']])
+
+ return AssignmentInfo(decoded['active_tasks'], decoded['standby_tasks'])
+
+ else:
+ raise TaskAssignmentError('Unknown assignment data version: ' + str(cls.version))
+ except Exception as ex:
+ raise TaskAssignmentError('Failed to decode AssignmentInfo', ex)
+
+ def __str__(self):
+ return "[version=%d, active_tasks=%d, standby_tasks=%d]" % (
+ self.version, len(self.active_tasks), len(self.standby_tasks))
diff --git a/kafka/streams/processor/assignment/client_state.py b/kafka/streams/processor/assignment/client_state.py
new file mode 100644
index 0000000..cec8d68
--- /dev/null
+++ b/kafka/streams/processor/assignment/client_state.py
@@ -0,0 +1,66 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import copy
+
+class ClientState(object):
+
+ COST_ACTIVE = 0.1
+ COST_STANDBY = 0.2
+ COST_LOAD = 0.5
+
+ def __init__(self, active_tasks=None, assigned_tasks=None,
+ prev_active_tasks=None, prev_assigned_tasks=None,
+ capacity=0.0):
+ self.active_tasks = active_tasks if active_tasks else set([])
+ self.assigned_tasks = assigned_tasks if assigned_tasks else set([])
+ self.prev_active_tasks = prev_active_tasks if prev_active_tasks else set([])
+ self.prev_assigned_tasks = prev_assigned_tasks if prev_assigned_tasks else set([])
+ self.capacity = capacity
+ self.cost = 0.0
+
+ def copy(self):
+ return ClientState(copy.deepcopy(self.active_tasks),
+ copy.deepcopy(self.assigned_tasks),
+ copy.deepcopy(self.prev_active_tasks),
+ copy.deepcopy(self.prev_assigned_tasks),
+ self.capacity)
+
+ def assign(self, task_id, active):
+ if active:
+ self.active_tasks.add(task_id)
+
+ self.assigned_tasks.add(task_id)
+
+ cost = self.COST_LOAD
+ try:
+ self.prev_assigned_tasks.remove(task_id)
+ cost = self.COST_STANDBY
+ except KeyError:
+ pass
+ try:
+ self.prev_active_tasks.remove(task_id)
+ cost = self.COST_ACTIVE
+ except KeyError:
+ pass
+
+ self.cost += cost
+
+ def __str__(self):
+ return "[active_tasks: (%s) assigned_tasks: (%s) prev_active_tasks: (%s) prev_assigned_tasks: (%s) capacity: (%s) cost: (%s)]" % (
+ self.active_tasks, self.assigned_tasks, self.prev_active_tasks, self.prev_assigned_tasks, self.capacity, self.cost)
diff --git a/kafka/streams/processor/assignment/subscription_info.py b/kafka/streams/processor/assignment/subscription_info.py
new file mode 100644
index 0000000..c7fce14
--- /dev/null
+++ b/kafka/streams/processor/assignment/subscription_info.py
@@ -0,0 +1,61 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import json
+import logging
+
+from kafka.streams.errors import TaskAssignmentError
+
+log = logging.getLogger(__name__)
+
+
+class SubscriptionInfo(object):
+ CURRENT_VERSION = 1
+
+ def __init__(self, process_id, prev_tasks, standby_tasks, version=None):
+ self.version = self.CURRENT_VERSION if version is None else version
+ self.process_id = process_id
+ self.prev_tasks = prev_tasks
+ self.standby_tasks = standby_tasks
+
+ def encode(self):
+ if self.version == self.CURRENT_VERSION:
+ data = {
+ 'version': self.version,
+ 'process_id': self.process_id,
+ 'prev_tasks': list(self.prev_tasks),
+ 'standby_tasks': list(self.standby_tasks)
+ }
+ return json.dumps(data).encode('utf-8')
+
+ else:
+ raise TaskAssignmentError('unable to encode subscription data: version=' + str(self.version))
+
+ @classmethod
+ def decode(cls, data):
+ try:
+ decoded = json.loads(data.decode('utf-8'))
+ if decoded['version'] != cls.CURRENT_VERSION:
+ raise TaskAssignmentError('unable to decode subscription data: version=' + str(cls.version))
+
+ decoded['prev_tasks'] = set(decoded['prev_tasks'])
+ decoded['standby_tasks'] = set(decoded['standby_tasks'])
+ return cls(decoded['process_id'], decoded['prev_tasks'], decoded['standby_tasks'], decoded['version'])
+
+ except Exception as e:
+ raise TaskAssignmentError('unable to decode subscription data', e)
diff --git a/kafka/streams/processor/assignment/task_assignor.py b/kafka/streams/processor/assignment/task_assignor.py
new file mode 100644
index 0000000..8cb6c4e
--- /dev/null
+++ b/kafka/streams/processor/assignment/task_assignor.py
@@ -0,0 +1,166 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import logging
+import random
+
+from kafka.streams.errors import TaskAssignmentError
+from kafka.streams.processor.assignment.client_state import ClientState
+
+log = logging.getLogger(__name__)
+
+
+class TaskAssignor(object):
+
+ @classmethod
+ def assign(cls, states, tasks, num_standby_replicas):
+ assignor = TaskAssignor(states, tasks)
+ log.info('Assigning tasks to clients: %s, prev_assignment_balanced: %s,'
+ ' prev_clients_unchangeed: %s, tasks: %s, replicas: %s',
+ states, assignor.prev_assignment_balanced,
+ assignor.prev_clients_unchanged, tasks, num_standby_replicas)
+
+ assignor.assign_tasks()
+ if num_standby_replicas > 0:
+ assignor.assign_standby_tasks(num_standby_replicas)
+
+ log.info('Assigned with: %s', assignor.states)
+ return assignor.states
+
+ def __init__(self, states, tasks):
+ self.states = {}
+ self.task_pairs = set()
+ self.max_num_task_pairs = None
+ self.tasks = []
+ self.prev_assignment_balanced = True
+ self.prev_clients_unchanged = True
+
+ avg_num_tasks = len(tasks) // len(states)
+ existing_tasks = set()
+ for client, state in states.items():
+ self.states[client] = state.copy()
+ old_tasks = state.prev_assigned_tasks
+ # make sure the previous assignment is balanced
+ self.prev_assignment_balanced &= len(old_tasks) < (2 * avg_num_tasks)
+ self.prev_assignment_balanced &= len(old_tasks) > (avg_num_tasks / 2)
+
+ for task in old_tasks:
+ # Make sure there is no duplicates
+ self.prev_clients_unchanged &= task not in existing_tasks
+ existing_tasks.update(old_tasks)
+
+ # Make sure the existing assignment didn't miss out any task
+ self.prev_clients_unchanged &= existing_tasks == tasks
+
+ self.tasks = list(tasks)
+
+ num_tasks = len(tasks)
+ self.max_num_task_pairs = num_tasks * (num_tasks - 1) / 2
+ #self.taskPairs = set(range(self.max_num_task_pairs)) # XXX
+
+ def assign_standby_tasks(self, num_standby_replicas):
+ num_replicas = min(num_standby_replicas, len(self.states) - 1)
+ for _ in range(num_replicas):
+ self.assign_tasks(active=False)
+
+ def assign_tasks(self, active=True):
+ random.shuffle(self.tasks)
+
+ for task in self.tasks:
+ state = self.find_client_for(task)
+
+ if state:
+ state.assign(task, active)
+ else:
+ raise TaskAssignmentError('failed to find an assignable client')
+
+ def find_client_for(self, task):
+ check_task_pairs = len(self.task_pairs) < self.max_num_task_pairs
+
+ state = self.find_client_by_addition_cost(task, check_task_pairs)
+
+ if state is None and check_task_pairs:
+ state = self.find_client_by_addition_cost(task, False)
+
+ if state:
+ self.add_task_pairs(task, state)
+
+ return state
+
+ def find_client_by_addition_cost(self, task, check_task_pairs):
+ candidate = None
+ candidate_addition_cost = 0.0
+
+ for state in self.states.values():
+ if (self.prev_assignment_balanced and self.prev_clients_unchanged
+ and task in state.prev_assigned_tasks):
+ return state;
+ if task not in state.assigned_tasks:
+ # if check_task_pairs is True, skip this client if this task doesn't introduce a new task combination
+ if (check_task_pairs and state.assigned_tasks
+ and not self.has_new_task_pair(task, state)):
+ continue
+ addition_cost = self.compute_addition_cost(task, state)
+ if (candidate is None or
+ (addition_cost < candidate_addition_cost or
+ (addition_cost == candidate_addition_cost
+ and state.cost < candidate.cost))):
+ candidate = state
+ candidate_addition_cost = addition_cost
+ return candidate
+
+ def add_task_pairs(self, task, state):
+ for other in state.assigned_tasks:
+ self.task_pairs.add(self.pair(task, other))
+
+ def has_new_task_pair(self, task, state):
+ for other in state.assigned_tasks:
+ if self.pair(task, other) not in self.task_pairs:
+ return True
+ return False
+
+ def compute_addition_cost(self, task, state):
+ cost = len(state.assignedTasks) // state.capacity
+
+ if task in state.prev_assigned_tasks:
+ if task in state.prev_active_tasks:
+ cost += ClientState.COST_ACTIVE
+ else:
+ cost += ClientState.COST_STANDBY
+ else:
+ cost += ClientState.COST_LOAD
+ return cost
+
+ def pair(self, task1, task2):
+ if task1 < task2:
+ return self.TaskPair(task1, task2)
+ else:
+ return self.TaskPair(task2, task1)
+
+ class TaskPair(object):
+ def __init__(self, task1, task2):
+ self.task1 = task1
+ self.task2 = task2
+
+ def __hash__(self):
+ return hash(self.task1) ^ hash(self.task2)
+
+ def __eq__(self, other):
+ if isinstance(other, type(self)):
+ return self.task1 == other.task1 and self.task2 == other.task2
+ return False
diff --git a/kafka/streams/processor/context.py b/kafka/streams/processor/context.py
new file mode 100644
index 0000000..14c9bf2
--- /dev/null
+++ b/kafka/streams/processor/context.py
@@ -0,0 +1,101 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+
+import kafka.errors as Errors
+
+NONEXIST_TOPIC = '__null_topic__'
+
+
+class ProcessorContext(object):
+
+ def __init__(self, task_id, task, record_collector, state_mgr, **config):
+ self.config = config
+ self._task = task
+ self.record_collector = record_collector
+ self.task_id = task_id
+ self.state_mgr = state_mgr
+
+ #self.metrics = metrics
+ self.key_serializer = config['key_serializer']
+ self.value_serializer = config['value_serializer']
+ self.key_deserializer = config['key_deserializer']
+ self.value_deserializer = config['value_deserializer']
+ self._initialized = False
+
+ def initialized(self):
+ self._initialized = True
+
+ @property
+ def application_id(self):
+ return self._task.application_id
+
+ def state_dir(self):
+ return self.state_mgr.base_dir()
+
+ def register(self, state_store, logging_enabled, state_restore_callback):
+ if self._initialized:
+ raise Errors.IllegalStateError('Can only create state stores during initialization.')
+
+ self.state_mgr.register(state_store, logging_enabled, state_restore_callback)
+
+ def get_state_store(self, name):
+ """
+ Raises TopologyBuilderError if an attempt is made to access this state store from an unknown node
+ """
+ node = self._task.node()
+ if not node:
+ raise Errors.TopologyBuilderError('Accessing from an unknown node')
+
+ # TODO: restore this once we fix the ValueGetter initialization issue
+ #if (!node.stateStores.contains(name))
+ # throw new TopologyBuilderException("Processor " + node.name() + " has no access to StateStore " + name);
+
+ return self.state_mgr.get_store(name)
+
+ def topic(self):
+ if self._task.record() is None:
+ raise Errors.IllegalStateError('This should not happen as topic() should only be called while a record is processed')
+
+ topic = self._task.record().topic()
+ if topic == NONEXIST_TOPIC:
+ return None
+ else:
+ return topic
+
+ def partition(self):
+ if self._task.record() is None:
+ raise Errors.IllegalStateError('This should not happen as partition() should only be called while a record is processed')
+ return self._task.record().partition()
+
+ def offset(self):
+ if self._task.record() is None:
+ raise Errors.IllegalStateError('This should not happen as offset() should only be called while a record is processed')
+ return self._task.record().offset()
+
+ def timestamp(self):
+ if self._task.record() is None:
+ raise Errors.IllegalStateError('This should not happen as timestamp() should only be called while a record is processed')
+ return self._task.record().timestamp
+
+ def forward(self, key, value, child_index=None, child_name=None):
+ self._task.forward(key, value, child_index=child_index, child_name=child_name)
+
+ def commit(self):
+ self._task.need_commit()
+
+ def schedule(self, interval):
+ self._task.schedule(interval)
diff --git a/kafka/streams/processor/internal_topic_manager.py b/kafka/streams/processor/internal_topic_manager.py
new file mode 100644
index 0000000..6d2aa4b
--- /dev/null
+++ b/kafka/streams/processor/internal_topic_manager.py
@@ -0,0 +1,188 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import copy
+import json
+import logging
+
+"""
+import org.I0Itec.zkclient.ZkClient
+import org.I0Itec.zkclient.exception.ZkNoNodeException
+import org.I0Itec.zkclient.exception.ZkNodeExistsException
+import org.I0Itec.zkclient.serialize.ZkSerializer
+import org.apache.zookeeper.ZooDefs;
+"""
+
+from kafka.streams.errors import StreamsError
+
+log = logging.getLogger(__name__)
+
+
+class InternalTopicManager(object):
+
+ # TODO: the following ZK dependency should be removed after KIP-4
+ ZK_TOPIC_PATH = '/brokers/topics'
+ ZK_BROKER_PATH = '/brokers/ids'
+ ZK_DELETE_TOPIC_PATH = '/admin/delete_topics'
+ ZK_ENTITY_CONFIG_PATH = '/config/topics'
+ # TODO: the following LogConfig dependency should be removed after KIP-4
+ CLEANUP_POLICY_PROP = 'cleanup.policy'
+ COMPACT = 'compact'
+ ZK_ENCODING = 'utf-8'
+
+ def __init__(self, zk_connect=None, replication_factor=0):
+ if zk_connect:
+ #self.zk_client = ZkClient(zk_connect, 30 * 1000, 30 * 1000, self.ZK_ENCODING)
+ self.zk_client = None
+ else:
+ self.zk_client = None
+ self.replication_factor = replication_factor
+
+ def make_ready(self, topic, num_partitions, compact_topic):
+ topic_not_ready = True
+
+ while topic_not_ready:
+ topic_metadata = self.get_topic_metadata(topic)
+
+ if not topic_metadata:
+ try:
+ self.create_topic(topic, num_partitions, self.replication_factor, compact_topic)
+ except Exception: #ZkNodeExistsError:
+ # ignore and continue
+ pass
+ else:
+ if len(topic_metadata) > num_partitions:
+ # else if topic exists with more #.partitions than needed, delete in order to re-create it
+ try:
+ self.delete_topic(topic)
+ except Exception: #ZkNodeExistsError:
+ # ignore and continue
+ pass
+ elif len(topic_metadata) < num_partitions:
+ # else if topic exists with less #.partitions than needed, add partitions
+ try:
+ self.add_partitions(topic, num_partitions - len(topic_metadata), self.replication_factor, topic_metadata)
+ except Exception: #ZkNoNodeError:
+ # ignore and continue
+ pass
+ else:
+ topic_not_ready = False
+
+ def get_brokers(self):
+ brokers = []
+ for broker in self.zk_client.get_children(self.ZK_BROKER_PATH):
+ brokers.append(int(broker))
+ brokers.sort()
+ log.debug("Read brokers %s from ZK in partition assignor.", brokers)
+ return brokers
+
+ def get_topic_metadata(self, topic):
+ data = self.zk_client.read_data(self.ZK_TOPIC_PATH + "/" + topic, True)
+
+ if data is None:
+ return None
+
+ try:
+ partitions = json.loads(data).get('partitions')
+ log.debug("Read partitions %s for topic %s from ZK in partition assignor.", partitions, topic)
+ return partitions
+ except Exception as e:
+ raise StreamsError("Error while reading topic metadata from ZK for internal topic " + topic, e)
+
+ def create_topic(self, topic, num_partitions, replication_factor, compact_topic):
+ log.debug("Creating topic %s with %s partitions from ZK in partition assignor.", topic, num_partitions)
+ brokers = self.get_brokers()
+ num_brokers = len(brokers)
+ if num_brokers < replication_factor:
+ log.warn("Not enough brokers found. The replication factor is reduced from " + replication_factor + " to " + num_brokers)
+ replication_factor = num_brokers
+
+ assignment = {}
+
+ for i in range(num_partitions):
+ broker_list = []
+ for r in range(replication_factor):
+ shift = r * num_brokers / replication_factor
+ broker_list.append(brokers[(i + shift) % num_brokers])
+ assignment[i] = broker_list
+
+ # write out config first just like in AdminUtils.scala createOrUpdateTopicPartitionAssignmentPathInZK()
+ if compact_topic:
+ try:
+ data_map = {
+ 'version': 1,
+ 'config': {
+ 'cleanup.policy': 'compact'
+ }
+ }
+ data = json.dumps(data_map)
+ #zk_client.create_persistent(self.ZK_ENTITY_CONFIG_PATH + "/" + topic, data, ZooDefs.Ids.OPEN_ACL_UNSAFE)
+ except Exception as e:
+ raise StreamsError('Error while creating topic config in ZK for internal topic ' + topic, e)
+
+ # try to write to ZK with open ACL
+ try:
+ data_map = {
+ 'version': 1,
+ 'partitions': assignment
+ }
+ data = json.dumps(data_map)
+
+ #zk_client.create_persistent(self.ZK_TOPIC_PATH + "/" + topic, data, ZooDefs.Ids.OPEN_ACL_UNSAFE)
+ except Exception as e:
+ raise StreamsError('Error while creating topic metadata in ZK for internal topic ' + topic, e)
+
+ def delete_topic(self, topic):
+ log.debug('Deleting topic %s from ZK in partition assignor.', topic)
+
+ #zk_client.create_persistent(self.ZK_DELETE_TOPIC_PATH + "/" + topic, "", ZooDefs.Ids.OPEN_ACL_UNSAFE)
+
+ def add_partitions(self, topic, num_partitions, replication_factor, existing_assignment):
+ log.debug('Adding %s partitions topic %s from ZK with existing'
+ ' partitions assigned as %s in partition assignor.',
+ topic, num_partitions, existing_assignment)
+
+ brokers = self.get_brokers()
+ num_brokers = len(brokers)
+ if (num_brokers < replication_factor):
+ log.warning('Not enough brokers found. The replication factor is'
+ ' reduced from %s to %s', replication_factor, num_brokers)
+ replication_factor = num_brokers
+
+ start_index = len(existing_assignment)
+
+ new_assignment = copy.deepcopy(existing_assignment)
+
+ for i in range(num_partitions):
+ broker_list = []
+ for r in range(replication_factor):
+ shift = r * num_brokers / replication_factor
+ broker_list.append(brokers[(i + shift) % num_brokers])
+ new_assignment[i + start_index] = broker_list
+
+ # try to write to ZK with open ACL
+ try:
+ data_map = {
+ 'version': 1,
+ 'partitions': new_assignment
+ }
+ data = json.dumps(data_map)
+
+ #zk_client.write_data(ZK_TOPIC_PATH + "/" + topic, data)
+ except Exception as e:
+ raise StreamsError('Error while updating topic metadata in ZK for internal topic ' + topic, e)
diff --git a/kafka/streams/processor/node.py b/kafka/streams/processor/node.py
new file mode 100644
index 0000000..90b0a6a
--- /dev/null
+++ b/kafka/streams/processor/node.py
@@ -0,0 +1,132 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import kafka.errors as Errors
+from .processor import Processor
+
+
+class ProcessorNode(object):
+
+ def __init__(self, name, processor=None, state_stores=None):
+ self.name = name
+
+ # Could we construct a Processor here if the processor is just a function?
+ assert isinstance(processor, Processor), 'processor must subclass Processor'
+
+ self.processor = processor
+ self.children = []
+ self.state_stores = state_stores
+
+ def add_child(self, child):
+ self.children.append(child)
+
+ def init(self, context):
+ self.processor.init(context)
+
+ def process(self, key, value):
+ self.processor.process(key, value)
+
+ def close(self):
+ self.processor.close()
+
+
+class SourceNode(ProcessorNode):
+
+ def __init__(self, name, key_deserializer, val_deserializer):
+ super(SourceNode, self).__init__(name)
+
+ self.key_deserializer = key_deserializer
+ self.val_deserializer = val_deserializer
+ self.context = None
+
+ def deserialize_key(self, topic, data):
+ if self.key_deserializer is None:
+ return data
+ return self.key_deserializer.deserialize(topic, data)
+
+ def deserialize_value(self, topic, data):
+ if self.value_deserializer is None:
+ return data
+ return self.val_deserializer.deserialize(topic, data)
+
+ def init(self, context):
+ self.context = context
+
+ # if deserializers are null, get the default ones from the context
+ if self.key_deserializer is None:
+ self.key_deserializer = self.context.key_deserializer
+ if self.val_deserializer is None:
+ self.val_deserializer = self.context.value_deserializer
+
+ """
+ // if value deserializers are for {@code Change} values, set the inner deserializer when necessary
+ if (this.valDeserializer instanceof ChangedDeserializer &&
+ ((ChangedDeserializer) this.valDeserializer).inner() == null)
+ ((ChangedDeserializer) this.valDeserializer).setInner(context.valueSerde().deserializer());
+ """
+
+ def process(self, key, value):
+ self.context.forward(key, value)
+
+ def close(self):
+ # do nothing
+ pass
+
+
+class SinkNode(ProcessorNode):
+
+ def __init__(self, name, topic, key_serializer, val_serializer, partitioner):
+ super(SinkNode, self).__init__(name)
+
+ self.topic = topic
+ self.key_serializer = key_serializer
+ self.val_serializer = val_serializer
+ self.partitioner = partitioner
+ self.context = None
+
+ def add_child(self, child):
+ raise Errors.UnsupportedOperationError("sink node does not allow addChild")
+
+ def init(self, context):
+ self.context = context
+
+ # if serializers are null, get the default ones from the context
+ if self.key_serializer is None:
+ self.key_serializer = self.context.key_serializer
+ if self.val_serializer is None:
+ self.val_serializer = self.context.value_serializer
+
+ """
+ // if value serializers are for {@code Change} values, set the inner serializer when necessary
+ if (this.valSerializer instanceof ChangedSerializer &&
+ ((ChangedSerializer) this.valSerializer).inner() == null)
+ ((ChangedSerializer) this.valSerializer).setInner(context.valueSerde().serializer());
+ """
+
+ def process(self, key, value):
+ # send to all the registered topics
+ collector = self.context.record_collector
+ collector.send(self.topic, key=key, value=value,
+ timestamp_ms=self.context.timestamp(),
+ key_serializer=self.key_serializer,
+ val_serializer=self.val_serializer,
+ partitioner=self.partitioner)
+
+ def close(self):
+ # do nothing
+ pass
diff --git a/kafka/streams/processor/partition_group.py b/kafka/streams/processor/partition_group.py
new file mode 100644
index 0000000..6d5ea3f
--- /dev/null
+++ b/kafka/streams/processor/partition_group.py
@@ -0,0 +1,158 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import collections
+import heapq
+
+import kafka.errors as Errors
+from kafka.structs import TopicPartition
+from .record_queue import RecordQueue
+
+
+TaskId = collections.namedtuple('TaskId', 'topic_group_id partition_id')
+
+
+class RecordInfo(object):
+ def __init__(self):
+ self.queue = RecordQueue()
+
+ def node(self):
+ return self.queue.source()
+
+ def partition(self):
+ return self.queue.partition()
+
+ def queue(self):
+ return self.queue
+
+
+class PartitionGroup(object):
+ """A PartitionGroup is composed from a set of partitions. It also maintains
+ the timestamp of this group, hence the associated task as the min timestamp
+ across all partitions in the group.
+ """
+ def __init__(self, partition_queues, timestamp_extractor):
+ self._queues_by_time = [] # heapq
+ self._partition_queues = partition_queues
+ self._timestamp_extractor = timestamp_extractor
+ self._total_buffered = 0
+
+ def next_record(self, record_info):
+ """Get the next record and queue
+
+ Returns: (timestamp, ConsumerRecord)
+ """
+ record = None
+
+ if self._queues_by_time:
+ _, queue = heapq.heappop(self._queues_by_time)
+
+ # get the first record from this queue.
+ record = queue.poll()
+
+ if queue:
+ heapq.heappush(self._queues_by_time, (queue.timestamp(), queue))
+
+ record_info.queue = queue
+
+ if record:
+ self._total_buffered -= 1
+
+ return record
+
+ def add_raw_records(self, partition, raw_records):
+ """Adds raw records to this partition group
+
+ Arguments:
+ partition (TopicPartition): the partition
+ raw_records (list of ConsumerRecord): the raw records
+
+ Returns: the queue size for the partition
+ """
+ record_queue = self._partition_queues[partition]
+
+ old_size = record_queue.size()
+ new_size = record_queue.add_raw_records(raw_records, self._timestamp_extractor)
+
+ # add this record queue to be considered for processing in the future
+ # if it was empty before
+ if old_size == 0 and new_size > 0:
+ heapq.heappush(self._queues_by_time, (record_queue.timestamp(), record_queue))
+
+ self._total_buffered += new_size - old_size
+
+ return new_size
+
+ def partitions(self):
+ return set(self._partition_queues.keys())
+
+ def timestamp(self):
+ """Return the timestamp of this partition group
+ as the smallest partition timestamp among all its partitions
+ """
+ # we should always return the smallest timestamp of all partitions
+ # to avoid group partition time goes backward
+ timestamp = float('inf')
+ for queue in self._partition_queues.values():
+ if timestamp > queue.timestamp():
+ timestamp = queue.timestamp()
+ return timestamp
+
+ def num_buffered(self, partition=None):
+ if partition is None:
+ return self._total_buffered
+ record_queue = self._partition_queues.get(partition)
+ if not record_queue:
+ raise Errors.IllegalStateError('Record partition does not belong to this partition-group.')
+ return record_queue.size()
+
+ def top_queue_size(self):
+ if not self._queues_by_time:
+ return 0
+ return self._queues_by_time[0].size()
+
+ def close(self):
+ self._queues_by_time = []
+ self._partition_queues.clear()
+
+
+def partition_grouper(topic_groups, metadata):
+ """Assign partitions to task/topic groups
+
+ Arguments:
+ topic_groups ({topic_group_id: [topics]})
+ metadata (kafka.Cluster)
+
+ Returns: {TaskId: set([TopicPartition])}
+ """
+ groups = {}
+ for topic_group_id, topic_group in topic_groups.items():
+
+ partitions = set()
+ for topic in topic_group:
+ partitions.update(metadata.partitions_for_topic(topic))
+
+ for partition_id in partitions:
+ group = set()
+
+ for topic in topic_group:
+ if partition_id in metadata.partitions_for_topic(topic):
+ group.add(TopicPartition(topic, partition_id))
+ groups[TaskId(topic_group_id, partition_id)] = group
+
+ return groups
diff --git a/kafka/streams/processor/processor.py b/kafka/streams/processor/processor.py
new file mode 100644
index 0000000..115230e
--- /dev/null
+++ b/kafka/streams/processor/processor.py
@@ -0,0 +1,80 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import abc
+
+
+class Processor(object):
+ """A processor of key-value pair records."""
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def init(self, context):
+ """Initialize this processor with the given context.
+
+ The framework ensures this is called once per processor when the
+ topology that contains it is initialized.
+
+ If this processor is to be called periodically by the framework,
+ (via punctuate) then this method should schedule itself with the
+ provided context.
+
+ Arguments:
+ context (ProcessorContext): the context; may not be None
+
+ Returns: None
+ """
+ pass
+
+ @abc.abstractmethod
+ def process(self, key, value):
+ """Process the record with the given key and value.
+
+ Arguments:
+ key: the key for the record after deserialization
+ value: the value for the record after deserialization
+
+ Returns: None
+ """
+ pass
+
+ @abc.abstractmethod
+ def punctuate(self, timestamp):
+ """Perform any periodic operations
+
+ Requires that the processor scheduled itself with the context during
+ initialization
+
+ Arguments:
+ timestamp (int): stream time in ms when this method is called
+
+ Returns: None
+ """
+ pass
+
+ @abc.abstractmethod
+ def close(self):
+ """Close this processor and clean up any resources.
+
+ Be aware that close() is called after an internal cleanup.
+ Thus, it is not possible to write anything to Kafka as underlying
+ clients are already closed.
+
+ Returns: None
+ """
+ pass
diff --git a/kafka/streams/processor/processor_state_manager.py b/kafka/streams/processor/processor_state_manager.py
new file mode 100644
index 0000000..0905c21
--- /dev/null
+++ b/kafka/streams/processor/processor_state_manager.py
@@ -0,0 +1,306 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import logging
+import time
+
+import kafka.errors as Errors
+
+log = logging.getLogger(__name__)
+
+STATE_CHANGELOG_TOPIC_SUFFIX = "-changelog"
+CHECKPOINT_FILE_NAME = ".checkpoint"
+
+
+class ProcessorStateManager(object):
+
+ def __init__(self, application_id, task_id, sources, restore_consumer, is_standby, state_directory):
+
+ self.application_id = application_id
+ self.default_partition = task_id.partition
+ self.task_id = task_id
+ self.state_directory = state_directory
+ self.partition_for_topic = {}
+ for source in sources:
+ self.partition_for_topic[source.topic] = source
+
+ self.stores = {}
+ self.logging_enabled = set()
+ self.restore_consumer = restore_consumer
+ self.restored_offsets = {}
+ self.is_standby = is_standby
+ if is_standby:
+ self.restore_callbacks = {}
+ else:
+ self.restore_callbacks = None
+ self.offset_limits = {}
+ self.base_dir = state_directory.directory_for_task(task_id)
+
+ if not state_directory.lock(task_id, 5):
+ raise IOError("Failed to lock the state directory: " + self.base_dir)
+
+ """
+ # load the checkpoint information
+ checkpoint = OffsetCheckpoint(self.base_dir, CHECKPOINT_FILE_NAME)
+ self.checkpointed_offsets = checkpoint.read()
+
+ # delete the checkpoint file after finish loading its stored offsets
+ checkpoint.delete()
+ """
+
+
+ def store_changelog_topic(self, application_id, store_name):
+ return application_id + "-" + store_name + STATE_CHANGELOG_TOPIC_SUFFIX
+
+ def base_dir(self):
+ return self.base_dir
+
+ def register(self, store, logging_enabled, state_restore_callback):
+ """
+ * @throws IllegalArgumentException if the store name has already been registered or if it is not a valid name
+ * (e.g., when it conflicts with the names of internal topics, like the checkpoint file name)
+ * @throws StreamsException if the store's change log does not contain the partition
+ """
+ if (store.name == CHECKPOINT_FILE_NAME):
+ raise Errors.IllegalArgumentError("Illegal store name: " + CHECKPOINT_FILE_NAME)
+
+ if store.name in self.stores:
+ raise Errors.IllegalArgumentError("Store " + store.name + " has already been registered.")
+
+ if logging_enabled:
+ self.logging_enabled.add(store.name)
+
+ # check that the underlying change log topic exist or not
+ if logging_enabled:
+ topic = self.store_changelog_topic(self.application_id, store.name)
+ else:
+ topic = store.name
+
+ # block until the partition is ready for this state changelog topic or time has elapsed
+ partition = self.get_partition(topic)
+ partition_not_found = True
+ start_time = time.time() * 1000
+ wait_time = 5000 # hard-code the value since we should not block after KIP-4
+
+ while True:
+ try:
+ time.sleep(50)
+ except KeyboardInterrupt:
+ # ignore
+ pass
+
+ partitions = self.restore_consumer.partitions_for_topic(topic)
+ if partitions is None:
+ raise Errors.StreamsError("Could not find partition info for topic: " + topic)
+
+ if partition in partitions:
+ partition_not_found = False
+ break
+
+ if partition_not_found and (time.time() * 1000) < (start_time + wait_time):
+ continue
+ break
+
+ if partition_not_found:
+ raise Errors.StreamsError("Store " + store.name + "'s change log (" + topic + ") does not contain partition " + partition)
+
+ self.stores[store.name] = store
+
+ if self.is_standby:
+ if store.persistent():
+ self.restore_callbacks[topic] = state_restore_callback
+ else:
+ self.restore_active_state(topic, state_restore_callback)
+
+ """
+ def restore_active_state(self, topic_name, state_restore_callback):
+ # ---- try to restore the state from change-log ---- //
+
+ # subscribe to the store's partition
+ if (!restoreConsumer.subscription().isEmpty()) {
+ throw new IllegalStateException("Restore consumer should have not subscribed to any partitions beforehand");
+ }
+ TopicPartition storePartition = new TopicPartition(topicName, getPartition(topicName));
+ restoreConsumer.assign(Collections.singletonList(storePartition));
+
+ try {
+ // calculate the end offset of the partition
+ // TODO: this is a bit hacky to first seek then position to get the end offset
+ restoreConsumer.seekToEnd(singleton(storePartition));
+ long endOffset = restoreConsumer.position(storePartition);
+
+ // restore from the checkpointed offset of the change log if it is persistent and the offset exists;
+ // restore the state from the beginning of the change log otherwise
+ if (checkpointedOffsets.containsKey(storePartition)) {
+ restoreConsumer.seek(storePartition, checkpointedOffsets.get(storePartition));
+ } else {
+ restoreConsumer.seekToBeginning(singleton(storePartition));
+ }
+
+ // restore its state from changelog records
+ long limit = offsetLimit(storePartition);
+ while (true) {
+ long offset = 0L;
+ for (ConsumerRecord<byte[], byte[]> record : restoreConsumer.poll(100).records(storePartition)) {
+ offset = record.offset();
+ if (offset >= limit) break;
+ stateRestoreCallback.restore(record.key(), record.value());
+ }
+
+ if (offset >= limit) {
+ break;
+ } else if (restoreConsumer.position(storePartition) == endOffset) {
+ break;
+ } else if (restoreConsumer.position(storePartition) > endOffset) {
+ // For a logging enabled changelog (no offset limit),
+ // the log end offset should not change while restoring since it is only written by this thread.
+ throw new IllegalStateException("Log end offset should not change while restoring");
+ }
+ }
+
+ // record the restored offset for its change log partition
+ long newOffset = Math.min(limit, restoreConsumer.position(storePartition));
+ restoredOffsets.put(storePartition, newOffset);
+ } finally {
+ // un-assign the change log partition
+ restoreConsumer.assign(Collections.<TopicPartition>emptyList());
+ }
+ }
+
+ public Map<TopicPartition, Long> checkpointedOffsets() {
+ Map<TopicPartition, Long> partitionsAndOffsets = new HashMap<>();
+
+ for (Map.Entry<String, StateRestoreCallback> entry : restoreCallbacks.entrySet()) {
+ String topicName = entry.getKey();
+ int partition = getPartition(topicName);
+ TopicPartition storePartition = new TopicPartition(topicName, partition);
+
+ if (checkpointedOffsets.containsKey(storePartition)) {
+ partitionsAndOffsets.put(storePartition, checkpointedOffsets.get(storePartition));
+ } else {
+ partitionsAndOffsets.put(storePartition, -1L);
+ }
+ }
+ return partitionsAndOffsets;
+ }
+
+ public List<ConsumerRecord<byte[], byte[]>> updateStandbyStates(TopicPartition storePartition, List<ConsumerRecord<byte[], byte[]>> records) {
+ long limit = offsetLimit(storePartition);
+ List<ConsumerRecord<byte[], byte[]>> remainingRecords = null;
+
+ // restore states from changelog records
+
+ StateRestoreCallback restoreCallback = restoreCallbacks.get(storePartition.topic());
+
+ long lastOffset = -1L;
+ int count = 0;
+ for (ConsumerRecord<byte[], byte[]> record : records) {
+ if (record.offset() < limit) {
+ restoreCallback.restore(record.key(), record.value());
+ lastOffset = record.offset();
+ } else {
+ if (remainingRecords == null)
+ remainingRecords = new ArrayList<>(records.size() - count);
+
+ remainingRecords.add(record);
+ }
+ count++;
+ }
+ // record the restored offset for its change log partition
+ restoredOffsets.put(storePartition, lastOffset + 1);
+
+ return remainingRecords;
+ }
+
+ public void putOffsetLimit(TopicPartition partition, long limit) {
+ offsetLimits.put(partition, limit);
+ }
+
+ private long offsetLimit(TopicPartition partition) {
+ Long limit = offsetLimits.get(partition);
+ return limit != null ? limit : Long.MAX_VALUE;
+ }
+
+ public StateStore getStore(String name) {
+ return stores.get(name);
+ }
+
+ public void flush() {
+ if (!this.stores.isEmpty()) {
+ log.debug("Flushing stores.");
+ for (StateStore store : this.stores.values())
+ store.flush();
+ }
+ }
+
+ /**
+ * @throws IOException if any error happens when flushing or closing the state stores
+ */
+ public void close(Map<TopicPartition, Long> ackedOffsets) throws IOException {
+ try {
+ // attempting to flush and close the stores, just in case they
+ // are not closed by a ProcessorNode yet
+ if (!stores.isEmpty()) {
+ log.debug("Closing stores.");
+ for (Map.Entry<String, StateStore> entry : stores.entrySet()) {
+ log.debug("Closing storage engine {}", entry.getKey());
+ entry.getValue().flush();
+ entry.getValue().close();
+ }
+
+ Map<TopicPartition, Long> checkpointOffsets = new HashMap<>();
+ for (String storeName : stores.keySet()) {
+ TopicPartition part;
+ if (loggingEnabled.contains(storeName))
+ part = new TopicPartition(storeChangelogTopic(applicationId, storeName), getPartition(storeName));
+ else
+ part = new TopicPartition(storeName, getPartition(storeName));
+
+ // only checkpoint the offset to the offsets file if it is persistent;
+ if (stores.get(storeName).persistent()) {
+ Long offset = ackedOffsets.get(part);
+
+ if (offset != null) {
+ // store the last offset + 1 (the log position after restoration)
+ checkpointOffsets.put(part, offset + 1);
+ } else {
+ // if no record was produced. we need to check the restored offset.
+ offset = restoredOffsets.get(part);
+ if (offset != null)
+ checkpointOffsets.put(part, offset);
+ }
+ }
+ }
+
+ // write the checkpoint file before closing, to indicate clean shutdown
+ OffsetCheckpoint checkpoint = new OffsetCheckpoint(new File(this.baseDir, CHECKPOINT_FILE_NAME));
+ checkpoint.write(checkpointOffsets);
+ }
+ } finally {
+ // release the state directory directoryLock
+ stateDirectory.unlock(taskId);
+ }
+ }
+
+ private int getPartition(String topic) {
+ TopicPartition partition = partitionForTopic.get(topic);
+
+ return partition == null ? defaultPartition : partition.partition();
+ }
+}
+ """
diff --git a/kafka/streams/processor/punctuation.py b/kafka/streams/processor/punctuation.py
new file mode 100644
index 0000000..cedd64c
--- /dev/null
+++ b/kafka/streams/processor/punctuation.py
@@ -0,0 +1,32 @@
+from __future__ import absolute_import
+
+import heapq
+import threading
+
+
+class PunctuationQueue(object):
+
+ def __init__(self):
+ self._pq = []
+ self._lock = threading.Lock()
+
+ def schedule(self, sched):
+ with self._lock:
+ heapq.heappush(self._pq, sched)
+
+ def close(self):
+ with self._lock:
+ self._pq = []
+
+ def may_punctuate(self, timestamp, punctuator):
+ with self._lock:
+ punctuated = False
+ while (self._pq and self._pq[0][0] <= timestamp):
+ old_ts, node, interval_ms = heapq.heappop(self._pq)
+ if old_ts == 0:
+ old_ts = timestamp
+ punctuator.punctuate(node, timestamp)
+ sched = (old_ts + interval_ms, node, interval_ms)
+ heapq.heappush(self._pq, sched)
+ punctuated = True
+ return punctuated
diff --git a/kafka/streams/processor/quick_union.py b/kafka/streams/processor/quick_union.py
new file mode 100644
index 0000000..6a2e3f4
--- /dev/null
+++ b/kafka/streams/processor/quick_union.py
@@ -0,0 +1,62 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import kafka.streams.errors as Errors
+
+
+class QuickUnion(object):
+
+ def __init__(self):
+ self.ids = {}
+
+ def add(self, foo):
+ self.ids[foo] = foo
+
+ def exists(self, foo):
+ return foo in self.ids
+
+ def root(self, foo):
+ """
+ @throws NoSuchElementException if the parent of this node is null
+ """
+ current = foo
+ parent = self.ids.get(current)
+
+ if not parent:
+ raise Errors.NoSuchElementError("id: " + str(foo))
+
+ while parent != current:
+ # do the path compression
+ grandparent = self.ids.get(parent)
+ self.ids[current] = grandparent
+
+ current = parent
+ parent = grandparent
+
+ return current
+
+ def unite(self, foo, foobars):
+ for bar in foobars:
+ self.unite_pair(foo, bar)
+
+ def unite_pair(self, foo, bar):
+ root1 = self.root(foo)
+ root2 = self.root(bar)
+
+ if root1 != root2:
+ self.ids[root1] = root2
diff --git a/kafka/streams/processor/record_collector.py b/kafka/streams/processor/record_collector.py
new file mode 100644
index 0000000..de59b10
--- /dev/null
+++ b/kafka/streams/processor/record_collector.py
@@ -0,0 +1,68 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import logging
+
+from kafka.structs import TopicPartition
+
+log = logging.getLogger(__name__)
+
+
+def _handle_send_success(offsets, metadata):
+ tp = TopicPartition(metadata.topic, metadata.partition)
+ offsets[tp] = metadata.offset
+
+
+def _handle_send_failure(topic, exception):
+ log.error('Error sending record to topic %s: %s', topic, exception)
+
+
+class RecordCollector(object):
+ def __init__(self, producer):
+ self.producer = producer
+ self.offsets = {}
+
+ def send(self, topic, partition=None, key=None, value=None, timestamp_ms=None,
+ key_serializer=None, value_serializer=None, partitioner=None):
+ if key_serializer:
+ key_bytes = key_serializer(topic, key)
+ else:
+ key_bytes = key
+
+ if value_serializer:
+ val_bytes = value_serializer(topic, value)
+ else:
+ val_bytes = value
+
+ if partition is None and partitioner is not None:
+ partitions = self.producer.partitions_for(topic)
+ if partitions is not None:
+ partition = partitioner.partition(key, value, len(partitions))
+
+ future = self.producer.send(topic, partition=partition,
+ key=key_bytes, value=val_bytes,
+ timestamp_ms=timestamp_ms)
+
+ future.add_callback(_handle_send_success, self.offsets)
+ future.add_errback(_handle_send_failure, topic)
+
+ def flush(self):
+ self.producer.flush()
+
+ def close(self):
+ self.producer.close()
diff --git a/kafka/streams/processor/record_queue.py b/kafka/streams/processor/record_queue.py
new file mode 100644
index 0000000..ba22b76
--- /dev/null
+++ b/kafka/streams/processor/record_queue.py
@@ -0,0 +1,166 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+from collections import deque
+
+from kafka.consumer.fetcher import ConsumerRecord
+import kafka.errors as Errors
+
+
+class MinTimestampTracker(object):
+ """MinTimestampTracker maintains the min timestamp of timestamped elements."""
+ def __init__(self):
+ self.descending_subsequence = deque()
+
+ # in the case that incoming traffic is very small, the records maybe
+ # put and polled within a single iteration, in this case we need to
+ # remember the last polled record's timestamp
+ self.last_known_time = -1
+
+ def add_element(self, elem):
+ if elem is None:
+ raise ValueError('elem must not be None')
+
+ while self.descending_subsequence:
+ min_elem = self.descending_subsequence[-1]
+ if min_elem[0] < elem[0]:
+ break
+ self.descending_subsequence.pop()
+ self.descending_subsequence.append(elem)
+
+ def remove_element(self, elem):
+ if elem is not None:
+ if self.descending_subsequence:
+ if self.descending_subsequence[0] == elem:
+ self.descending_subsequence.popleft()
+
+ if not self.descending_subsequence:
+ self.last_known_time = elem[0]
+
+ def size(self):
+ return len(self.descending_subsequence)
+
+ def get(self):
+ if not self.descending_subsequence:
+ return self.last_known_time
+ return self.descending_subsequence[0][0]
+
+
+class RecordQueue(object):
+ """
+ RecordQueue is a FIFO queue of (timestamp, ConsumerRecord).
+ It also keeps track of the partition timestamp defined as the minimum
+ timestamp of records in its queue; in addition, its partition timestamp
+ is monotonically increasing such that once it is advanced, it will not be
+ decremented.
+ """
+
+ def __init__(self, partition, source):
+ self.partition = partition
+ self.source = source
+
+ self.fifo_queue = deque()
+ self.time_tracker = MinTimestampTracker()
+
+ self._partition_time = -1
+
+ def add_raw_records(self, raw_records, timestamp_extractor):
+ """Add a batch of ConsumerRecord into the queue
+
+ Arguments:
+ raw_records (list of ConsumerRecord): the raw records
+ timestamp_extractor (TimestampExtractor)
+
+ Returns: the size of this queue
+ """
+ for raw_record in raw_records:
+ # deserialize the raw record, extract the timestamp and put into the queue
+ key = self.source.deserialize_key(raw_record.topic, raw_record.key)
+ value = self.source.deserialize_value(raw_record.topic, raw_record.value)
+
+ record = ConsumerRecord(raw_record.topic,
+ raw_record.partition,
+ raw_record.offset,
+ raw_record.timestamp,
+ 0, # TimestampType.CREATE_TIME,
+ raw_record.checksum,
+ raw_record.serialized_key_size,
+ raw_record.serialized_value_size,
+ key, value)
+
+ timestamp = timestamp_extractor.extract(record)
+
+ # validate that timestamp must be non-negative
+ if timestamp < 0:
+ raise Errors.StreamsError('Extracted timestamp value is negative, which is not allowed.')
+
+ stamped_record = (timestamp, record)
+
+ self.fifo_queue.append(stamped_record)
+ self.time_tracker.add_element(stamped_record)
+
+ # update the partition timestamp if its currently
+ # tracked min timestamp has exceeded its value; this will
+ # usually only take effect for the first added batch
+ timestamp = self.time_tracker.get()
+
+ if timestamp > self._partition_time:
+ self._partition_time = timestamp
+
+ return self.size()
+
+ def poll(self):
+ """Get the next StampedRecord from the queue
+
+ Returns: StampedRecord
+ """
+ if not self.fifo_queue:
+ return None
+
+ elem = self.fifo_queue.popleft()
+ self.time_tracker.remove_element(elem)
+
+ # only advance the partition timestamp if its currently
+ # tracked min timestamp has exceeded its value
+ timestamp = self.time_tracker.get()
+
+ if timestamp > self._partition_time:
+ self._partition_time = timestamp
+
+ return elem
+
+ def size(self):
+ """Returns the number of records in the queue
+
+ Returns: the number of records
+ """
+ return len(self.fifo_queue)
+
+ def is_empty(self):
+ """Tests if the queue is empty
+
+ Returns: True if the queue is empty, otherwise False
+ """
+ return not bool(self.fifo_queue)
+
+ def timestamp(self):
+ """Returns the tracked partition timestamp
+
+ Returns: timestamp
+ """
+ return self._partition_time
diff --git a/kafka/streams/processor/stream_partition_assignor.py b/kafka/streams/processor/stream_partition_assignor.py
new file mode 100644
index 0000000..dfe7e18
--- /dev/null
+++ b/kafka/streams/processor/stream_partition_assignor.py
@@ -0,0 +1,435 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import collections
+import logging
+import weakref
+
+from kafka import TopicPartition
+from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor
+from kafka.coordinator.protocol import (
+ ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment)
+import kafka.streams.errors as Errors
+from .internal_topic_manager import InternalTopicManager
+from .partition_group import TaskId
+from .assignment.assignment_info import AssignmentInfo
+from .assignment.client_state import ClientState
+from .assignment.subscription_info import SubscriptionInfo
+from .assignment.task_assignor import TaskAssignor
+
+log = logging.getLogger(__name__)
+
+
+class AssignedPartition(object):
+ def __init__(self, task_id, partition):
+ self.task_id = task_id
+ self.partition = partition
+
+ def __cmp__(self, that):
+ return cmp(self.partition, that.partition)
+
+
+class SubscriptionUpdates(object):
+ """
+ * Used to capture subscribed topic via Patterns discovered during the
+ * partition assignment process.
+ """
+ def __init__(self):
+ self.updated_topic_subscriptions = set()
+
+ def update_topics(self, topic_names):
+ self.updatedTopicSubscriptions.clear()
+ self.updated_topic_subscriptions.update(topic_names)
+
+ def get_updates(self):
+ return self.updated_topic_subscriptions
+
+ def has_updates(self):
+ return bool(self.updated_topic_subscriptions)
+
+
+class StreamPartitionAssignor(AbstractPartitionAssignor):
+ name = 'stream'
+ version = 0
+
+ def __init__(self, **configs):
+ """
+ We need to have the PartitionAssignor and its StreamThread to be mutually accessible
+ since the former needs later's cached metadata while sending subscriptions,
+ and the latter needs former's returned assignment when adding tasks.
+ """
+ self.stream_thread = None
+ self.num_standby_replicas = None
+ self.topic_groups = {}
+ self.partition_to_task_ids = {}
+ self.state_changelog_topic_to_task_ids = {}
+ self.internal_source_topic_to_task_ids = {}
+ self.standby_tasks = {}
+ self.internal_topic_manager = None
+
+ self.num_standby_replicas = configs.get('num_standby_replicas', 0)
+
+ o = configs.get('stream_thread_instance')
+ if o is None:
+ raise Errors.KafkaError("StreamThread is not specified")
+
+ #if not isinstance(o, StreamThread):
+ # raise Errors.KafkaError(o.__class__.__name__ + " is not an instance of StreamThread")
+
+ self.stream_thread = weakref.proxy(o)
+ self.stream_thread.partition_assignor = self
+
+ if 'zookeeper_connect_config' in configs:
+ self.internal_topic_manager = InternalTopicManager(configs['zookeeper_connect_config'], configs.get('replication_factor', 1))
+ else:
+ log.info("Config 'zookeeper_connect_config' isn't supplied and hence no internal topics will be created.")
+
+ def metadata(self, topics):
+ """Adds the following information to subscription
+ 1. Client UUID (a unique id assigned to an instance of KafkaStreams)
+ 2. Task ids of previously running tasks
+ 3. Task ids of valid local states on the client's state directory.
+
+ Returns: ConsumerProtocolMemberMetadata
+ """
+
+ prev_tasks = self.stream_thread.prev_tasks()
+ standby_tasks = self.stream_thread.cached_tasks()
+ standby_tasks.difference_update(prev_tasks)
+ data = SubscriptionInfo(self.stream_thread.process_id, prev_tasks, standby_tasks)
+
+ return ConsumerProtocolMemberMetadata(self.version, list(topics), data.encode())
+
+ def prepare_topic(self, topic_to_task_ids, compact_topic, post_partition_phase):
+ """Internal helper function that creates a Kafka topic
+
+ Arguments:
+ topic_to_task_ids (dict): that contains the topic names to be created
+ compact_topic (bool): If True, the topic should be a compacted topic.
+ This is used for change log topics usually.
+ post_partition_phase (bool): If True, the computation for calculating
+ the number of partitions is slightly different. Set to True after
+ the initial topic-to-partition assignment.
+
+ Returns:
+ set([TopicPartition])
+ """
+ partitions = set()
+ # if ZK is specified, prepare the internal source topic before calling partition grouper
+ if self.internal_topic_manager is not None:
+ log.debug("Starting to validate internal topics in partition assignor.")
+
+ for topic, tasks in topic_to_task_ids.items():
+ num_partitions = 0
+ if post_partition_phase:
+ # the expected number of partitions is the max value of
+ # TaskId.partition + 1
+ for task in tasks:
+ if num_partitions < task.partition + 1:
+ num_partitions = task.partition + 1
+ else:
+ # should have size 1 only
+ num_partitions = -1
+ for task in tasks:
+ num_partitions = task.partition
+
+ self.internal_topic_manager.make_ready(topic, num_partitions, compact_topic)
+
+ # wait until the topic metadata has been propagated to all brokers
+ partition_ints = []
+ while True:
+ partition_ints = self.stream_thread.restore_consumer.partitions_for_topic(topic)
+ if partition_ints and len(partition_ints) == num_partitions:
+ break
+
+ for partition in partition_ints:
+ partitions.add(TopicPartition(topic, partition))
+
+ log.info("Completed validating internal topics in partition assignor.")
+ else:
+ missing_topics = []
+ for topic in topic_to_task_ids:
+ partition_ints = self.stream_thread.restore_consumer.partitions_for_topic(topic)
+ if partition_ints is None:
+ missing_topics.append(topic)
+
+ if missing_topics:
+ log.warn("Topic {} do not exists but couldn't created as the config '{}' isn't supplied",
+ missing_topics, 'zookeeper_connect_config')
+
+ return partitions
+
+ def assign(self, metadata, subscriptions):
+ """Assigns tasks to consumer clients in two steps.
+
+ 1. using TaskAssignor to assign tasks to consumer clients.
+ - Assign a task to a client which was running it previously.
+ If there is no such client, assign a task to a client which has
+ its valid local state.
+ - A client may have more than one stream threads.
+ The assignor tries to assign tasks to a client proportionally to
+ the number of threads.
+ - We try not to assign the same set of tasks to two different clients
+ We do the assignment in one-pass. The result may not satisfy above all.
+ 2. within each client, tasks are assigned to consumer clients in
+ round-robin manner.
+
+ Returns:
+ {member_id: ConsumerProtocolMemberAssignment}
+ """
+ import pdb; pdb.set_trace()
+ consumers_by_client = {}
+ states = {}
+ subscription_updates = SubscriptionUpdates()
+ # decode subscription info
+ for consumer_id, subscription in subscriptions.items():
+
+ if self.stream_thread.builder.source_topic_pattern() is not None:
+ # update the topic groups with the returned subscription list for regex pattern subscriptions
+ subscription_updates.update_topics(subscription.topics())
+
+ info = SubscriptionInfo.decode(subscription.user_data)
+
+ consumers = consumers_by_client.get(info.process_id)
+ if consumers is None:
+ consumers = set()
+ consumers_by_client[info.process_id] = consumers
+ consumers.add(consumer_id)
+
+ state = states.get(info.process_id)
+ if state is None:
+ state = ClientState()
+ states[info.process_id] = state
+
+ state.prev_active_tasks.update(info.prev_tasks)
+ state.prev_assigned_tasks.update(info.prev_tasks)
+ state.prev_assigned_tasks.update(info.standby_tasks)
+ state.capacity = state.capacity + 1
+
+ self.stream_thread.builder.subscription_updates = subscription_updates
+ self.topic_groups = self.stream_thread.builder.topic_groups()
+
+ # ensure the co-partitioning topics within the group have the same
+ # number of partitions, and enforce the number of partitions for those
+ # internal topics.
+ source_topic_groups = {}
+ internal_source_topic_groups = {}
+ for key, value in self.topic_groups.items():
+ source_topic_groups[key] = value.source_topics
+ internal_source_topic_groups[key] = value.inter_source_topics
+
+ # for all internal source topics
+ # set the number of partitions to the maximum of the depending
+ # sub-topologies source topics
+ internal_partitions = set()
+ all_internal_topic_names = set()
+ for topic_group_id, topics_info in self.topic_groups.items():
+ internal_topics = topics_info.inter_source_topics
+ all_internal_topic_names.update(internal_topics)
+ for internal_topic in internal_topics:
+ tasks = self.internal_source_topic_to_task_ids.get(internal_topic)
+
+ if tasks is None:
+ num_partitions = -1
+ for other in self.topic_groups.values():
+ other_sink_topics = other.sink_topics
+
+ if internal_topic in other_sink_topics:
+ for topic in other.source_topics:
+ partitions = None
+ # It is possible the source_topic is another internal topic, i.e,
+ # map().join().join(map())
+ if topic in all_internal_topic_names:
+ task_ids = self.internal_source_topic_to_task_ids.get(topic)
+ if task_ids is not None:
+ for task_id in task_ids:
+ partitions = task_id.partition
+ else:
+ partitions = len(metadata.partitions_for_topic(topic))
+
+ if partitions is not None and partitions > num_partitions:
+ num_partitions = partitions
+
+ self.internal_source_topic_to_task_ids[internal_topic] = [TaskId(topic_group_id, num_partitions)]
+ for partition in range(num_partitions):
+ internal_partitions.add(TopicPartition(internal_topic, partition))
+
+ copartition_topic_groups = self.stream_thread.builder.copartition_groups()
+ self.ensure_copartitioning(copartition_topic_groups, internal_source_topic_groups,
+ metadata.with_partitions(internal_partitions))
+
+
+ internal_partitions = self.prepare_topic(self.internal_source_topic_to_task_ids, False, False);
+ self.internal_source_topic_to_task_ids.clear()
+
+ metadata_with_internal_topics = metadata
+ if self.internal_topic_manager:
+ metadata_with_internal_topics = metadata.with_partitions(internal_partitions)
+
+ # get the tasks as partition groups from the partition grouper
+ partitions_for_task = self.stream_thread.partition_grouper(source_topic_groups, metadata_with_internal_topics)
+
+ # add tasks to state change log topic subscribers
+ self.state_changelog_topic_to_task_ids = {}
+ for task in partitions_for_task:
+ for topic_name in self.topic_groups[task.topic_group_id].state_changelog_topics:
+ tasks = self.state_changelog_topic_to_task_ids.get(topic_name)
+ if tasks is None:
+ tasks = set()
+ self.state_changelog_topic_to_task_ids[topic_name] = tasks
+
+ tasks.add(task)
+
+ for topic_name in self.topic_groups[task.topic_group_id].inter_source_topics:
+ tasks = self.internal_source_topic_to_task_ids.get(topic_name)
+ if tasks is None:
+ tasks = set()
+ self.internal_source_topic_to_task_ids[topic_name] = tasks
+
+ tasks.add(task)
+
+ # assign tasks to clients
+ states = TaskAssignor.assign(states, set(partitions_for_task), self.num_standby_replicas)
+ assignment = {}
+
+ for process_id, consumers in consumers_by_client.items():
+ state = states[process_id]
+
+ task_ids = []
+ num_active_tasks = len(state.active_tasks)
+ for task_id in state.active_tasks:
+ task_ids.append(task_id)
+
+ for task_id in state.assigned_tasks:
+ if task_id not in state.active_tasks:
+ task_ids.append(task_id)
+
+ num_consumers = len(consumers)
+ standby = {}
+
+ i = 0
+ for consumer in consumers:
+ assigned_partitions = []
+
+ num_task_ids = len(task_ids)
+ j = i
+ while j < num_task_ids:
+ task_id = task_ids[j]
+ if j < num_active_tasks:
+ for partition in partitions_for_task[task_id]:
+ assigned_partitions.append(AssignedPartition(task_id, partition))
+ else:
+ standby_partitions = standby.get(task_id)
+ if standby_partitions is None:
+ standby_partitions = set()
+ standby[task_id] = standby_partitions
+ standby_partitions.update(partitions_for_task[task_id])
+ j += num_consumers
+
+ assigned_partitions.sort()
+ active = []
+ active_partitions = collections.defaultdict(list)
+ for partition in assigned_partitions:
+ active.append(partition.task_id)
+ active_partitions[partition.topic].append(partition.partition)
+
+ data = AssignmentInfo(active, standby)
+ assignment[consumer] = ConsumerProtocolMemberAssignment(
+ self.version,
+ sorted(active_partitions.items()),
+ data.encode())
+ i += 1
+
+ active.clear()
+ standby.clear()
+
+ # if ZK is specified, validate the internal topics again
+ self.prepare_topic(self.internal_source_topic_to_task_ids, False, True)
+ # change log topics should be compacted
+ self.prepare_topic(self.state_changelog_topic_to_task_ids, True, True)
+
+ return assignment
+
+ def on_assignment(self, assignment):
+ partitions = [TopicPartition(topic, partition)
+ for topic, topic_partitions in assignment.partition_assignment
+ for partition in topic_partitions]
+
+ partitions.sort()
+
+ info = AssignmentInfo.decode(assignment.user_data)
+ self.standby_tasks = info.standby_tasks
+
+ partition_to_task_ids = {}
+ task_iter = iter(info.active_tasks)
+ for partition in partitions:
+ task_ids = self.partition_to_task_ids.get(partition)
+ if task_ids is None:
+ task_ids = set()
+ self.partition_to_task_ids[partition] = task_ids
+
+ try:
+ task_ids.add(next(task_iter))
+ except StopIteration:
+ raise Errors.TaskAssignmentError(
+ "failed to find a task id for the partition=%s"
+ ", partitions=%d, assignment_info=%s"
+ % (partition, len(partitions), info))
+ self.partition_to_task_ids = partition_to_task_ids
+
+ def ensure_copartitioning(self, copartition_groups, internal_topic_groups, metadata):
+ internal_topics = set()
+ for topics in internal_topic_groups.values():
+ internal_topics.update(topics)
+
+ for copartition_group in copartition_groups:
+ num_partitions = -1
+
+ for topic in copartition_group:
+ if topic not in internal_topics:
+ infos = metadata.partitions_for_topic(topic)
+
+ if infos is None:
+ raise Errors.TopologyBuilderError("External source topic not found: " + topic)
+
+ if num_partitions == -1:
+ num_partitions = len(infos)
+ elif num_partitions != len(infos):
+ raise Errors.TopologyBuilderError("Topics not copartitioned: [%s]" % copartition_group)
+
+ if num_partitions == -1:
+ for topic in internal_topics:
+ if topic in copartition_group:
+ partitions = len(metadata.partitions_for_topic(topic))
+ if partitions is not None and partitions > num_partitions:
+ num_partitions = partitions
+
+ # enforce co-partitioning restrictions to internal topics reusing
+ # internalSourceTopicToTaskIds
+ for topic in internal_topics:
+ if topic in copartition_group:
+ self.internal_source_topic_to_task_ids[topic] = [TaskId(-1, num_partitions)]
+
+ def tasks_for_partition(self, partition):
+ return self.partition_to_task_ids.get(partition)
+
+ def standby_tasks(self):
+ return self.standby_tasks
+
+ def set_internal_topic_manager(self, internal_topic_manager):
+ self.internal_topic_manager = internal_topic_manager
diff --git a/kafka/streams/processor/stream_task.py b/kafka/streams/processor/stream_task.py
new file mode 100644
index 0000000..5c073d6
--- /dev/null
+++ b/kafka/streams/processor/stream_task.py
@@ -0,0 +1,277 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import logging
+import threading
+
+from kafka.consumer.fetcher import ConsumerRecord
+import kafka.errors as Errors
+from kafka.structs import OffsetAndMetadata
+from .context import ProcessorContext
+from .partition_group import PartitionGroup, RecordInfo
+from .punctuation import PunctuationQueue
+from .record_collector import RecordCollector
+from .record_queue import RecordQueue
+from .task import AbstractTask
+
+NONEXIST_TOPIC = '__null_topic__'
+DUMMY_RECORD = ConsumerRecord(NONEXIST_TOPIC, -1, -1, -1, -1, None, None, -1, -1, -1)
+
+log = logging.getLogger(__name__)
+
+
+class StreamTask(AbstractTask):
+ """A StreamTask is associated with a PartitionGroup,
+ and is assigned to a StreamThread for processing."""
+
+ def __init__(self, task_id, partitions, topology, consumer, producer, restore_consumer, **config):
+ """Create StreamTask with its assigned partitions
+
+ Arguments:
+ task_id (str): the ID of this task
+ partitions (list of TopicPartition): the assigned partitions
+ topology (ProcessorTopology): the instance of ProcessorTopology
+ consumer (Consumer): the instance of Consumer
+ producer (Producer): the instance of Producer
+ restore_consumer (Consumer): the instance of Consumer used when
+ restoring state
+ """
+ super(StreamTask, self).__init__(task_id, partitions, topology,
+ consumer, restore_consumer, False, **config)
+ self._punctuation_queue = PunctuationQueue()
+ self._record_info = RecordInfo()
+
+ self.max_buffered_size = config['buffered_records_per_partition']
+ self._process_lock = threading.Lock()
+
+ self._commit_requested = False
+ self._commit_offset_needed = False
+ self._curr_record = None
+ self._curr_node = None
+ self.requires_poll = True
+
+ # create queues for each assigned partition and associate them
+ # to corresponding source nodes in the processor topology
+ partition_queues = {}
+
+ for partition in partitions:
+ source = self.topology.source(partition.topic())
+ queue = self.create_record_queue(partition, source)
+ partition_queues[partition] = queue
+
+ self.partition_group = PartitionGroup(partition_queues, self.config['timestamp_extractor_class'])
+
+ # initialize the consumed offset cache
+ self.consumed_offsets = {}
+
+ # create the record recordCollector that maintains the produced offsets
+ self.record_collector = RecordCollector(self.producer)
+
+ log.info('Creating restoration consumer client for stream task #%s', self.id)
+
+ # initialize the topology with its own context
+ self.processor_context = ProcessorContext(self.id, self, self.record_collector, self.state_mgr, **config)
+
+ # initialize the state stores
+ self.initialize_state_stores()
+
+ # initialize the task by initializing all its processor nodes in the topology
+ for node in self.topology.processors():
+ self._curr_node = node
+ try:
+ node.init(self.processor_context)
+ finally:
+ self._curr_node = None
+
+ self.processor_context.initialized()
+
+
+ def add_records(self, partition, records):
+ """Adds records to queues"""
+ queue_size = self.partition_group.add_raw_records(partition, records)
+
+ # if after adding these records, its partition queue's buffered size has
+ # been increased beyond the threshold, we can then pause the consumption
+ # for this partition
+ if queue_size > self.max_buffered_size:
+ self.consumer.pause(partition)
+
+ def process(self):
+ """Process one record
+
+ Returns:
+ number of records left in the buffer of this task's partition group after the processing is done
+ """
+ with self._process_lock:
+ # get the next record to process
+ record = self.partition_group.next_record(self._record_info)
+
+ # if there is no record to process, return immediately
+ if record is None:
+ self.requires_poll = True
+ return 0
+
+ self.requires_poll = False
+
+ try:
+ # process the record by passing to the source node of the topology
+ self._curr_record = record
+ self._curr_node = self._record_info.node()
+ partition = self._record_info.partition()
+
+ log.debug('Start processing one record [%s]', self._curr_record)
+
+ self._curr_node.process(self._curr_record.key, self._curr_record.value)
+
+ log.debug('Completed processing one record [%s]', self._curr_record)
+
+ # update the consumed offset map after processing is done
+ self.consumed_offsets[partition] = self._curr_record.offset
+ self._commit_offset_needed = True
+
+ # after processing this record, if its partition queue's
+ # buffered size has been decreased to the threshold, we can then
+ # resume the consumption on this partition
+ if self._record_info.queue().size() == self.max_buffered_size:
+ self.consumer.resume(partition)
+ self.requires_poll = True
+
+ if self.partition_group.top_queue_size() <= self.max_buffered_size:
+ self.requires_poll = True
+
+ finally:
+ self._curr_record = None
+ self._curr_node = None
+
+ return self.partition_group.num_buffered()
+
+ def maybe_punctuate(self):
+ """Possibly trigger registered punctuation functions if
+ current partition group timestamp has reached the defined stamp
+ """
+ timestamp = self.partition_group.timestamp()
+
+ # if the timestamp is not known yet, meaning there is not enough data
+ # accumulated to reason stream partition time, then skip.
+ if timestamp == -1:
+ return False
+ else:
+ return self._punctuation_queue.may_punctuate(timestamp, self)
+
+ def punctuate(self, node, timestamp):
+ if self._curr_node is not None:
+ raise Errors.IllegalStateError('Current node is not null')
+
+ self._curr_node = node
+ self._curr_record = (timestamp, DUMMY_RECORD)
+
+ try:
+ node.processor().punctuate(timestamp)
+ finally:
+ self._curr_node = None
+ self._curr_record = None
+
+ def record(self):
+ return self._curr_record
+
+ def node(self):
+ return self._curr_node
+
+ def commit(self):
+ """Commit the current task state"""
+ # 1) flush local state
+ self.state_mgr.flush()
+
+ # 2) flush produced records in the downstream and change logs of local states
+ self.record_collector.flush()
+
+ # 3) commit consumed offsets if it is dirty already
+ if self._commit_offset_needed:
+ consumed_offsets_and_metadata = {}
+ for partition, offset in self.consumed_offsets.items():
+ consumed_offsets_and_metadata[partition] = OffsetAndMetadata(offset + 1)
+ self.state_mgr.put_offset_limit(partition, offset + 1)
+ self.consumer.commit_sync(consumed_offsets_and_metadata)
+ self._commit_offset_needed = False
+
+ self._commit_requested = False
+
+ def commit_needed(self):
+ """Whether or not a request has been made to commit the current state"""
+ return self._commit_requested
+
+ def need_commit(self):
+ """Request committing the current task's state"""
+ self._commit_requested = True
+
+ def schedule(self, interval_ms):
+ """Schedules a punctuation for the processor
+
+ Arguments:
+ interval_ms (int): the interval in milliseconds
+
+ Raises: IllegalStateError if the current node is not None
+ """
+ if self._curr_node is None:
+ raise Errors.IllegalStateError('Current node is null')
+
+ schedule = (0, self._curr_node, interval_ms)
+ self._punctuation_queue.schedule(schedule)
+
+ def close(self):
+ self.partition_group.close()
+ self.consumed_offsets.clear()
+
+ # close the processors
+ # make sure close() is called for each node even when there is a RuntimeException
+ exception = None
+ for node in self.topology.processors():
+ self._curr_node = node
+ try:
+ node.close()
+ except RuntimeError as e:
+ exception = e
+ finally:
+ self._curr_node = None
+
+ super(StreamTask, self).close()
+
+ if exception is not None:
+ raise exception
+
+ def record_collector_offsets(self):
+ return self.record_collector.offsets()
+
+ def _create_record_queue(self, partition, source):
+ return RecordQueue(partition, source)
+
+ def forward(self, key, value, child_index=None, child_name=None):
+ this_node = self._curr_node
+ try:
+ children = this_node.children()
+
+ if child_index is not None:
+ children = [children[child_index]]
+ elif child_name is not None:
+ children = [child for child in children if child.name == child_name]
+
+ for child_node in children:
+ self._curr_node = child_node
+ child_node.process(key, value)
+ finally:
+ self._curr_node = this_node
diff --git a/kafka/streams/processor/stream_thread.py b/kafka/streams/processor/stream_thread.py
new file mode 100644
index 0000000..021fb0a
--- /dev/null
+++ b/kafka/streams/processor/stream_thread.py
@@ -0,0 +1,697 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import collections
+import copy
+import logging
+from multiprocessing import Process
+import os
+import time
+
+import six
+
+from kafka import KafkaConsumer, KafkaProducer
+from kafka.consumer.subscription_state import ConsumerRebalanceListener
+import kafka.errors as Errors
+from kafka.streams.errors import StreamsError
+from kafka.streams.utils import AtomicInteger
+from .partition_group import partition_grouper
+from .stream_partition_assignor import StreamPartitionAssignor
+from .task import StreamTask
+
+log = logging.getLogger(__name__)
+
+STREAM_THREAD_ID_SEQUENCE = AtomicInteger(0)
+
+
+class StreamThread(object):#Process):
+ DEFAULT_CONFIG = {
+ 'application_id': None, # required
+ 'bootstrap_servers': None, # required
+ 'process_id': None, # required
+ 'client_id': 'kafka-python-streams',
+ 'poll_ms': 100,
+ 'num_stream_threads': 1,
+ 'commit_interval_ms': 30000,
+ 'partition_grouper': partition_grouper,
+ 'key_serializer': None,
+ 'value_serializer': None,
+ 'key_deserializer': None,
+ 'value_deserializer': None,
+ 'state_dir': '/tmp/kafka-streams',
+ #'client_supplier': ....,
+ 'state_cleanup_delay_ms': 60000,
+ 'linger_ms': 100,
+ 'auto_offset_reset': 'earliest',
+ 'enable_auto_commit': False,
+ 'max_poll_records': 1000,
+ }
+
+ def __init__(self, builder, **configs):
+ stream_id = STREAM_THREAD_ID_SEQUENCE.increment()
+ #super(StreamThread, self).__init__(name='StreamThread-' + str(stream_id))
+ self.name = 'StreamThread-' + str(stream_id)
+
+ self.config = copy.copy(self.DEFAULT_CONFIG)
+ for key in self.config:
+ if key in configs:
+ self.config[key] = configs.pop(key)
+
+ for key in ('application_id', 'process_id', 'bootstrap_servers'):
+ assert self.config[key], 'Required configuration: ' + key
+
+ # Only check for extra config keys in top-level class
+ log.warning('Unrecognized configs: %s', configs.keys())
+
+ self.builder = builder
+ self._partition_grouper = self.config.get('partition_grouper', partition_grouper)
+ self.source_topics = builder.source_topics()
+ self.topic_pattern = builder.source_topic_pattern()
+ self.partition_assignor = None
+
+ self._rebalance_exception = None
+ self._process_standby_records = False
+ self._rebalance_listener = KStreamsConsumerRebalanceListener(self)
+ self.config['thread_client_id'] = self.config['client_id'] + "-" + self.name
+ self._running = False
+
+ @property
+ def application_id(self):
+ return self.config['application_id']
+
+ @property
+ def process_id(self):
+ return self.config['process_id']
+
+ @property
+ def client_id(self):
+ return self.config['client_id']
+
+ def initialize(self):
+ assert not self._running
+ log.info('Creating producer client for stream thread [%s]', self.name)
+ #client_supplier = self.config['client_supplier']
+ self.producer = KafkaProducer(**self.config)
+ log.info('Creating consumer client for stream thread [%s]', self.name)
+ assignor = StreamPartitionAssignor(
+ stream_thread_instance=self, **self.config)
+ self.consumer = KafkaConsumer(
+ partition_assignment_strategy=[assignor], **self.config)
+ log.info('Creating restore consumer client for stream thread [%s]', self.name)
+ restore_assignor = StreamPartitionAssignor(
+ stream_thread_instance=self, **self.config)
+ self.restore_consumer = KafkaConsumer(
+ partition_assignment_strategy=[restore_assignor], **self.config)
+
+ # initialize the task list
+ self._active_tasks = {}
+ self._standby_tasks = {}
+ self._active_tasks_by_partition = {}
+ self._standby_tasks_by_partition = {}
+ self._prev_tasks = set()
+
+ # standby ktables
+ self._standby_records = {}
+
+ # read in task specific config values
+ """
+ self._state_dir = os.path.join(self.config['state_dir'], self.config['application_id'])
+ if not os.path.isdir(self._state_dir):
+ os.makedirs(self._state_dir)
+ """
+
+ # the cleaning cycle won't start until partition assignment
+ self._last_clean_ms = float('inf')
+ self._last_commit_ms = time.time() * 1000
+
+ #self._sensors = StreamsMetricsImpl(metrics)
+
+ self._running = True #new AtomicBoolean(true);
+
+ def run(self):
+ """Execute the stream processors.
+
+ Raises:
+ KafkaError for any Kafka-related exceptions
+ Exception for any other non-Kafka exceptions
+ """
+ self.initialize()
+ log.info('Starting stream thread [%s]', self.name)
+
+ try:
+ self._run_loop()
+ except Errors.KafkaError:
+ # just re-throw the exception as it should be logged already
+ raise
+ except Exception:
+ # we have caught all Kafka related exceptions, and other runtime exceptions
+ # should be due to user application errors
+ log.exception('Streams application error during processing in thread [%s]', self.name)
+ raise
+ finally:
+ self.shutdown()
+
+ def close(self):
+ """Shutdown this stream thread."""
+ self._running = False #.set(False)
+
+ def tasks(self):
+ return self._active_tasks
+
+ def shutdown(self):
+ log.info('Shutting down stream thread [%s]', self.name)
+
+ # Exceptions should not prevent this call from going through all shutdown steps
+ try:
+ self._commit_all()
+ except Exception:
+ # already logged in commitAll()
+ pass
+
+ # Close standby tasks before closing the restore consumer since closing
+ # standby tasks uses the restore consumer.
+ self._remove_standby_tasks()
+
+ # We need to first close the underlying clients before closing the state
+ # manager, for example we need to make sure producer's record sends
+ # have all been acked before the state manager records
+ # changelog sent offsets
+ try:
+ self.producer.close()
+ except Exception:
+ log.exception('Failed to close producer in thread [%s]', self.name)
+ try:
+ self.consumer.close()
+ except Exception:
+ log.exception('Failed to close consumer in thread [%s]', self.name)
+ try:
+ self.restore_consumer.close()
+ except Exception:
+ log.exception('Failed to close restore consumer in thread [%s]', self.name)
+
+ self._remove_stream_tasks()
+ log.info('Stream thread shutdown complete [%s]', self.name)
+
+ def _run_loop(self):
+ total_num_buffered = 0
+ last_poll = 0
+ requires_poll = True
+
+ if self.topic_pattern is not None:
+ self.consumer.subscribe(pattern=self.topic_pattern,
+ listener=self._rebalance_listener)
+ else:
+ self.consumer.subscribe(topics=self.source_topics,
+ listener=self._rebalance_listener)
+
+ while self._still_running():
+ # try to fetch some records if necessary
+ if requires_poll:
+ requires_poll = False
+
+ start_poll = time.time() * 1000
+
+ if total_num_buffered == 0:
+ poll_ms = self.config['poll_ms']
+ else:
+ poll_ms = 0
+ records = self.consumer.poll(poll_ms)
+ last_poll = time.time() * 1000
+
+ if self._rebalance_exception is not None:
+ raise StreamsError('Failed to rebalance',
+ self._rebalance_exception)
+
+ if records:
+ for partition in records:
+ task = self._active_tasks_by_partition[partition]
+ task.add_records(partition, records[partition])
+
+ end_poll = time.time()
+ #self._sensors.poll_time_sensor.record(end_poll - start_poll)
+
+ total_num_buffered = 0
+
+ # try to process one fetch record from each task via the topology,
+ # and also trigger punctuate functions if necessary, which may
+ # result in more records going through the topology in this loop
+ if self._active_tasks:
+ for task in six.itervalues(self._active_tasks):
+ start_process = time.time()
+
+ total_num_buffered += task.process()
+ requires_poll = requires_poll or task.requires_poll()
+
+ latency_ms = (time.time() - start_process) * 1000
+ #self._sensors.process_time_sensor.record(latency_ms)
+
+ self._maybe_punctuate(task)
+
+ if task.commit_needed():
+ self._commit_one(task)
+
+ # if poll_ms has passed since the last poll, we poll to respond
+ # to a possible rebalance even when we paused all partitions.
+ if (last_poll + self.config['poll_ms'] < time.time() * 1000):
+ requires_poll = True
+
+ else:
+ # even when no task is assigned, we must poll to get a task.
+ requires_poll = True
+
+ self._maybe_commit()
+ self._maybe_update_standby_tasks()
+ self._maybe_clean()
+
+ def _maybe_update_standby_tasks(self):
+ if self._standby_tasks:
+ if self._process_standby_records:
+ if self._standby_records:
+ remaining_standby_records = {}
+ for partition in self._standby_records:
+ remaining = self._standby_records[partition]
+ if remaining:
+ task = self._standby_tasks_by_partition[partition]
+ remaining = task.update(partition, remaining)
+ if remaining:
+ remaining_standby_records[partition] = remaining
+ else:
+ self.restore_consumer.resume(partition)
+ self._standby_records = remaining_standby_records;
+ self._process_standby_records = False
+
+ records = self.restore_consumer.poll(0)
+
+ if records:
+ for partition in records:
+ task = self._standby_tasks_by_partition.get(partition)
+
+ if task is None:
+ log.error('missing standby task for partition %s', partition)
+ raise StreamsError('missing standby task for partition %s' % partition)
+
+ remaining = task.update(partition, records[partition])
+ if remaining:
+ self.restore_consumer.pause(partition)
+ self._standby_records[partition] = remaining
+
+ def _still_running(self):
+ if not self._running:
+ log.debug('Shutting down at user request.')
+ return False
+ return True
+
+ def _maybe_punctuate(self, task):
+ try:
+ now = time.time()
+
+ # check whether we should punctuate based on the task's partition
+ # group timestamp which are essentially based on record timestamp.
+ if task.maybe_punctuate():
+ latency_ms = (time.time() - now) * 1000
+ #self._sensors.punctuate_time_sensor.record(latency_ms)
+
+ except Errors.KafkaError:
+ log.exception('Failed to punctuate active task #%s in thread [%s]',
+ task.id, self.name)
+ raise
+
+ def _maybe_commit(self):
+ now_ms = time.time() * 1000
+
+ if (self.config['commit_interval_ms'] >= 0 and
+ self._last_commit_ms + self.config['commit_interval_ms'] < now_ms):
+ log.log(0, 'Committing processor instances because the commit interval has elapsed.')
+
+ self._commit_all()
+ self._last_commit_ms = now_ms
+
+ self._proces_standby_records = True
+
+ def _commit_all(self):
+ """Commit the states of all its tasks"""
+ for task in six.itervalues(self._active_tasks):
+ self._commit_one(task)
+ for task in six.itervalues(self._standby_tasks):
+ self._commit_one(task)
+
+ def _commit_one(self, task):
+ """Commit the state of a task"""
+ start = time.time()
+ try:
+ task.commit()
+ except Errors.CommitFailedError:
+ # commit failed. Just log it.
+ log.warning('Failed to commit %s #%s in thread [%s]',
+ task.__class__.__name__, task.id, self.name,
+ exc_info=True)
+ except Errors.KafkaError:
+ # commit failed due to an unexpected exception.
+ # Log it and rethrow the exception.
+ log.exception('Failed to commit %s #%s in thread [%s]',
+ task.__class__.__name__, task.id, self.name)
+ raise
+
+ timer_ms = (time.time() - start) * 1000
+ #self._sensors.commit_time_sensor.record(timer_ms)
+
+ def _maybe_clean(self):
+ """Cleanup any states of the tasks that have been removed from this thread"""
+ now_ms = time.time() * 1000
+
+ clean_time_ms = self.config['state_cleanup_delay_ms']
+ if now_ms > self._last_clean_ms + clean_time_ms:
+ """
+ File[] stateDirs = stateDir.listFiles();
+ if (stateDirs != null) {
+ for (File dir : stateDirs) {
+ try {
+ String dirName = dir.getName();
+ TaskId id = TaskId.parse(dirName.substring(dirName.lastIndexOf("-") + 1)); # task_id as (topic_group_id, partition_id)
+
+ // try to acquire the exclusive lock on the state directory
+ if (dir.exists()) {
+ FileLock directoryLock = null;
+ try {
+ directoryLock = ProcessorStateManager.lockStateDirectory(dir);
+ if (directoryLock != null) {
+ log.info("Deleting obsolete state directory {} for task {} after delayed {} ms.", dir.getAbsolutePath(), id, cleanTimeMs);
+ Utils.delete(dir);
+ }
+ } catch (FileNotFoundException e) {
+ // the state directory may be deleted by another thread
+ } catch (IOException e) {
+ log.error("Failed to lock the state directory due to an unexpected exception", e);
+ } finally {
+ if (directoryLock != null) {
+ try {
+ directoryLock.release();
+ directoryLock.channel().close();
+ } catch (IOException e) {
+ log.error("Failed to release the state directory lock");
+ }
+ }
+ }
+ }
+ } catch (TaskIdFormatException e) {
+ // there may be some unknown files that sits in the same directory,
+ // we should ignore these files instead trying to delete them as well
+ }
+ }
+ }
+ """
+ self._last_clean_ms = now_ms
+
+ def prev_tasks(self):
+ """Returns ids of tasks that were being executed before the rebalance."""
+ return self._prev_tasks
+
+ def cached_tasks(self):
+ """Returns ids of tasks whose states are kept on the local storage."""
+ # A client could contain some inactive tasks whose states are still
+ # kept on the local storage in the following scenarios:
+ # 1) the client is actively maintaining standby tasks by maintaining
+ # their states from the change log.
+ # 2) the client has just got some tasks migrated out of itself to other
+ # clients while these task states have not been cleaned up yet (this
+ # can happen in a rolling bounce upgrade, for example).
+
+ tasks = set()
+ """
+ File[] stateDirs = stateDir.listFiles();
+ if (stateDirs != null) {
+ for (File dir : stateDirs) {
+ try {
+ TaskId id = TaskId.parse(dir.getName());
+ // if the checkpoint file exists, the state is valid.
+ if (new File(dir, ProcessorStateManager.CHECKPOINT_FILE_NAME).exists())
+ tasks.add(id);
+
+ } catch (TaskIdFormatException e) {
+ // there may be some unknown files that sits in the same directory,
+ // we should ignore these files instead trying to delete them as well
+ }
+ }
+ }
+ """
+ return tasks
+
+ def _create_stream_task(self, task_id, partitions):
+ #self._sensors.task_creation_sensor.record()
+
+ topology = self.builder.build(self.config['application_id'],
+ task_id.topic_group_id)
+
+ return StreamTask(task_id, self.config['application_id'],
+ partitions, topology,
+ self.consumer, self.producer, self.restore_consumer,
+ **self.config) # self._sensors
+
+ def _add_stream_tasks(self, assignment):
+ if self.partition_assignor is None:
+ raise Errors.IllegalStateError(
+ 'Partition assignor has not been initialized while adding'
+ ' stream tasks: this should not happen.')
+
+ partitions_for_task = collections.defaultdict(set)
+
+ for partition in assignment:
+ task_ids = self.partition_assignor.tasks_for_partition(partition)
+ for task_id in task_ids:
+ partitions = partitions_for_task[task_id].add(partition)
+
+ # create the active tasks
+ for task_id, partitions in partitions_for_task.items():
+ try:
+ task = self._create_stream_task(task_id, partitions)
+ self._active_tasks[task_id] = task
+
+ for partition in partitions:
+ self._active_tasks_by_partition[partition] = task
+ except StreamsError:
+ log.exception('Failed to create an active task #%s in thread [%s]',
+ task_id, self.name)
+ raise
+
+ def _remove_stream_tasks(self):
+ try:
+ for task in self._active_tasks.values():
+ self._close_one(task)
+ self._prev_tasks.clear()
+ self._prev_tasks.update(set(self._active_tasks.keys()))
+
+ self._active_tasks.clear()
+ self._active_tasks_by_partition.clear()
+
+ except Exception:
+ log.exception('Failed to remove stream tasks in thread [%s]', self.name)
+
+ def _close_one(self, task):
+ log.info('Removing a task %s', task.id)
+ try:
+ task.close()
+ except StreamsError:
+ log.exception('Failed to close a %s #%s in thread [%s]',
+ task.__class__.__name__, task.id, self.name)
+ #self._sensors.task_destruction_sensor.record()
+
+ def _create_standby_task(self, task_id, partitions):
+ #self._sensors.task_creation_sensor.record()
+ raise NotImplementedError('no standby tasks yet')
+
+ topology = self.builder.build(self.config['application_id'],
+ task_id.topic_group_id)
+
+ """
+ if topology.state_store_suppliers():
+ return StandbyTask(task_id, partitions, topology,
+ self.consumer, self.restore_consumer,
+ **self.config) # self._sensors
+ else:
+ return None
+ """
+
+ def _add_standby_tasks(self):
+ if self.partition_assignor is None:
+ raise Errors.IllegalStateError(
+ 'Partition assignor has not been initialized while adding'
+ ' standby tasks: this should not happen.')
+
+ checkpointed_offsets = {}
+
+ # create the standby tasks
+ for task_id, partitions in self.partition_assignor.standby_tasks().items():
+ task = self._create_standby_task(task_id, partitions)
+ if task:
+ self._standby_tasks[task_id] = task
+ for partition in partitions:
+ self._standby_tasks_by_partition[partition] = task
+
+ # collect checkpointed offsets to position the restore consumer
+ # this includes all partitions from which we restore states
+ for partition in task.checkpointed_offsets():
+ self._standby_tasks_by_partition[partition] = task
+
+ checkpointed_offsets.update(task.checkpointed_offsets())
+
+ self.restore_consumer.assign(checkpointed_offsets.keys())
+
+ for partition, offset in checkpointed_offsets.items():
+ if offset >= 0:
+ self.restore_consumer.seek(partition, offset)
+ else:
+ self.restore_consumer.seek_to_beginning(partition)
+
+ def _remove_standby_tasks(self):
+ try:
+ for task in self._standby_tasks.values():
+ self._close_one(task)
+ self._standby_tasks.clear()
+ self._standby_tasks_by_partition.clear()
+ self._standby_records.clear()
+
+ # un-assign the change log partitions
+ self.restore_consumer.assign([])
+
+ except Exception:
+ log.exception('Failed to remove standby tasks in thread [%s]', self.name)
+
+"""
+ private class StreamsMetricsImpl implements StreamsMetrics {
+ final Metrics metrics;
+ final String metricGrpName;
+ final Map<String, String> metricTags;
+
+ final Sensor commitTimeSensor;
+ final Sensor pollTimeSensor;
+ final Sensor processTimeSensor;
+ final Sensor punctuateTimeSensor;
+ final Sensor taskCreationSensor;
+ final Sensor taskDestructionSensor;
+
+ public StreamsMetricsImpl(Metrics metrics) {
+
+ this.metrics = metrics;
+ this.metricGrpName = "stream-metrics";
+ this.metricTags = new LinkedHashMap<>();
+ this.metricTags.put("client-id", clientId + "-" + getName());
+
+ this.commitTimeSensor = metrics.sensor("commit-time");
+ this.commitTimeSensor.add(metrics.metricName("commit-time-avg", metricGrpName, "The average commit time in ms", metricTags), new Avg());
+ this.commitTimeSensor.add(metrics.metricName("commit-time-max", metricGrpName, "The maximum commit time in ms", metricTags), new Max());
+ this.commitTimeSensor.add(metrics.metricName("commit-calls-rate", metricGrpName, "The average per-second number of commit calls", metricTags), new Rate(new Count()));
+
+ this.pollTimeSensor = metrics.sensor("poll-time");
+ this.pollTimeSensor.add(metrics.metricName("poll-time-avg", metricGrpName, "The average poll time in ms", metricTags), new Avg());
+ this.pollTimeSensor.add(metrics.metricName("poll-time-max", metricGrpName, "The maximum poll time in ms", metricTags), new Max());
+ this.pollTimeSensor.add(metrics.metricName("poll-calls-rate", metricGrpName, "The average per-second number of record-poll calls", metricTags), new Rate(new Count()));
+
+ this.processTimeSensor = metrics.sensor("process-time");
+ this.processTimeSensor.add(metrics.metricName("process-time-avg-ms", metricGrpName, "The average process time in ms", metricTags), new Avg());
+ this.processTimeSensor.add(metrics.metricName("process-time-max-ms", metricGrpName, "The maximum process time in ms", metricTags), new Max());
+ this.processTimeSensor.add(metrics.metricName("process-calls-rate", metricGrpName, "The average per-second number of process calls", metricTags), new Rate(new Count()));
+
+ this.punctuateTimeSensor = metrics.sensor("punctuate-time");
+ this.punctuateTimeSensor.add(metrics.metricName("punctuate-time-avg", metricGrpName, "The average punctuate time in ms", metricTags), new Avg());
+ this.punctuateTimeSensor.add(metrics.metricName("punctuate-time-max", metricGrpName, "The maximum punctuate time in ms", metricTags), new Max());
+ this.punctuateTimeSensor.add(metrics.metricName("punctuate-calls-rate", metricGrpName, "The average per-second number of punctuate calls", metricTags), new Rate(new Count()));
+
+ this.taskCreationSensor = metrics.sensor("task-creation");
+ this.taskCreationSensor.add(metrics.metricName("task-creation-rate", metricGrpName, "The average per-second number of newly created tasks", metricTags), new Rate(new Count()));
+
+ this.taskDestructionSensor = metrics.sensor("task-destruction");
+ this.taskDestructionSensor.add(metrics.metricName("task-destruction-rate", metricGrpName, "The average per-second number of destructed tasks", metricTags), new Rate(new Count()));
+ }
+
+ @Override
+ public void recordLatency(Sensor sensor, long startNs, long endNs) {
+ sensor.record((endNs - startNs) / 1000000, endNs);
+ }
+
+ /**
+ * @throws IllegalArgumentException if tags is not constructed in key-value pairs
+ */
+ @Override
+ public Sensor addLatencySensor(String scopeName, String entityName, String operationName, String... tags) {
+ // extract the additional tags if there are any
+ Map<String, String> tagMap = new HashMap<>(this.metricTags);
+ if ((tags.length % 2) != 0)
+ throw new IllegalArgumentException("Tags needs to be specified in key-value pairs");
+
+ for (int i = 0; i < tags.length; i += 2)
+ tagMap.put(tags[i], tags[i + 1]);
+
+ String metricGroupName = "stream-" + scopeName + "-metrics";
+
+ // first add the global operation metrics if not yet, with the global tags only
+ Sensor parent = metrics.sensor(scopeName + "-" + operationName);
+ addLatencyMetrics(metricGroupName, parent, "all", operationName, this.metricTags);
+
+ // add the store operation metrics with additional tags
+ Sensor sensor = metrics.sensor(scopeName + "-" + entityName + "-" + operationName, parent);
+ addLatencyMetrics(metricGroupName, sensor, entityName, operationName, tagMap);
+
+ return sensor;
+ }
+
+ private void addLatencyMetrics(String metricGrpName, Sensor sensor, String entityName, String opName, Map<String, String> tags) {
+ maybeAddMetric(sensor, metrics.metricName(entityName + "-" + opName + "-avg-latency-ms", metricGrpName,
+ "The average latency in milliseconds of " + entityName + " " + opName + " operation.", tags), new Avg());
+ maybeAddMetric(sensor, metrics.metricName(entityName + "-" + opName + "-max-latency-ms", metricGrpName,
+ "The max latency in milliseconds of " + entityName + " " + opName + " operation.", tags), new Max());
+ maybeAddMetric(sensor, metrics.metricName(entityName + "-" + opName + "-qps", metricGrpName,
+ "The average number of occurrence of " + entityName + " " + opName + " operation per second.", tags), new Rate(new Count()));
+ }
+
+ private void maybeAddMetric(Sensor sensor, MetricName name, MeasurableStat stat) {
+ if (!metrics.metrics().containsKey(name))
+ sensor.add(name, stat);
+ }
+ }
+}
+ """
+
+
+class KStreamsConsumerRebalanceListener(ConsumerRebalanceListener):
+ def __init__(self, stream_thread):
+ self.stream_thread = stream_thread
+
+ def on_partitions_assigned(self, assignment):
+ try:
+ self.stream_thread._add_stream_tasks(assignment)
+ self.stream_thread._add_standby_tasks()
+ # start the cleaning cycle
+ self.stream_thread._last_clean_ms = time.time() * 1000
+ except Exception as e:
+ self.stream_thread._rebalance_exception = e
+ raise
+
+ def on_partitions_revoked(self, assignment):
+ try:
+ self.stream_thread._commit_all()
+ # stop the cleaning cycle until partitions are assigned
+ self.stream_thread._last_clean_ms = float('inf')
+ except Exception as e:
+ self.stream_thread._rebalance_exception = e
+ raise
+ finally:
+ # TODO: right now upon partition revocation, we always remove all
+ # the tasks; this behavior can be optimized to only remove affected
+ # tasks in the future
+ self.stream_thread._remove_stream_tasks()
+ self.stream_thread._remove_standby_tasks()
diff --git a/kafka/streams/processor/task.py b/kafka/streams/processor/task.py
new file mode 100644
index 0000000..67a276b
--- /dev/null
+++ b/kafka/streams/processor/task.py
@@ -0,0 +1,333 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import abc
+import logging
+import threading
+
+from kafka.consumer.fetcher import ConsumerRecord
+import kafka.errors as Errors
+from kafka.streams.errors import ProcessorStateError
+from kafka.structs import OffsetAndMetadata
+from .context import ProcessorContext
+from .partition_group import PartitionGroup, RecordInfo
+from .punctuation import PunctuationQueue
+from .record_collector import RecordCollector
+from .record_queue import RecordQueue
+
+log = logging.getLogger(__name__)
+
+NONEXIST_TOPIC = '__null_topic__'
+DUMMY_RECORD = ConsumerRecord(NONEXIST_TOPIC, -1, -1, -1, -1, None, None, -1, -1, -1)
+
+
+class AbstractTask(object):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, task_id, partitions, topology,
+ consumer, restore_consumer, is_standby, **config):
+ """Raises ProcessorStateError if the state manager cannot be created"""
+ self.id = task_id
+ self.application_id = self.config['application_id']
+ self.partitions = set(partitions)
+ self.topology = topology
+ self.consumer = consumer
+ self.processor_context = None
+
+ # create the processor state manager
+ """
+ try:
+ File applicationStateDir = StreamThread.makeStateDir(applicationId, config.getString(StreamsConfig.STATE_DIR_CONFIG));
+ File stateFile = new File(applicationStateDir.getCanonicalPath(), id.toString());
+ # if partitions is null, this is a standby task
+ self.state_mgr = ProcessorStateManager(applicationId, id.partition, partitions, stateFile, restoreConsumer, isStandby);
+ except Exception as e:
+ raise ProcessorStateError('Error while creating the state manager', e)
+ """
+
+ def initialize_state_stores(self):
+ # set initial offset limits
+ self.initialize_offset_limits()
+
+ for state_store_supplier in self.topology.state_store_suppliers():
+ store = state_store_supplier.get()
+ store.init(self.processor_context, store)
+
+ @abc.abstractmethod
+ def commit(self):
+ pass
+
+ def close(self):
+ """Close the task.
+
+ Raises ProcessorStateError if there is an error while closing the state manager
+ """
+ try:
+ self.state_mgr.close(self.record_collector_offsets())
+ except Exception as e:
+ raise ProcessorStateError('Error while closing the state manager', e)
+
+ def record_collector_offsets(self):
+ return {}
+
+ def initialize_offset_limits(self):
+ for partition in self.partitions:
+ metadata = self.consumer.committed(partition) # TODO: batch API?
+ self.state_mgr.put_offset_limit(partition, metadata.offset if metadata else 0)
+
+
+class StreamTask(AbstractTask):
+ """A StreamTask is associated with a PartitionGroup,
+ and is assigned to a StreamThread for processing."""
+
+ def __init__(self, task_id, partitions, topology, consumer, producer, restore_consumer, **config):
+ """Create StreamTask with its assigned partitions
+
+ Arguments:
+ task_id (str): the ID of this task
+ partitions (list of TopicPartition): the assigned partitions
+ topology (ProcessorTopology): the instance of ProcessorTopology
+ consumer (Consumer): the instance of Consumer
+ producer (Producer): the instance of Producer
+ restore_consumer (Consumer): the instance of Consumer used when
+ restoring state
+ """
+ super(StreamTask, self).__init__(task_id, partitions, topology,
+ consumer, restore_consumer, False, **config)
+ self._punctuation_queue = PunctuationQueue()
+ self._record_info = RecordInfo()
+
+ self.max_buffered_size = config['buffered_records_per_partition']
+ self._process_lock = threading.Lock()
+
+ self._commit_requested = False
+ self._commit_offset_needed = False
+ self._curr_record = None
+ self._curr_node = None
+ self.requires_poll = True
+
+ # create queues for each assigned partition and associate them
+ # to corresponding source nodes in the processor topology
+ partition_queues = {}
+
+ for partition in partitions:
+ source = self.topology.source(partition.topic())
+ queue = self._create_record_queue(partition, source)
+ partition_queues[partition] = queue
+
+ self.partition_group = PartitionGroup(partition_queues, self.config['timestamp_extractor_class'])
+
+ # initialize the consumed offset cache
+ self.consumed_offsets = {}
+
+ # create the RecordCollector that maintains the produced offsets
+ self.record_collector = RecordCollector(self.producer)
+
+ log.info('Creating restoration consumer client for stream task #%s', self.id)
+
+ # initialize the topology with its own context
+ self.processor_context = ProcessorContext(self.id, self, self.record_collector, self.state_mgr, **config)
+
+ # initialize the state stores
+ self.initialize_state_stores()
+
+ # initialize the task by initializing all its processor nodes in the topology
+ for node in self.topology.processors():
+ self._curr_node = node
+ try:
+ node.init(self.processor_context)
+ finally:
+ self._curr_node = None
+
+ self.processor_context.initialized()
+
+
+ def add_records(self, partition, records):
+ """Adds records to queues"""
+ queue_size = self.partition_group.add_raw_records(partition, records)
+
+ # if after adding these records, its partition queue's buffered size has
+ # been increased beyond the threshold, we can then pause the consumption
+ # for this partition
+ if queue_size > self.max_buffered_size:
+ self.consumer.pause(partition)
+
+ def process(self):
+ """Process one record
+
+ Returns:
+ number of records left in the buffer of this task's partition group after the processing is done
+ """
+ with self._process_lock:
+ # get the next record to process
+ record = self.partition_group.next_record(self._record_info)
+
+ # if there is no record to process, return immediately
+ if record is None:
+ self.requires_poll = True
+ return 0
+
+ self.requires_poll = False
+
+ try:
+ # process the record by passing to the source node of the topology
+ self._curr_record = record
+ self._curr_node = self._record_info.node()
+ partition = self._record_info.partition()
+
+ log.debug('Start processing one record [%s]', self._curr_record)
+
+ self._curr_node.process(self._curr_record.key, self._curr_record.value)
+
+ log.debug('Completed processing one record [%s]', self._curr_record)
+
+ # update the consumed offset map after processing is done
+ self.consumed_offsets[partition] = self._curr_record.offset
+ self._commit_offset_needed = True
+
+ # after processing this record, if its partition queue's
+ # buffered size has been decreased to the threshold, we can then
+ # resume the consumption on this partition
+ if self._record_info.queue().size() == self.max_buffered_size:
+ self.consumer.resume(partition)
+ self.requires_poll = True
+
+ if self.partition_group.top_queue_size() <= self.max_buffered_size:
+ self.requires_poll = True
+
+ finally:
+ self._curr_record = None
+ self._curr_node = None
+
+ return self.partition_group.num_buffered()
+
+ def maybe_punctuate(self):
+ """Possibly trigger registered punctuation functions if
+ current partition group timestamp has reached the defined stamp
+ """
+ timestamp = self.partition_group.timestamp()
+
+ # if the timestamp is not known yet, meaning there is not enough data
+ # accumulated to reason stream partition time, then skip.
+ if timestamp == -1:
+ return False
+ else:
+ return self._punctuation_queue.may_punctuate(timestamp, self)
+
+ def punctuate(self, node, timestamp):
+ if self._curr_node is not None:
+ raise Errors.IllegalStateError('Current node is not null')
+
+ self._curr_node = node
+ self._curr_record = (timestamp, DUMMY_RECORD)
+
+ try:
+ node.processor().punctuate(timestamp)
+ finally:
+ self._curr_node = None
+ self._curr_record = None
+
+ def record(self):
+ return self._curr_record
+
+ def node(self):
+ return self._curr_node
+
+ def commit(self):
+ """Commit the current task state"""
+ # 1) flush local state
+ self.state_mgr.flush()
+
+ # 2) flush produced records in the downstream and change logs of local states
+ self.record_collector.flush()
+
+ # 3) commit consumed offsets if it is dirty already
+ if self._commit_offset_needed:
+ consumed_offsets_and_metadata = {}
+ for partition, offset in self.consumed_offsets.items():
+ consumed_offsets_and_metadata[partition] = OffsetAndMetadata(offset + 1)
+ self.state_mgr.put_offset_limit(partition, offset + 1)
+ self.consumer.commit_sync(consumed_offsets_and_metadata)
+ self._commit_offset_needed = False
+
+ self._commit_requested = False
+
+ def commit_needed(self):
+ """Whether or not a request has been made to commit the current state"""
+ return self._commit_requested
+
+ def need_commit(self):
+ """Request committing the current task's state"""
+ self._commit_requested = True
+
+ def schedule(self, interval_ms):
+ """Schedules a punctuation for the processor
+
+ Arguments:
+ interval_ms (int): the interval in milliseconds
+
+ Raises: IllegalStateError if the current node is not None
+ """
+ if self._curr_node is None:
+ raise Errors.IllegalStateError('Current node is null')
+
+ schedule = (0, self._curr_node, interval_ms)
+ self._punctuation_queue.schedule(schedule)
+
+ def close(self):
+ self.partition_group.close()
+ self.consumed_offsets.clear()
+
+ # close the processors
+ # make sure close() is called for each node even when there is a RuntimeException
+ exception = None
+ for node in self.topology.processors():
+ self._curr_node = node
+ try:
+ node.close()
+ except RuntimeError as e:
+ exception = e
+ finally:
+ self._curr_node = None
+
+ super(StreamTask, self).close()
+
+ if exception is not None:
+ raise exception
+
+ def record_collector_offsets(self):
+ return self.record_collector.offsets()
+
+ def _create_record_queue(self, partition, source):
+ return RecordQueue(partition, source)
+
+ def forward(self, key, value, child_index=None, child_name=None):
+ this_node = self._curr_node
+ try:
+ children = this_node.children()
+
+ if child_index is not None:
+ children = [children[child_index]]
+ elif child_name is not None:
+ children = [child for child in children if child.name == child_name]
+
+ for child_node in children:
+ self._curr_node = child_node
+ child_node.process(key, value)
+ finally:
+ self._curr_node = this_node
diff --git a/kafka/streams/processor/topology.py b/kafka/streams/processor/topology.py
new file mode 100644
index 0000000..a2fa441
--- /dev/null
+++ b/kafka/streams/processor/topology.py
@@ -0,0 +1,21 @@
+class ProcessorTopology(object):
+
+ def __init__(self, processor_nodes, source_by_topics, state_store_suppliers):
+ self.processor_nodes = processor_nodes
+ self.source_by_topics = source_by_topics
+ self.state_store_suppliers = state_store_suppliers
+
+ def sourceTopics(self):
+ return set(self.source_by_topics)
+
+ def source(self, topic):
+ return self.source_by_topics.get(topic)
+
+ def sources(self):
+ return set(self.source_by_topics.values())
+
+ def processors(self):
+ return self.processor_nodes
+
+ def state_store_suppliers(self):
+ return self.state_store_suppliers
diff --git a/kafka/streams/processor/topology_builder.py b/kafka/streams/processor/topology_builder.py
new file mode 100644
index 0000000..223cda9
--- /dev/null
+++ b/kafka/streams/processor/topology_builder.py
@@ -0,0 +1,642 @@
+"""
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+"""
+from __future__ import absolute_import
+
+import abc
+import re
+
+import kafka.streams.errors as Errors
+from .node import ProcessorNode, SourceNode, SinkNode
+from .processor_state_manager import STATE_CHANGELOG_TOPIC_SUFFIX
+from .quick_union import QuickUnion
+from .topology import ProcessorTopology
+
+
+class StateStoreFactory(object):
+ def __init__(self, is_internal, supplier):
+ self.users = set()
+ self.is_internal = is_internal
+ self.supplier = supplier
+
+class NodeFactory(object):
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, builder, name):
+ self.builder = builder
+ self.name = name
+
+ @abc.abstractmethod
+ def build(self, application_id):
+ pass
+
+class ProcessorNodeFactory(NodeFactory):
+ def __init__(self, builder, name, parents, supplier):
+ self.builder = builder
+ self.name = name
+ self.parents = list(parents)
+ self.supplier = supplier
+ self.state_store_names = set()
+
+ def add_state_store(self, state_store_name):
+ self.state_store_names.add(state_store_name)
+
+ def build(self, application_id):
+ return ProcessorNode(self.name, self.supplier(), self.state_store_names)
+
+class SourceNodeFactory(NodeFactory):
+ def __init__(self, builder, name, topics, pattern, key_deserializer, val_deserializer):
+ self.builder = builder
+ self.name = name
+ self.topics = list(topics) if topics else None
+ self.pattern = pattern
+ self.key_deserializer = key_deserializer
+ self.val_deserializer = val_deserializer
+
+ def get_topics(self):
+ return self.topics
+
+ """
+ def get_topics(self, subscribed_topics=None):
+ if not subscribed_topics:
+ return self.topics
+ matched_topics = []
+ for update in subscribed_topics:
+ if self.pattern == topicToPatterns.get(update)) {
+ matchedTopics.add(update);
+ //not same pattern instance,but still matches not allowed
+ } else if (topicToPatterns.containsKey(update) && isMatch(update)) {
+ throw new TopologyBuilderException("Topic " + update + " already matched check for overlapping regex patterns");
+ } else if (isMatch(update)) {
+ topicToPatterns.put(update, this.pattern);
+ matchedTopics.add(update);
+ }
+ }
+ return matchedTopics.toArray(new String[matchedTopics.size()]);
+ }
+ """
+
+ def build(self, application_id):
+ return SourceNode(self.name, self.key_deserializer, self.val_deserializer)
+
+ """
+ private boolean isMatch(String topic) {
+ return this.pattern.matcher(topic).matches();
+ """
+
+class SinkNodeFactory(NodeFactory):
+ def __init__(self, builder, name, parents, topic, key_serializer, val_serializer, partitioner):
+ self.builder = builder
+ self.name = name
+ self.parents = list(parents)
+ self.topic = topic
+ self.key_serializer = key_serializer
+ self.val_serializer = val_serializer
+ self.partitioner = partitioner
+
+ def build(self, application_id):
+ if self.topic in self.builder.internal_topics:
+ sink_name = application_id + '-' + self.topic
+ else:
+ sink_name = self.topic
+ return SinkNode(self.name, sink_name, self.key_serializer, self.val_serializer, self.partitioner)
+
+class TopicsInfo(object):
+ def __init__(self, builder, sink_topics, source_topics, inter_source_topics, state_changelog_topics):
+ self.sink_topics = set(sink_topics)
+ self.source_topics = set(source_topics)
+ self.inter_source_topics = set(inter_source_topics)
+ self.state_changelog_topics = set(state_changelog_topics)
+
+ def __eq__(self, other):
+ if isinstance(other, TopicsInfo):
+ return (other.source_topics == self.source_topics and
+ other.state_changelog_topics == self.state_changelog_topics)
+ else:
+ return False
+
+ """
+ @Override
+ public int hashCode() {
+ long n = ((long) sourceTopics.hashCode() << 32) | (long) stateChangelogTopics.hashCode();
+ return (int) (n % 0xFFFFFFFFL);
+ """
+
+
+class TopologyBuilder(object):
+ """TopologyBuilder is used to build a ProcessorTopology.
+
+ A topology contains an acyclic graph of sources, processors, and sinks.
+
+ A source is a node in the graph that consumes one or more Kafka topics
+ and forwards them to its child nodes.
+
+ A processor is a node in the graph that receives input records from
+ upstream nodes, processes the records, and optionally forwarding new
+ records to one or all of its children.
+
+ A sink is a node in the graph that receives records from upstream nodes
+ and writes them to a Kafka topic.
+
+ This builder allows you to construct an acyclic graph of these nodes,
+ and the builder is then passed into a new KafkaStreams instance that will
+ then begin consuming, processing, and producing records.
+ """
+
+ def __init__(self):
+ """Create a new builder."""
+ # node factories in a topological order
+ self.node_factories = {}
+
+ # state factories
+ self.state_factories = {}
+
+ self.source_topic_names = set()
+ self.internal_topic_names = set()
+ self.node_grouper = QuickUnion()
+ self.copartition_source_groups = []
+ self.node_to_source_topics = {}
+ self.node_to_source_patterns = {}
+ self.topic_to_patterns = {}
+ self.node_to_sink_topic = {}
+ self.subscription_updates = set()
+ self.application_id = None
+
+ self._node_groups = None
+ self.topic_pattern = None
+
+ def add_source(self, name, *topics, **kwargs):
+ """Add a named source node that consumes records from kafka.
+
+ Source consumes named topics or topics that match a pattern and
+ forwards the records to child processor and/or sink nodes.
+ The source will use the specified key and value deserializers.
+
+ Arguments:
+ name (str): unique name of the source used to reference this node
+ when adding processor children
+ topics (*str): one or more Kafka topics to consume with this source
+
+ Keyword Arguments:
+ topic_pattern (str): pattern to match source topics
+ key_deserializer (callable): used when consuming records, if None,
+ uses the default key deserializer specified in the stream
+ configuration
+ val_deserializer (callable): the value deserializer used when
+ consuming records; if None, uses the default value
+ deserializer specified in the stream configuration.
+
+ Raises: TopologyBuilderError if processor is already added or if topics
+ have already been registered by another source
+
+ Returns: self, so methods can be chained together
+ """
+ topic_pattern = kwargs.get('topic_pattern', None)
+ key_deserializer = kwargs.get('key_deserializer', None)
+ val_deserializer = kwargs.get('val_deserializer', None)
+ if name in self.node_factories:
+ raise Errors.TopologyBuilderError("Processor " + name + " is already added.")
+
+ if topic_pattern:
+ if topics:
+ raise Errors.TopologyBuilderError('Cannot supply both topics and a topic_pattern')
+
+ for source_topic_name in self.source_topic_names:
+ if re.match(topic_pattern, source_topic_name):
+ raise Errors.TopologyBuilderError("Pattern " + topic_pattern + " will match a topic that has already been registered by another source.")
+
+ self.node_to_source_patterns[name] = topic_pattern
+ self.node_factories[name] = SourceNodeFactory(self, name, None, topic_pattern, key_deserializer, val_deserializer)
+ self.node_grouper.add(name)
+
+ return self
+
+ for topic in topics:
+ if topic in self.source_topic_names:
+ raise Errors.TopologyBuilderError("Topic " + topic + " has already been registered by another source.")
+
+ for pattern in self.node_to_source_patterns.values():
+ if re.match(pattern, topic):
+ raise Errors.TopologyBuilderError("Topic " + topic + " matches a Pattern already registered by another source.")
+
+ self.source_topic_names.add(topic)
+
+ self.node_factories[name] = SourceNodeFactory(self, name, topics, None, key_deserializer, val_deserializer)
+ self.node_to_source_topics[name] = list(topics)
+ self.node_grouper.add(name)
+
+ return self
+
+ def add_sink(self, name, topic, *parent_names, **kwargs):
+ """Add a named sink node that writes records to a named kafka topic.
+
+ The sink node forwards records from upstream parent processor and/or
+ source nodes to the named Kafka topic. The sink will use the specified
+ key and value serializers, and the supplied partitioner.
+
+ Arguments;
+ name (str): unique name of the sink node
+ topic (str): name of the output topic for the sink
+ parent_names (*str): one or more source or processor nodes whose
+ output records should consumed by this sink and written to
+ the output topic
+
+ Keyword Arguments:
+ key_serializer (callable): the key serializer used when consuming
+ records; if None, uses the default key serializer specified in
+ the stream configuration.
+ val_serializer (callable): the value serializer used when consuming
+ records; if None, uses the default value serializer specified
+ in the stream configuration.
+ partitioner (callable): function used to determine the partition
+ for each record processed by the sink
+
+ Raises: TopologyBuilderError if parent processor is not added yet, or
+ if this processor's name is equal to the parent's name
+
+ Returns: self, so methods can be chained together
+ """
+ key_serializer = kwargs.get('key_serializer', None)
+ val_serializer = kwargs.get('val_serializer', None)
+ partitioner = kwargs.get('partitioner', None)
+ if name in self.node_factories:
+ raise Errors.TopologyBuilderError("Processor " + name + " is already added.")
+
+ for parent in parent_names:
+ if parent == name:
+ raise Errors.TopologyBuilderError("Processor " + name + " cannot be a parent of itself.")
+ if parent not in self.node_factories:
+ raise Errors.TopologyBuilderError("Parent processor " + parent + " is not added yet.")
+
+ self.node_factories[name] = SinkNodeFactory(self, name, parent_names, topic, key_serializer, val_serializer, partitioner)
+ self.node_to_sink_topic[name] = topic
+ self.node_grouper.add(name)
+ self.node_grouper.unite(name, parent_names)
+ return self
+
+ def add_processor(self, name, supplier, *parent_names):
+ """Add a node to process consumed messages from parent nodes.
+
+ A processor node receives and processes records output by one or more
+ parent source or processor nodes. Any new record output by this
+ processor will be forwarded to its child processor or sink nodes.
+
+ Arguments:
+ name (str): unique name of the processor node
+ supplier (callable): factory function that returns a Processor
+ parent_names (*str): the name of one or more source or processor
+ nodes whose output records this processor should receive
+ and process
+
+ Returns: self (so methods can be chained together)
+
+ Raises: TopologyBuilderError if parent processor is not added yet,
+ or if this processor's name is equal to the parent's name
+ """
+ if name in self.node_factories:
+ raise Errors.TopologyBuilderError("Processor " + name + " is already added.")
+
+ for parent in parent_names:
+ if parent == name:
+ raise Errors.TopologyBuilderError("Processor " + name + " cannot be a parent of itself.")
+ if not parent in self.node_factories:
+ raise Errors.TopologyBuilderError("Parent processor " + parent + " is not added yet.")
+
+ self.node_factories[name] = ProcessorNodeFactory(self, name, parent_names, supplier)
+ self.node_grouper.add(name)
+ self.node_grouper.unite(name, parent_names)
+ return self
+
+ def add_state_store(self, supplier, *processor_names, **kwargs):
+ """Adds a state store
+
+ @param supplier the supplier used to obtain this state store {@link StateStore} instance
+ @return this builder instance so methods can be chained together; never null
+ @throws TopologyBuilderException if state store supplier is already added
+ """
+ is_internal = kwargs.get('is_internal', True)
+ if supplier.name in self.state_factories:
+ raise Errors.TopologyBuilderError("StateStore " + supplier.name + " is already added.")
+
+ self.state_factories[supplier.name] = StateStoreFactory(is_internal, supplier)
+
+ for processor_name in processor_names:
+ self.connect_processor_and_state_store(processor_name, supplier.name)
+
+ return self
+
+ def connect_processor_and_state_stores(self, processor_name, *state_store_names):
+ """
+ Connects the processor and the state stores
+
+ @param processorName the name of the processor
+ @param stateStoreNames the names of state stores that the processor uses
+ @return this builder instance so methods can be chained together; never null
+ """
+ for state_store_name in state_store_names:
+ self.connect_processor_and_state_store(processor_name, state_store_name)
+
+ return self
+
+ def connect_processors(self, *processor_names):
+ """
+ Connects a list of processors.
+
+ NOTE this function would not needed by developers working with the processor APIs, but only used
+ for the high-level DSL parsing functionalities.
+
+ @param processorNames the name of the processors
+ @return this builder instance so methods can be chained together; never null
+ @throws TopologyBuilderException if less than two processors are specified, or if one of the processors is not added yet
+ """
+ if len(processor_names) < 2:
+ raise Errors.TopologyBuilderError("At least two processors need to participate in the connection.")
+
+ for processor_name in processor_names:
+ if processor_name not in self.node_factories:
+ raise Errors.TopologyBuilderError("Processor " + processor_name + " is not added yet.")
+
+ self.node_grouper.unite(processor_names[0], processor_names[1:])
+
+ return self
+
+ def add_internal_topic(self, topic_name):
+ """
+ Adds an internal topic
+
+ @param topicName the name of the topic
+ @return this builder instance so methods can be chained together; never null
+ """
+ self.internal_topic_names.add(topic_name)
+
+ return self
+
+ def connect_processor_and_state_store(self, processor_name, state_store_name):
+ if state_store_name not in self.state_factories:
+ raise Errors.TopologyBuilderError("StateStore " + state_store_name + " is not added yet.")
+ if processor_name not in self.node_factories:
+ raise Errors.TopologyBuilderError("Processor " + processor_name + " is not added yet.")
+
+ state_store_factory = self.state_factories.get(state_store_name)
+ for user in state_store_factory.users:
+ self.node_grouper.unite(user, [processor_name])
+ state_store_factory.users.add(processor_name)
+
+ node_factory = self.node_factories.get(processor_name)
+ if isinstance(node_factory, ProcessorNodeFactory):
+ node_factory.add_state_store(state_store_name)
+ else:
+ raise Errors.TopologyBuilderError("cannot connect a state store " + state_store_name + " to a source node or a sink node.")
+
+ def topic_groups(self):
+ """
+ Returns the map of topic groups keyed by the group id.
+ A topic group is a group of topics in the same task.
+
+ @return groups of topic names
+ """
+ topic_groups = {}
+
+ if self.subscription_updates:
+ for name in self.node_to_source_patterns:
+ source_node = self.node_factories[name]
+ # need to update node_to_source_topics with topics matched from given regex
+ self.node_to_source_topics[name] = source_node.get_topics(self.subscription_updates)
+
+ if self._node_groups is None:
+ self._node_groups = self.make_node_groups()
+
+ for group_id, nodes in self._node_groups.items():
+ sink_topics = set()
+ source_topics = set()
+ internal_source_topics = set()
+ state_changelog_topics = set()
+ for node in nodes:
+ # if the node is a source node, add to the source topics
+ topics = self.node_to_source_topics.get(node)
+ if topics:
+ # if some of the topics are internal, add them to the internal topics
+ for topic in topics:
+ if topic in self.internal_topic_names:
+ if self.application_id is None:
+ raise Errors.TopologyBuilderError("There are internal topics and"
+ " applicationId hasn't been "
+ "set. Call setApplicationId "
+ "first")
+ # prefix the internal topic name with the application id
+ internal_topic = self.application_id + "-" + topic
+ internal_source_topics.add(internal_topic)
+ source_topics.add(internal_topic)
+ else:
+ source_topics.add(topic)
+
+ # if the node is a sink node, add to the sink topics
+ topic = self.node_to_sink_topic.get(node)
+ if topic:
+ if topic in self.internal_topic_names:
+ # prefix the change log topic name with the application id
+ sink_topics.add(self.application_id + "-" + topic)
+ else:
+ sink_topics.add(topic)
+
+ # if the node is connected to a state, add to the state topics
+ for state_factory in self.state_factories.values():
+ if state_factory.is_internal and node in state_factory.users:
+ # prefix the change log topic name with the application id
+ state_changelog_topics.add(self.application_id + "-" + state_factory.supplier.name + STATE_CHANGELOG_TOPIC_SUFFIX)
+
+ topic_groups[group_id] = TopicsInfo(
+ self,
+ sink_topics,
+ source_topics,
+ internal_source_topics,
+ state_changelog_topics)
+
+ return topic_groups
+
+ def node_groups(self):
+ """
+ Returns the map of node groups keyed by the topic group id.
+
+ @return groups of node names
+ """
+ if self._node_groups is None:
+ self._node_groups = self.make_node_groups()
+
+ return self._node_groups
+
+ def make_node_groups(self):
+ node_groups = {}
+ root_to_node_group = {}
+
+ node_group_id = 0
+
+ # Go through source nodes first. This makes the group id assignment easy to predict in tests
+ for node_name in sorted(self.node_to_source_topics):
+ root = self.node_grouper.root(node_name)
+ node_group = root_to_node_group.get(root)
+ if node_group is None:
+ node_group = set()
+ root_to_node_group[root] = node_group
+ node_group_id += 1
+ node_groups[node_group_id] = node_group
+ node_group.add(node_name)
+
+ # Go through non-source nodes
+ for node_name in sorted(self.node_factories):
+ if node_name not in self.node_to_source_topics:
+ root = self.node_grouper.root(node_name)
+ node_group = root_to_node_group.get(root)
+ if node_group is None:
+ node_group = set()
+ root_to_node_group[root] = node_group
+ node_group_id += 1
+ node_groups[node_group_id] = node_group
+ node_group.add(node_name)
+
+ return node_groups
+
+ def copartition_sources(self, source_nodes):
+ """
+ Asserts that the streams of the specified source nodes must be copartitioned.
+
+ @param sourceNodes a set of source node names
+ @return this builder instance so methods can be chained together; never null
+ """
+ self.copartition_source_groups.append(source_nodes)
+ return self
+
+ def copartition_groups(self):
+ """
+ Returns the copartition groups.
+ A copartition group is a group of source topics that are required to be copartitioned.
+
+ @return groups of topic names
+ """
+ groups = []
+ for node_names in self.copartition_source_groups:
+ copartition_group = set()
+ for node in node_names:
+ topics = self.node_to_source_topics.get(node)
+ if topics:
+ copartition_group.update(self.convert_internal_topic_names(topics))
+ groups.append(copartition_group)
+ return groups
+
+ def convert_internal_topic_names(self, *topics):
+ topic_names = []
+ for topic in topics:
+ if topic in self.internal_topic_names:
+ if self.application_id is None:
+ raise Errors.TopologyBuilderError("there are internal topics "
+ "and applicationId hasn't been set. Call "
+ "setApplicationId first")
+ topic_names.append(self.application_id + "-" + topic)
+ else:
+ topic_names.append(topic)
+ return topic_names
+
+ def build(self, application_id, topic_group_id=None, node_group=None):
+ """
+ Build the topology for the specified topic group. This is called automatically when passing this builder into the
+ {@link org.apache.kafka.streams.KafkaStreams#KafkaStreams(TopologyBuilder, org.apache.kafka.streams.StreamsConfig)} constructor.
+
+ @see org.apache.kafka.streams.KafkaStreams#KafkaStreams(TopologyBuilder, org.apache.kafka.streams.StreamsConfig)
+ """
+ if topic_group_id is not None:
+ node_group = None
+ if topic_group_id is not None:
+ node_group = self.node_groups().get(topic_group_id)
+ else:
+ # when nodeGroup is null, we build the full topology. this is used in some tests.
+ node_group = None
+
+ processor_nodes = []
+ processor_map = {}
+ topic_source_map = {}
+ state_store_map = {}
+
+ # create processor nodes in a topological order ("nodeFactories" is already topologically sorted)
+ for factory in self.node_factories.values():
+ if node_group is None or factory.name in node_group:
+ node = factory.build(application_id)
+ processor_nodes.append(node)
+ processor_map[node.name] = node
+
+ if isinstance(factory, ProcessorNodeFactory):
+ for parent in factory.parents:
+ processor_map[parent].add_child(node)
+ for state_store_name in factory.state_store_names:
+ if state_store_name not in state_store_map:
+ state_store_map[state_store_name] = self.state_factories[state_store_name].supplier
+ elif isinstance(factory, SourceNodeFactory):
+ if factory.pattern is not None:
+ topics = factory.get_topics(self.subscription_updates)
+ else:
+ topics = factory.get_topics()
+ for topic in topics:
+ if topic in self.internal_topic_names:
+ # prefix the internal topic name with the application id
+ topic_source_map[application_id + "-" + topic] = node
+ else:
+ topic_source_map[topic] = node
+ elif isinstance(factory, SinkNodeFactory):
+ for parent in factory.parents:
+ processor_map[parent].add_child(node)
+ else:
+ raise Errors.TopologyBuilderError("Unknown definition class: " + factory.__class__.__name__)
+
+ return ProcessorTopology(processor_nodes, topic_source_map, state_store_map.values())
+
+ def source_topics(self):
+ """
+ Get the names of topics that are to be consumed by the source nodes created by this builder.
+ @return the unmodifiable set of topic names used by source nodes, which changes as new sources are added; never null
+ """
+ topics = set()
+ for topic in self.source_topic_names:
+ if topic in self.internal_topic_names:
+ if self.application_id is None:
+ raise Errors.TopologyBuilderError("there are internal topics and "
+ "applicationId is null. Call "
+ "setApplicationId before sourceTopics")
+ topics.add(self.application_id + "-" + topic)
+ else:
+ topics.add(topic)
+ return topics
+
+ def source_topic_pattern(self):
+ if self.topic_pattern is None and self.node_to_source_patterns:
+ topic_pattern = ''
+ for pattern in self.node_to_source_patterns.values():
+ topic_pattern += pattern
+ topic_pattern += "|"
+ if self.node_to_source_topics:
+ for topics in self.node_to_source_topics.values():
+ for topic in topics:
+ topic_pattern += topic
+ topic_pattern += "|"
+ self.topic_pattern = topic_pattern[:-1]
+ return self.topic_pattern;
+
+ def set_application_id(self, application_id):
+ """
+ Set the applicationId. This is required before calling
+ {@link #sourceTopics}, {@link #topicGroups} and {@link #copartitionSources}
+ @param applicationId the streams applicationId. Should be the same as set by
+ {@link org.apache.kafka.streams.StreamsConfig#APPLICATION_ID_CONFIG}
+ """
+ self.application_id = application_id
diff --git a/kafka/streams/utils.py b/kafka/streams/utils.py
new file mode 100644
index 0000000..b0161dc
--- /dev/null
+++ b/kafka/streams/utils.py
@@ -0,0 +1,20 @@
+import threading
+
+
+class AtomicInteger(object):
+ def __init__(self, val=0):
+ self._lock = threading.Lock()
+ self._val = val
+
+ def increment(self):
+ with self._lock:
+ self._val += 1
+ return self._val
+
+ def decrement(self):
+ with self._lock:
+ self._val -= 1
+ return self._val
+
+ def get(self):
+ return self._val