aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kinesis
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-07-02 07:41:42 -0700
committerGravatar Derek Murray <derek.murray@gmail.com>2018-07-02 07:41:42 -0700
commita7b7aa856f34bf2e44fbeb91d817742c61483618 (patch)
treec2fcf698127c4d330e7ba629ec4ae512efbd5b1a /tensorflow/contrib/kinesis
parent7e8927e7af0c51ac20a63bd4eab6ff83df1a39ae (diff)
Add KinesisDataset support for tensorflow Dataset (#19712)
* Add KinesisDataset support for tensorflow Dataset This fix is an attempt to add Kinesis support for tensorflow's Dataset. Kinesis is provided by AWS as a managed data streaming service. It is similiar to Apache Kafka, often used in places where maintaining a independent Kafka cluster on AWS is not desirable or possible. This fix adds the Kinesis support for tensorflow Dataset. Similiar to the Kafka integration in tensorflow, KinesisDataset outputs tf.string for records. Test cases have also been added, which could be invoked manually. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Expose KinesisDataset in dataset_ops.cc Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Expose KinesisDataset in python wrapper Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for KinesisDataset Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Update AWS library include files Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add Bazel BUILD files Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Rename s3_crypto to aws_crypto Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Rename with_s3_support to with_aws_support Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Selectively add kinesis to tensorflow/contrib/BUILD Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Set different partition key and pylint fix. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add missing modules in cmake's python_modules.txt Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Address review feedback Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/kinesis')
-rw-r--r--tensorflow/contrib/kinesis/BUILD113
-rw-r--r--tensorflow/contrib/kinesis/__init__.py32
-rw-r--r--tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc359
-rw-r--r--tensorflow/contrib/kinesis/ops/dataset_ops.cc42
-rw-r--r--tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py139
-rw-r--r--tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py96
-rw-r--r--tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py24
7 files changed, 805 insertions, 0 deletions
diff --git a/tensorflow/contrib/kinesis/BUILD b/tensorflow/contrib/kinesis/BUILD
new file mode 100644
index 0000000000..25443d0ad4
--- /dev/null
+++ b/tensorflow/contrib/kinesis/BUILD
@@ -0,0 +1,113 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_py_test",
+)
+
+py_library(
+ name = "kinesis",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = ["ops/dataset_ops.cc"],
+ deps = [":dataset_kernels"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["dataset_ops"],
+)
+
+cc_library(
+ name = "dataset_kernels",
+ srcs = [
+ "kernels/kinesis_dataset_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core/platform/s3:aws_crypto",
+ "//third_party/eigen3",
+ "@aws",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+py_library(
+ name = "dataset_ops",
+ srcs = [
+ "python/ops/kinesis_dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":kinesis_op_loader",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_dataset_ops",
+ out = "python/ops/gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/kinesis:dataset_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "dataset_ops_kernels",
+ deps = [
+ ":dataset_kernels",
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "kinesis_op_loader",
+ srcs = ["python/ops/kinesis_op_loader.py"],
+ dso = ["//tensorflow/contrib/kinesis:_dataset_ops.so"],
+ kernels = [
+ ":dataset_ops_kernels",
+ "//tensorflow/contrib/kinesis:dataset_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_dataset_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+tf_py_test(
+ name = "kinesis_test",
+ srcs = ["python/kernel_tests/kinesis_test.py"],
+ additional_deps = [
+ ":kinesis",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "no_windows",
+ "notap",
+ ],
+)
diff --git a/tensorflow/contrib/kinesis/__init__.py b/tensorflow/contrib/kinesis/__init__.py
new file mode 100644
index 0000000000..3824b8ae75
--- /dev/null
+++ b/tensorflow/contrib/kinesis/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kinesis Dataset.
+
+@@KinesisDataset
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kinesis.python.ops.kinesis_dataset_ops import KinesisDataset
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "KinesisDataset",
+]
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc
new file mode 100644
index 0000000000..3212279c4c
--- /dev/null
+++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc
@@ -0,0 +1,359 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <aws/core/Aws.h>
+#include <aws/core/config/AWSProfileConfigLoader.h>
+#include <aws/core/utils/Outcome.h>
+#include <aws/kinesis/KinesisClient.h>
+#include <aws/kinesis/model/DescribeStreamRequest.h>
+#include <aws/kinesis/model/GetRecordsRequest.h>
+#include <aws/kinesis/model/GetShardIteratorRequest.h>
+#include <aws/kinesis/model/PutRecordsRequest.h>
+#include <aws/kinesis/model/ShardIteratorType.h>
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/platform/s3/aws_crypto.h"
+
+namespace tensorflow {
+namespace {
+
+Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() {
+ static Aws::Client::ClientConfiguration config;
+ const char* endpoint = getenv("KINESIS_ENDPOINT");
+ if (endpoint) {
+ config.endpointOverride = Aws::String(endpoint);
+ }
+ const char* region = getenv("AWS_REGION");
+ if (region) {
+ config.region = Aws::String(region);
+ } else {
+ // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
+ // is set with a truthy value.
+ const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
+ string load_config =
+ load_config_env ? str_util::Lowercase(load_config_env) : "";
+ if (load_config == "true" || load_config == "1") {
+ Aws::String config_file;
+ // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
+ const char* config_file_env = getenv("AWS_CONFIG_FILE");
+ if (config_file_env) {
+ config_file = config_file_env;
+ } else {
+ const char* home_env = getenv("HOME");
+ if (home_env) {
+ config_file = home_env;
+ config_file += "/.aws/config";
+ }
+ }
+ Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
+ // Load the configuration. If successful, get the region.
+ // If the load is not successful, then generate a warning.
+ if (loader.Load()) {
+ auto profiles = loader.GetProfiles();
+ if (!profiles["default"].GetRegion().empty()) {
+ config.region = profiles["default"].GetRegion();
+ }
+ } else {
+ LOG(WARNING) << "Failed to load the profile in " << config_file << ".";
+ }
+ }
+ }
+ const char* use_https = getenv("KINESIS_USE_HTTPS");
+ if (use_https) {
+ if (use_https[0] == '0') {
+ config.scheme = Aws::Http::Scheme::HTTP;
+ } else {
+ config.scheme = Aws::Http::Scheme::HTTPS;
+ }
+ }
+ const char* verify_ssl = getenv("KINESIS_VERIFY_SSL");
+ if (verify_ssl) {
+ if (verify_ssl[0] == '0') {
+ config.verifySSL = false;
+ } else {
+ config.verifySSL = true;
+ }
+ }
+ const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC");
+ if (connect_timeout) {
+ int64 timeout;
+
+ if (strings::safe_strto64(connect_timeout, &timeout)) {
+ config.connectTimeoutMs = timeout;
+ }
+ }
+ const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC");
+ if (request_timeout) {
+ int64 timeout;
+
+ if (strings::safe_strto64(request_timeout, &timeout)) {
+ config.requestTimeoutMs = timeout;
+ }
+ }
+
+ return &config;
+}
+
+Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
+ static Aws::Client::ClientConfiguration* config =
+ InitializeDefaultClientConfig();
+ return *config;
+}
+
+static mutex mu(LINKER_INITIALIZED);
+static unsigned count(0);
+void AwsInitAPI() {
+ mutex_lock lock(mu);
+ count++;
+ if (count == 1) {
+ Aws::SDKOptions options;
+ options.cryptoOptions.sha256Factory_create_fn = []() {
+ return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
+ };
+ options.cryptoOptions.sha256HMACFactory_create_fn = []() {
+ return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
+ };
+ Aws::InitAPI(options);
+ }
+}
+void AwsShutdownAPI() {
+ mutex_lock lock(mu);
+ count--;
+ if (count == 0) {
+ Aws::SDKOptions options;
+ Aws::ShutdownAPI(options);
+ }
+}
+void ShutdownClient(Aws::Kinesis::KinesisClient* client) {
+ if (client != nullptr) {
+ delete client;
+ AwsShutdownAPI();
+ }
+}
+}
+class KinesisDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ std::string stream = "";
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<std::string>(ctx, "stream", &stream));
+ std::string shard = "";
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "shard", &shard));
+ bool read_indefinitely = true;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "read_indefinitely",
+ &read_indefinitely));
+ int64 interval = -1;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "interval", &interval));
+ OP_REQUIRES(ctx, (interval > 0),
+ errors::InvalidArgument(
+ "Interval value should be large than 0, got ", interval));
+ *output = new Dataset(ctx, stream, shard, read_indefinitely, interval);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const string& stream, const string& shard,
+ const bool read_indefinitely, const int64 interval)
+ : GraphDatasetBase(ctx),
+ stream_(stream),
+ shard_(shard),
+ read_indefinitely_(read_indefinitely),
+ interval_(interval) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Kinesis")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() const override { return "KinesisDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* stream = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream));
+ Node* shard = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard));
+ Node* read_indefinitely = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely));
+ Node* interval = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {stream, shard, read_indefinitely, interval}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ client_(nullptr, ShutdownClient) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (iterator_ == "") {
+ TF_RETURN_IF_ERROR(SetupStreamsLocked());
+ }
+ do {
+ Aws::Kinesis::Model::GetRecordsRequest request;
+ auto outcome = client_->GetRecords(
+ request.WithShardIterator(iterator_).WithLimit(1));
+ if (!outcome.IsSuccess()) {
+ return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
+ outcome.GetError().GetMessage());
+ }
+ if (outcome.GetResult().GetRecords().size() == 0) {
+ // If no records were returned then nothing is available at the
+ // moment.
+ if (!dataset()->read_indefinitely_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ // Continue the loop after a period of time.
+ ctx->env()->SleepForMicroseconds(dataset()->interval_);
+ continue;
+ }
+ if (outcome.GetResult().GetRecords().size() != 1) {
+ return errors::Unknown("invalid number of records ",
+ outcome.GetResult().GetRecords().size(),
+ " returned");
+ }
+
+ iterator_ = outcome.GetResult().GetNextShardIterator();
+
+ const auto& data = outcome.GetResult().GetRecords()[0].GetData();
+ StringPiece value(
+ reinterpret_cast<const char*>(data.GetUnderlyingData()),
+ data.GetLength());
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<std::string>()() = std::string(value);
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ *end_of_sequence = false;
+ return Status::OK();
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented("SaveInternal is currently not supported");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "RestoreInternal is currently not supported");
+ }
+
+ private:
+ // Sets up Kinesis streams to read from.
+ Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ AwsInitAPI();
+ client_.reset(
+ new Aws::Kinesis::KinesisClient(GetDefaultClientConfig()));
+
+ Aws::Kinesis::Model::DescribeStreamRequest request;
+ auto outcome = client_->DescribeStream(
+ request.WithStreamName(dataset()->stream_.c_str()));
+ if (!outcome.IsSuccess()) {
+ return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
+ outcome.GetError().GetMessage());
+ }
+ Aws::String shard;
+ Aws::String sequence;
+ if (dataset()->shard_ == "") {
+ if (outcome.GetResult().GetStreamDescription().GetShards().size() !=
+ 1) {
+ return errors::InvalidArgument(
+ "shard has to be provided unless the stream only have one "
+ "shard, there are ",
+ outcome.GetResult().GetStreamDescription().GetShards().size(),
+ " shards in stream ", dataset()->stream_);
+ }
+ shard = outcome.GetResult()
+ .GetStreamDescription()
+ .GetShards()[0]
+ .GetShardId();
+ sequence = outcome.GetResult()
+ .GetStreamDescription()
+ .GetShards()[0]
+ .GetSequenceNumberRange()
+ .GetStartingSequenceNumber();
+ } else {
+ for (const auto& entry :
+ outcome.GetResult().GetStreamDescription().GetShards()) {
+ if (entry.GetShardId() == dataset()->shard_.c_str()) {
+ shard = entry.GetShardId();
+ sequence =
+ entry.GetSequenceNumberRange().GetStartingSequenceNumber();
+ break;
+ }
+ }
+ if (shard == "") {
+ return errors::InvalidArgument("no shard ", dataset()->shard_,
+ " in stream ", dataset()->stream_);
+ }
+ }
+
+ Aws::Kinesis::Model::GetShardIteratorRequest iterator_request;
+ auto iterator_outcome = client_->GetShardIterator(
+ iterator_request.WithStreamName(dataset()->stream_.c_str())
+ .WithShardId(shard)
+ .WithShardIteratorType(
+ Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER)
+ .WithStartingSequenceNumber(sequence));
+ if (!iterator_outcome.IsSuccess()) {
+ return errors::Unknown(iterator_outcome.GetError().GetExceptionName(),
+ ": ",
+ iterator_outcome.GetError().GetMessage());
+ }
+ iterator_ = iterator_outcome.GetResult().GetShardIterator();
+ return Status::OK();
+ }
+
+ mutex mu_;
+ Aws::String iterator_ GUARDED_BY(mu_);
+ std::unique_ptr<Aws::Kinesis::KinesisClient, decltype(&ShutdownClient)>
+ client_ GUARDED_BY(mu_);
+ };
+
+ const std::string stream_;
+ const std::string shard_;
+ const bool read_indefinitely_;
+ const int64 interval_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU),
+ KinesisDatasetOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/kinesis/ops/dataset_ops.cc b/tensorflow/contrib/kinesis/ops/dataset_ops.cc
new file mode 100644
index 0000000000..54204513cf
--- /dev/null
+++ b/tensorflow/contrib/kinesis/ops/dataset_ops.cc
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KinesisDataset")
+ .Input("stream: string")
+ .Input("shard: string")
+ .Input("read_indefinitely: bool")
+ .Input("interval: int64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that emits the messages of one or more Kinesis topics.
+
+stream: A `tf.string` tensor containing the name of the stream.
+shard: A `tf.string` tensor containing the id of the shard.
+read_indefinitely: If `True`, the Kinesis dataset will keep retry
+ again on `EOF` after the `interval` period. If `False`, then
+ the dataset will stop on `EOF`. The default value is `True`.
+interval: The interval for the Kinesis Client to wait before
+ it tries to get records again (in millisecond).
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
new file mode 100644
index 0000000000..7289b45c50
--- /dev/null
+++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
@@ -0,0 +1,139 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for KinesisDataset.
+NOTE: boto3 is needed and the test has to be invoked manually:
+```
+$ bazel test -s --verbose_failures --config=opt \
+ --action_env=AWS_ACCESS_KEY_ID=XXXXXX \
+ --action_env=AWS_SECRET_ACCESS_KEY=XXXXXX \
+ //tensorflow/contrib/kinesis:kinesis_test
+```
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import boto3
+
+from tensorflow.contrib.kinesis.python.ops import kinesis_dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class KinesisDatasetTest(test.TestCase):
+
+ def testKinesisDatasetOneShard(self):
+ client = boto3.client('kinesis', region_name='us-east-1')
+
+ # Setup the Kinesis with 1 shard.
+ stream_name = "tf_kinesis_test_1"
+ client.create_stream(StreamName=stream_name, ShardCount=1)
+ # Wait until stream exists, default is 10 * 18 seconds.
+ client.get_waiter('stream_exists').wait(StreamName=stream_name)
+ for i in range(10):
+ data = "D" + str(i)
+ client.put_record(
+ StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
+
+ stream = array_ops.placeholder(dtypes.string, shape=[])
+ num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = kinesis_dataset_ops.KinesisDataset(
+ stream, read_indefinitely=False).repeat(num_epochs)
+ batch_dataset = repeat_dataset.batch(batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ init_op = iterator.make_initializer(repeat_dataset)
+ init_batch_op = iterator.make_initializer(batch_dataset)
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ # Basic test: read from shard 0 of stream 1.
+ sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1})
+ for i in range(10):
+ self.assertEqual("D" + str(i), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ client.delete_stream(StreamName=stream_name)
+ # Wait until stream deleted, default is 10 * 18 seconds.
+ client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
+
+ def testKinesisDatasetTwoShards(self):
+ client = boto3.client('kinesis', region_name='us-east-1')
+
+ # Setup the Kinesis with 2 shards.
+ stream_name = "tf_kinesis_test_2"
+ client.create_stream(StreamName=stream_name, ShardCount=2)
+ # Wait until stream exists, default is 10 * 18 seconds.
+ client.get_waiter('stream_exists').wait(StreamName=stream_name)
+
+ for i in range(10):
+ data = "D" + str(i)
+ client.put_record(
+ StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
+ response = client.describe_stream(StreamName=stream_name)
+ shard_id_0 = response["StreamDescription"]["Shards"][0]["ShardId"]
+ shard_id_1 = response["StreamDescription"]["Shards"][1]["ShardId"]
+
+ stream = array_ops.placeholder(dtypes.string, shape=[])
+ shard = array_ops.placeholder(dtypes.string, shape=[])
+ num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = kinesis_dataset_ops.KinesisDataset(
+ stream, shard, read_indefinitely=False).repeat(num_epochs)
+ batch_dataset = repeat_dataset.batch(batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ init_op = iterator.make_initializer(repeat_dataset)
+ init_batch_op = iterator.make_initializer(batch_dataset)
+ get_next = iterator.get_next()
+
+ data = list()
+ with self.test_session() as sess:
+ # Basic test: read from shard 0 of stream 2.
+ sess.run(
+ init_op, feed_dict={
+ stream: stream_name, shard: shard_id_0, num_epochs: 1})
+ with self.assertRaises(errors.OutOfRangeError):
+ # Use range(11) to guarantee the OutOfRangeError.
+ for i in range(11):
+ data.append(sess.run(get_next))
+
+ # Basic test: read from shard 1 of stream 2.
+ sess.run(
+ init_op, feed_dict={
+ stream: stream_name, shard: shard_id_1, num_epochs: 1})
+ with self.assertRaises(errors.OutOfRangeError):
+ # Use range(11) to guarantee the OutOfRangeError.
+ for i in range(11):
+ data.append(sess.run(get_next))
+
+ data.sort()
+ self.assertEqual(data, ["D" + str(i) for i in range(10)])
+
+ client.delete_stream(StreamName=stream_name)
+ # Wait until stream deleted, default is 10 * 18 seconds.
+ client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
new file mode 100644
index 0000000000..ca2df95ba4
--- /dev/null
+++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
@@ -0,0 +1,96 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kinesis Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops
+from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class KinesisDataset(Dataset):
+ """A Kinesis Dataset that consumes the message.
+
+ Kinesis is a managed service provided by AWS for data streaming.
+ This dataset reads messages from Kinesis with each message presented
+ as a `tf.string`.
+
+ For example, we can construct and use the KinesisDataset as follows:
+ ```python
+ dataset = tf.contrib.kinesis.KinesisDataset(
+ "kinesis_stream_name", read_indefinitely=False)
+ next = dataset.make_one_shot_iterator().get_next()
+ with tf.Session() as sess:
+ while True:
+ try:
+ print(sess.run(nxt))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+
+ Since Kinesis is a data streaming service, data may not be available
+ at the time it is being read. The argument `read_indefinitely` is
+ used to control the behavior in this situation. If `read_indefinitely`
+ is `True`, then `KinesisDataset` will keep retrying to retrieve data
+ from the stream. If `read_indefinitely` is `False`, an `OutOfRangeError`
+ is returned immediately instead.
+ """
+
+ def __init__(self,
+ stream,
+ shard="",
+ read_indefinitely=True,
+ interval=100000):
+ """Create a KinesisDataset.
+
+ Args:
+ stream: A `tf.string` tensor containing the name of the stream.
+ shard: A `tf.string` tensor containing the id of the shard.
+ read_indefinitely: If `True`, the Kinesis dataset will keep retry
+ again on `EOF` after the `interval` period. If `False`, then
+ the dataset will stop on `EOF`. The default value is `True`.
+ interval: The interval for the Kinesis Client to wait before
+ it tries to get records again (in millisecond).
+ """
+ super(KinesisDataset, self).__init__()
+ self._stream = ops.convert_to_tensor(
+ stream, dtype=dtypes.string, name="stream")
+ self._shard = ops.convert_to_tensor(
+ shard, dtype=dtypes.string, name="shard")
+ self._read_indefinitely = ops.convert_to_tensor(
+ read_indefinitely, dtype=dtypes.bool, name="read_indefinitely")
+ self._interval = ops.convert_to_tensor(
+ interval, dtype=dtypes.int64, name="interval")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.kinesis_dataset(
+ self._stream, self._shard, self._read_indefinitely, self._interval)
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.string
diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py
new file mode 100644
index 0000000000..c9ce9f3646
--- /dev/null
+++ b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python helper for loading kinesis ops and kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_dataset_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../../_dataset_ops.so"))