aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kafka/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/kafka/python')
-rw-r--r--tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py115
-rw-r--r--tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh48
-rw-r--r--tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py74
3 files changed, 237 insertions, 0 deletions
diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
new file mode 100644
index 0000000000..621911876f
--- /dev/null
+++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
@@ -0,0 +1,115 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for KafkaDataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class KafkaDatasetTest(test.TestCase):
+
+ def setUp(self):
+ # The Kafka server has to be setup before the test
+ # and tear down after the test manually.
+ # The docker engine has to be installed.
+ #
+ # To setup the Kafka server:
+ # $ bash kafka_test.sh start kafka
+ #
+ # To team down the Kafka server:
+ # $ bash kafka_test.sh stop kafka
+ pass
+
+ def testKafkaDataset(self):
+ topics = array_ops.placeholder(dtypes.string, shape=[None])
+ num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = kafka_dataset_ops.KafkaDataset(
+ topics, group="test", eof=True).repeat(num_epochs)
+ batch_dataset = repeat_dataset.batch(batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ init_op = iterator.make_initializer(repeat_dataset)
+ init_batch_op = iterator.make_initializer(batch_dataset)
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ # Basic test: read from topic 0.
+ sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1})
+ for i in range(5):
+ self.assertEqual("D" + str(i), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Basic test: read from topic 1.
+ sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1})
+ for i in range(5):
+ self.assertEqual("D" + str(i + 5), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Basic test: read from both topics.
+ sess.run(
+ init_op,
+ feed_dict={
+ topics: ["test:0:0:4", "test:0:5:-1"],
+ num_epochs: 1
+ })
+ for j in range(2):
+ for i in range(5):
+ self.assertEqual("D" + str(i + j * 5), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test repeated iteration through both files.
+ sess.run(
+ init_op,
+ feed_dict={
+ topics: ["test:0:0:4", "test:0:5:-1"],
+ num_epochs: 10
+ })
+ for _ in range(10):
+ for j in range(2):
+ for i in range(5):
+ self.assertEqual("D" + str(i + j * 5), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test batched and repeated iteration through both files.
+ sess.run(
+ init_batch_op,
+ feed_dict={
+ topics: ["test:0:0:4", "test:0:5:-1"],
+ num_epochs: 10,
+ batch_size: 5
+ })
+ for _ in range(10):
+ self.assertAllEqual(["D" + str(i) for i in range(5)],
+ sess.run(get_next))
+ self.assertAllEqual(["D" + str(i + 5) for i in range(5)],
+ sess.run(get_next))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh
new file mode 100644
index 0000000000..adf027b8e7
--- /dev/null
+++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh
@@ -0,0 +1,48 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+set -o pipefail
+
+if [ "$#" -ne 2 ]; then
+ echo "Usage: $0 start|stop <kafka container name>" >&2
+ exit 1
+fi
+
+container=$2
+if [ "$1" == "start" ]; then
+ docker run -d --rm --net=host --name=$container spotify/kafka
+ echo Wait 5 secs until kafka is up and running
+ sleep 5
+ echo Create test topic
+ docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic test'
+ echo Create test message
+ docker exec $container bash -c 'echo -e "D0\nD1\nD2\nD3\nD4\nD5\nD6\nD7\nD8\nD9" > /test'
+ echo Produce test message
+ docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-console-producer.sh --topic test --broker-list 127.0.0.1:9092 < /test'
+
+ echo Container $container started successfully
+elif [ "$1" == "stop" ]; then
+ docker rm -f $container
+
+ echo Container $container stopped successfully
+else
+ echo "Usage: $0 start|stop <kafka container name>" >&2
+ exit 1
+fi
+
+
+
diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
new file mode 100644
index 0000000000..8e51d27a34
--- /dev/null
+++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
@@ -0,0 +1,74 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kafka Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kafka.python.ops import gen_kafka_ops
+from tensorflow.python.data.ops.readers import Dataset
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class KafkaDataset(Dataset):
+ """A Kafka Dataset that consumes the message.
+ """
+
+ def __init__(self,
+ topics,
+ servers="localhost",
+ group="",
+ eof=False,
+ timeout=1000):
+ """Create a KafkaReader.
+
+ Args:
+ topics: A `tf.string` tensor containing one or more subscriptions,
+ in the format of [topic:partition:offset:length],
+ by default length is -1 for unlimited.
+ servers: A list of bootstrap servers.
+ group: The consumer group id.
+ eof: If True, the kafka reader will stop on EOF.
+ timeout: The timeout value for the Kafka Consumer to wait
+ (in millisecond).
+ """
+ super(KafkaDataset, self).__init__()
+ self._topics = ops.convert_to_tensor(
+ topics, dtype=dtypes.string, name="topics")
+ self._servers = ops.convert_to_tensor(
+ servers, dtype=dtypes.string, name="servers")
+ self._group = ops.convert_to_tensor(
+ group, dtype=dtypes.string, name="group")
+ self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name="eof")
+ self._timeout = ops.convert_to_tensor(
+ timeout, dtype=dtypes.int64, name="timeout")
+
+ def _as_variant_tensor(self):
+ return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group,
+ self._eof, self._timeout)
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.string