diff options
author | Michael Case <mikecase@google.com> | 2018-02-07 14:36:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-07 14:39:49 -0800 |
commit | d90054e7c0f41f4bab81df0548577a73b939a87a (patch) | |
tree | a15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/contrib/kafka | |
parent | 8461760f9f6cde8ed97507484d2a879140141032 (diff) |
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/contrib/kafka')
-rw-r--r-- | tensorflow/contrib/kafka/BUILD | 105 | ||||
-rw-r--r-- | tensorflow/contrib/kafka/__init__.py | 32 | ||||
-rw-r--r-- | tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc | 321 | ||||
-rw-r--r-- | tensorflow/contrib/kafka/ops/kafka_ops.cc | 44 | ||||
-rw-r--r-- | tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py | 115 | ||||
-rw-r--r-- | tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh | 48 | ||||
-rw-r--r-- | tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py | 74 |
7 files changed, 739 insertions, 0 deletions
diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD new file mode 100644 index 0000000000..efb403462a --- /dev/null +++ b/tensorflow/contrib/kafka/BUILD @@ -0,0 +1,105 @@ +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +tf_kernel_library( + name = "kafka_kernels", + srcs = ["kernels/kafka_dataset_ops.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:bounds_check_lib", + "//tensorflow/core/kernels:dataset", + "//third_party/eigen3", + "@kafka", + ], +) + +tf_gen_op_libs( + op_lib_names = ["kafka_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_kafka_ops", + out = "python/ops/gen_kafka_ops.py", + require_shape_functions = True, + deps = [":kafka_ops_op_lib"], +) + +py_library( + name = "kafka", + srcs = [ + "__init__.py", + "python/ops/kafka_dataset_ops.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":gen_kafka_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + ], +) + +# The Kafka server has to be setup before running the test. +# The Kafka server is setup through Docker so the Docker engine +# has to be installed. +# +# Once the Docker engine is ready: +# To setup the Kafka server: +# $ bash tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh start kafka +# +# After the test is complete: +# To team down the Kafka server: +# $ bash tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh stop kafka +tf_py_test( + name = "kafka_test", + srcs = ["python/kernel_tests/kafka_test.py"], + additional_deps = [ + ":kafka", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "notap", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/kafka/__init__.py b/tensorflow/contrib/kafka/__init__.py new file mode 100644 index 0000000000..4d755c4056 --- /dev/null +++ b/tensorflow/contrib/kafka/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kafka Dataset. + +@@KafkaDataset +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kafka.python.ops.kafka_dataset_ops import KafkaDataset + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "KafkaDataset", +] + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc new file mode 100644 index 0000000000..88ef5f3571 --- /dev/null +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -0,0 +1,321 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/dataset.h" + +#include "tensorflow/core/framework/tensor.h" + +#include "src-cpp/rdkafkacpp.h" + +namespace tensorflow { + +class KafkaDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* topics_tensor; + OP_REQUIRES_OK(ctx, ctx->input("topics", &topics_tensor)); + OP_REQUIRES( + ctx, topics_tensor->dims() <= 1, + errors::InvalidArgument("`topics` must be a scalar or a vector.")); + + std::vector<string> topics; + topics.reserve(topics_tensor->NumElements()); + for (int i = 0; i < topics_tensor->NumElements(); ++i) { + topics.push_back(topics_tensor->flat<string>()(i)); + } + + std::string servers = ""; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<std::string>(ctx, "servers", &servers)); + std::string group = ""; + OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "group", &group)); + bool eof = false; + OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "eof", &eof)); + int64 timeout = -1; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "timeout", &timeout)); + OP_REQUIRES(ctx, (timeout > 0), + errors::InvalidArgument( + "Timeout value should be large than 0, got ", timeout)); + *output = new Dataset(ctx, std::move(topics), servers, group, eof, timeout); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector<string> topics, + const string& servers, const string& group, const bool eof, + const int64 timeout) + : GraphDatasetBase(ctx), + topics_(std::move(topics)), + servers_(servers), + group_(group), + eof_(eof), + timeout_(timeout) {} + + std::unique_ptr<IteratorBase> MakeIterator( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Kafka")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}}); + return *shapes; + } + + string DebugString() override { return "KafkaDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + Node* topics = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics)); + Node* servers = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(servers_, &servers)); + Node* group = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(group_, &group)); + Node* eof = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(eof_, &eof)); + Node* timeout = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(timeout_, &timeout)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {topics, servers, group, eof, timeout}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + // We are currently processing a topic, so try to read the next line. + if (consumer_.get()) { + while (true) { + if (limit_ >= 0 && + (topic_partition_->offset() >= limit_ || offset_ >= limit_)) { + // EOF current topic + break; + } + std::unique_ptr<RdKafka::Message> message( + consumer_->consume(dataset()->timeout_)); + if (message->err() == RdKafka::ERR_NO_ERROR) { + // Produce the line as output. + Tensor line_tensor(cpu_allocator(), DT_STRING, {}); + line_tensor.scalar<string>()() = + std::string(static_cast<const char*>(message->payload()), + message->len()); + out_tensors->emplace_back(std::move(line_tensor)); + *end_of_sequence = false; + // Sync offset + offset_ = message->offset(); + return Status::OK(); + } + + if (message->err() == RdKafka::ERR__PARTITION_EOF && + dataset()->eof_) { + // EOF current topic + break; + } + if (message->err() != RdKafka::ERR__TIMED_OUT) { + return errors::Internal("Failed to consume:", + message->errstr()); + } + message.reset(nullptr); + consumer_->poll(0); + } + + // We have reached the end of the current topic, so maybe + // move on to next topic. + ResetStreamsLocked(); + ++current_topic_index_; + } + + // Iteration ends when there are no more topic to process. + if (current_topic_index_ == dataset()->topics_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_topic_index"), + current_topic_index_)); + + // `consumer_` is empty if + // 1. GetNext has not been called even once. + // 2. All topics have been read and iterator has been exhausted. + if (consumer_.get()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("current_pos"), offset_)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + ResetStreamsLocked(); + int64 current_topic_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_topic_index"), + ¤t_topic_index)); + current_topic_index_ = size_t(current_topic_index); + // The key "current_pos" is written only if the iterator was saved + // with an open topic. + if (reader->Contains(full_name("current_pos"))) { + int64 current_pos; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("current_pos"), ¤t_pos)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + topic_partition_->set_offset(current_pos); + if (topic_partition_->offset() != current_pos) { + return errors::Internal("Failed to restore to offset ", + current_pos); + } + offset_ = current_pos; + } + return Status::OK(); + } + + private: + // Sets up Kafka streams to read from the topic at + // `current_topic_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_topic_index_ >= dataset()->topics_.size()) { + return errors::InvalidArgument( + "current_topic_index_:", current_topic_index_, + " >= topics_.size():", dataset()->topics_.size()); + } + + // Actually move on to next topic. + string entry = dataset()->topics_[current_topic_index_]; + + std::vector<string> parts = str_util::Split(entry, ":"); + if (parts.size() < 1) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + string topic = parts[0]; + int32 partition = 0; + if (parts.size() > 1) { + if (!strings::safe_strto32(parts[1], &partition)) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + } + int64 offset = 0; + if (parts.size() > 2) { + if (!strings::safe_strto64(parts[2], &offset)) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + } + + topic_partition_.reset( + RdKafka::TopicPartition::create(topic, partition, offset)); + + offset_ = topic_partition_->offset(); + limit_ = -1; + if (parts.size() > 3) { + if (!strings::safe_strto64(parts[3], &limit_)) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + } + + std::unique_ptr<RdKafka::Conf> conf( + RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL)); + std::unique_ptr<RdKafka::Conf> topic_conf( + RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC)); + + std::string errstr; + + RdKafka::Conf::ConfResult result = + conf->set("default_topic_conf", topic_conf.get(), errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("Failed to set default_topic_conf:", errstr); + } + + result = conf->set("bootstrap.servers", dataset()->servers_, errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("Failed to set bootstrap.servers ", + dataset()->servers_, ":", errstr); + } + result = conf->set("group.id", dataset()->group_, errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("Failed to set group.id ", dataset()->group_, + ":", errstr); + } + + consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr)); + if (!consumer_.get()) { + return errors::Internal("Failed to create consumer:", errstr); + } + + std::vector<RdKafka::TopicPartition*> partitions; + partitions.emplace_back(topic_partition_.get()); + RdKafka::ErrorCode err = consumer_->assign(partitions); + if (err != RdKafka::ERR_NO_ERROR) { + return errors::Internal( + "Failed to assign partition [", topic_partition_->topic(), ", ", + topic_partition_->partition(), ", ", topic_partition_->offset(), + "]:", RdKafka::err2str(err)); + } + + return Status::OK(); + } + + // Resets all Kafka streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + consumer_->unassign(); + consumer_->close(); + consumer_.reset(nullptr); + } + + mutex mu_; + size_t current_topic_index_ GUARDED_BY(mu_) = 0; + int64 offset_ GUARDED_BY(mu_) = 0; + int64 limit_ GUARDED_BY(mu_) = -1; + std::unique_ptr<RdKafka::TopicPartition> topic_partition_ GUARDED_BY(mu_); + std::unique_ptr<RdKafka::KafkaConsumer> consumer_ GUARDED_BY(mu_); + }; + + const std::vector<string> topics_; + const std::string servers_; + const std::string group_; + const bool eof_; + const int64 timeout_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("KafkaDataset").Device(DEVICE_CPU), + KafkaDatasetOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/kafka_ops.cc new file mode 100644 index 0000000000..8cdf16103b --- /dev/null +++ b/tensorflow/contrib/kafka/ops/kafka_ops.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("KafkaDataset") + .Input("topics: string") + .Input("servers: string") + .Input("group: string") + .Input("eof: bool") + .Input("timeout: int64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that emits the messages of one or more Kafka topics. + +topics: A `tf.string` tensor containing one or more subscriptions, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. +servers: A list of bootstrap servers. +group: The consumer group id. +eof: If True, the kafka reader will stop on EOF. +timeout: The timeout value for the Kafka Consumer to wait + (in millisecond). +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py new file mode 100644 index 0000000000..621911876f --- /dev/null +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py @@ -0,0 +1,115 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for KafkaDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class KafkaDatasetTest(test.TestCase): + + def setUp(self): + # The Kafka server has to be setup before the test + # and tear down after the test manually. + # The docker engine has to be installed. + # + # To setup the Kafka server: + # $ bash kafka_test.sh start kafka + # + # To team down the Kafka server: + # $ bash kafka_test.sh stop kafka + pass + + def testKafkaDataset(self): + topics = array_ops.placeholder(dtypes.string, shape=[None]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_dataset_ops.KafkaDataset( + topics, group="test", eof=True).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.test_session() as sess: + # Basic test: read from topic 0. + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual("D" + str(i), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from topic 1. + sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual("D" + str(i + 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from both topics. + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 1 + }) + for j in range(2): + for i in range(5): + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both files. + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10 + }) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both files. + sess.run( + init_batch_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10, + batch_size: 5 + }) + for _ in range(10): + self.assertAllEqual(["D" + str(i) for i in range(5)], + sess.run(get_next)) + self.assertAllEqual(["D" + str(i + 5) for i in range(5)], + sess.run(get_next)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh new file mode 100644 index 0000000000..adf027b8e7 --- /dev/null +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +set -o pipefail + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 start|stop <kafka container name>" >&2 + exit 1 +fi + +container=$2 +if [ "$1" == "start" ]; then + docker run -d --rm --net=host --name=$container spotify/kafka + echo Wait 5 secs until kafka is up and running + sleep 5 + echo Create test topic + docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic test' + echo Create test message + docker exec $container bash -c 'echo -e "D0\nD1\nD2\nD3\nD4\nD5\nD6\nD7\nD8\nD9" > /test' + echo Produce test message + docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-console-producer.sh --topic test --broker-list 127.0.0.1:9092 < /test' + + echo Container $container started successfully +elif [ "$1" == "stop" ]; then + docker rm -f $container + + echo Container $container stopped successfully +else + echo "Usage: $0 start|stop <kafka container name>" >&2 + exit 1 +fi + + + diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py new file mode 100644 index 0000000000..8e51d27a34 --- /dev/null +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kafka Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kafka.python.ops import gen_kafka_ops +from tensorflow.python.data.ops.readers import Dataset +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class KafkaDataset(Dataset): + """A Kafka Dataset that consumes the message. + """ + + def __init__(self, + topics, + servers="localhost", + group="", + eof=False, + timeout=1000): + """Create a KafkaReader. + + Args: + topics: A `tf.string` tensor containing one or more subscriptions, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. + servers: A list of bootstrap servers. + group: The consumer group id. + eof: If True, the kafka reader will stop on EOF. + timeout: The timeout value for the Kafka Consumer to wait + (in millisecond). + """ + super(KafkaDataset, self).__init__() + self._topics = ops.convert_to_tensor( + topics, dtype=dtypes.string, name="topics") + self._servers = ops.convert_to_tensor( + servers, dtype=dtypes.string, name="servers") + self._group = ops.convert_to_tensor( + group, dtype=dtypes.string, name="group") + self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name="eof") + self._timeout = ops.convert_to_tensor( + timeout, dtype=dtypes.int64, name="timeout") + + def _as_variant_tensor(self): + return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group, + self._eof, self._timeout) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.string |