diff options
Diffstat (limited to 'tensorflow/contrib/kafka/python')
3 files changed, 237 insertions, 0 deletions
diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py new file mode 100644 index 0000000000..621911876f --- /dev/null +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py @@ -0,0 +1,115 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for KafkaDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class KafkaDatasetTest(test.TestCase): + + def setUp(self): + # The Kafka server has to be setup before the test + # and tear down after the test manually. + # The docker engine has to be installed. + # + # To setup the Kafka server: + # $ bash kafka_test.sh start kafka + # + # To team down the Kafka server: + # $ bash kafka_test.sh stop kafka + pass + + def testKafkaDataset(self): + topics = array_ops.placeholder(dtypes.string, shape=[None]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_dataset_ops.KafkaDataset( + topics, group="test", eof=True).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.test_session() as sess: + # Basic test: read from topic 0. + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual("D" + str(i), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from topic 1. + sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual("D" + str(i + 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from both topics. + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 1 + }) + for j in range(2): + for i in range(5): + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both files. + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10 + }) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both files. + sess.run( + init_batch_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10, + batch_size: 5 + }) + for _ in range(10): + self.assertAllEqual(["D" + str(i) for i in range(5)], + sess.run(get_next)) + self.assertAllEqual(["D" + str(i + 5) for i in range(5)], + sess.run(get_next)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh new file mode 100644 index 0000000000..adf027b8e7 --- /dev/null +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +set -o pipefail + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 start|stop <kafka container name>" >&2 + exit 1 +fi + +container=$2 +if [ "$1" == "start" ]; then + docker run -d --rm --net=host --name=$container spotify/kafka + echo Wait 5 secs until kafka is up and running + sleep 5 + echo Create test topic + docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic test' + echo Create test message + docker exec $container bash -c 'echo -e "D0\nD1\nD2\nD3\nD4\nD5\nD6\nD7\nD8\nD9" > /test' + echo Produce test message + docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-console-producer.sh --topic test --broker-list 127.0.0.1:9092 < /test' + + echo Container $container started successfully +elif [ "$1" == "stop" ]; then + docker rm -f $container + + echo Container $container stopped successfully +else + echo "Usage: $0 start|stop <kafka container name>" >&2 + exit 1 +fi + + + diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py new file mode 100644 index 0000000000..8e51d27a34 --- /dev/null +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kafka Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kafka.python.ops import gen_kafka_ops +from tensorflow.python.data.ops.readers import Dataset +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class KafkaDataset(Dataset): + """A Kafka Dataset that consumes the message. + """ + + def __init__(self, + topics, + servers="localhost", + group="", + eof=False, + timeout=1000): + """Create a KafkaReader. + + Args: + topics: A `tf.string` tensor containing one or more subscriptions, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. + servers: A list of bootstrap servers. + group: The consumer group id. + eof: If True, the kafka reader will stop on EOF. + timeout: The timeout value for the Kafka Consumer to wait + (in millisecond). + """ + super(KafkaDataset, self).__init__() + self._topics = ops.convert_to_tensor( + topics, dtype=dtypes.string, name="topics") + self._servers = ops.convert_to_tensor( + servers, dtype=dtypes.string, name="servers") + self._group = ops.convert_to_tensor( + group, dtype=dtypes.string, name="group") + self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name="eof") + self._timeout = ops.convert_to_tensor( + timeout, dtype=dtypes.int64, name="timeout") + + def _as_variant_tensor(self): + return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group, + self._eof, self._timeout) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.string |