aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--configure.py6
-rw-r--r--tensorflow/BUILD16
-rw-r--r--tensorflow/contrib/BUILD18
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt2
-rw-r--r--tensorflow/contrib/kinesis/BUILD113
-rw-r--r--tensorflow/contrib/kinesis/__init__.py32
-rw-r--r--tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc359
-rw-r--r--tensorflow/contrib/kinesis/ops/dataset_ops.cc42
-rw-r--r--tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py139
-rw-r--r--tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py96
-rw-r--r--tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py24
-rw-r--r--tensorflow/core/platform/default/build_config.bzl8
-rw-r--r--tensorflow/core/platform/s3/BUILD14
-rw-r--r--tensorflow/core/platform/s3/aws_crypto.cc (renamed from tensorflow/core/platform/s3/s3_crypto.cc)22
-rw-r--r--tensorflow/core/platform/s3/aws_crypto.h (renamed from tensorflow/core/platform/s3/s3_crypto.h)6
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc6
-rw-r--r--third_party/aws.BUILD3
17 files changed, 866 insertions, 40 deletions
diff --git a/configure.py b/configure.py
index 5243e09b24..31a83b4a15 100644
--- a/configure.py
+++ b/configure.py
@@ -1449,7 +1449,7 @@ def main():
setup_python(environ_cp)
if is_windows():
- environ_cp['TF_NEED_S3'] = '0'
+ environ_cp['TF_NEED_AWS'] = '0'
environ_cp['TF_NEED_GCP'] = '0'
environ_cp['TF_NEED_HDFS'] = '0'
environ_cp['TF_NEED_JEMALLOC'] = '0'
@@ -1473,8 +1473,8 @@ def main():
'with_gcp_support', True, 'gcp')
set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
'with_hdfs_support', True, 'hdfs')
- set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
- 'with_s3_support', True, 's3')
+ set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform',
+ 'with_aws_support', True, 'aws')
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
'with_kafka_support', True, 'kafka')
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index f362900387..51eea94847 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -216,8 +216,8 @@ config_setting(
)
config_setting(
- name = "with_s3_support",
- define_values = {"with_s3_support": "true"},
+ name = "with_aws_support",
+ define_values = {"with_aws_support": "true"},
visibility = ["//visibility:public"],
)
@@ -244,8 +244,8 @@ config_setting(
)
config_setting(
- name = "with_s3_support_windows_override",
- define_values = {"with_s3_support": "true"},
+ name = "with_aws_support_windows_override",
+ define_values = {"with_aws_support": "true"},
values = {"cpu": "x64_windows"},
visibility = ["//visibility:public"],
)
@@ -279,8 +279,8 @@ config_setting(
)
config_setting(
- name = "with_s3_support_android_override",
- define_values = {"with_s3_support": "true"},
+ name = "with_aws_support_android_override",
+ define_values = {"with_aws_support": "true"},
values = {"crosstool_top": "//external:android/crosstool"},
visibility = ["//visibility:public"],
)
@@ -300,8 +300,8 @@ config_setting(
)
config_setting(
- name = "with_s3_support_ios_override",
- define_values = {"with_s3_support": "true"},
+ name = "with_aws_support_ios_override",
+ define_values = {"with_aws_support": "true"},
values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index e2c85f3995..fa69efa3f6 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -122,6 +122,12 @@ py_library(
"//tensorflow/contrib/kafka",
],
"//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support": [
+ "//tensorflow/contrib/kinesis",
+ ],
+ "//conditions:default": [],
}) + if_not_windows_cuda([
"//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
]) + if_not_windows([
@@ -157,6 +163,12 @@ cc_library(
"//tensorflow/contrib/kafka:dataset_kernels",
],
"//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support": [
+ "//tensorflow/contrib/kinesis:dataset_kernels",
+ ],
+ "//conditions:default": [],
}),
)
@@ -186,5 +198,11 @@ cc_library(
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
],
"//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support": [
+ "//tensorflow/contrib/kinesis:dataset_ops_op_lib",
+ ],
+ "//conditions:default": [],
}),
)
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 8ff6ebedab..a5eba5a8c9 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -240,6 +240,8 @@ tensorflow/contrib/keras/api/keras/wrappers/scikit_learn
tensorflow/contrib/kernel_methods
tensorflow/contrib/kernel_methods/python
tensorflow/contrib/kernel_methods/python/mappers
+tensorflow/contrib/kinesis/python
+tensorflow/contrib/kinesis/python/ops
tensorflow/contrib/kfac
tensorflow/contrib/kfac/examples
tensorflow/contrib/kfac/python
diff --git a/tensorflow/contrib/kinesis/BUILD b/tensorflow/contrib/kinesis/BUILD
new file mode 100644
index 0000000000..25443d0ad4
--- /dev/null
+++ b/tensorflow/contrib/kinesis/BUILD
@@ -0,0 +1,113 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_py_test",
+)
+
+py_library(
+ name = "kinesis",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = ["ops/dataset_ops.cc"],
+ deps = [":dataset_kernels"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["dataset_ops"],
+)
+
+cc_library(
+ name = "dataset_kernels",
+ srcs = [
+ "kernels/kinesis_dataset_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core/platform/s3:aws_crypto",
+ "//third_party/eigen3",
+ "@aws",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+py_library(
+ name = "dataset_ops",
+ srcs = [
+ "python/ops/kinesis_dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":kinesis_op_loader",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_dataset_ops",
+ out = "python/ops/gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/kinesis:dataset_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "dataset_ops_kernels",
+ deps = [
+ ":dataset_kernels",
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "kinesis_op_loader",
+ srcs = ["python/ops/kinesis_op_loader.py"],
+ dso = ["//tensorflow/contrib/kinesis:_dataset_ops.so"],
+ kernels = [
+ ":dataset_ops_kernels",
+ "//tensorflow/contrib/kinesis:dataset_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_dataset_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+tf_py_test(
+ name = "kinesis_test",
+ srcs = ["python/kernel_tests/kinesis_test.py"],
+ additional_deps = [
+ ":kinesis",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "no_windows",
+ "notap",
+ ],
+)
diff --git a/tensorflow/contrib/kinesis/__init__.py b/tensorflow/contrib/kinesis/__init__.py
new file mode 100644
index 0000000000..3824b8ae75
--- /dev/null
+++ b/tensorflow/contrib/kinesis/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kinesis Dataset.
+
+@@KinesisDataset
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kinesis.python.ops.kinesis_dataset_ops import KinesisDataset
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "KinesisDataset",
+]
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc
new file mode 100644
index 0000000000..3212279c4c
--- /dev/null
+++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc
@@ -0,0 +1,359 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <aws/core/Aws.h>
+#include <aws/core/config/AWSProfileConfigLoader.h>
+#include <aws/core/utils/Outcome.h>
+#include <aws/kinesis/KinesisClient.h>
+#include <aws/kinesis/model/DescribeStreamRequest.h>
+#include <aws/kinesis/model/GetRecordsRequest.h>
+#include <aws/kinesis/model/GetShardIteratorRequest.h>
+#include <aws/kinesis/model/PutRecordsRequest.h>
+#include <aws/kinesis/model/ShardIteratorType.h>
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/platform/s3/aws_crypto.h"
+
+namespace tensorflow {
+namespace {
+
+Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() {
+ static Aws::Client::ClientConfiguration config;
+ const char* endpoint = getenv("KINESIS_ENDPOINT");
+ if (endpoint) {
+ config.endpointOverride = Aws::String(endpoint);
+ }
+ const char* region = getenv("AWS_REGION");
+ if (region) {
+ config.region = Aws::String(region);
+ } else {
+ // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
+ // is set with a truthy value.
+ const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
+ string load_config =
+ load_config_env ? str_util::Lowercase(load_config_env) : "";
+ if (load_config == "true" || load_config == "1") {
+ Aws::String config_file;
+ // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
+ const char* config_file_env = getenv("AWS_CONFIG_FILE");
+ if (config_file_env) {
+ config_file = config_file_env;
+ } else {
+ const char* home_env = getenv("HOME");
+ if (home_env) {
+ config_file = home_env;
+ config_file += "/.aws/config";
+ }
+ }
+ Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
+ // Load the configuration. If successful, get the region.
+ // If the load is not successful, then generate a warning.
+ if (loader.Load()) {
+ auto profiles = loader.GetProfiles();
+ if (!profiles["default"].GetRegion().empty()) {
+ config.region = profiles["default"].GetRegion();
+ }
+ } else {
+ LOG(WARNING) << "Failed to load the profile in " << config_file << ".";
+ }
+ }
+ }
+ const char* use_https = getenv("KINESIS_USE_HTTPS");
+ if (use_https) {
+ if (use_https[0] == '0') {
+ config.scheme = Aws::Http::Scheme::HTTP;
+ } else {
+ config.scheme = Aws::Http::Scheme::HTTPS;
+ }
+ }
+ const char* verify_ssl = getenv("KINESIS_VERIFY_SSL");
+ if (verify_ssl) {
+ if (verify_ssl[0] == '0') {
+ config.verifySSL = false;
+ } else {
+ config.verifySSL = true;
+ }
+ }
+ const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC");
+ if (connect_timeout) {
+ int64 timeout;
+
+ if (strings::safe_strto64(connect_timeout, &timeout)) {
+ config.connectTimeoutMs = timeout;
+ }
+ }
+ const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC");
+ if (request_timeout) {
+ int64 timeout;
+
+ if (strings::safe_strto64(request_timeout, &timeout)) {
+ config.requestTimeoutMs = timeout;
+ }
+ }
+
+ return &config;
+}
+
+Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
+ static Aws::Client::ClientConfiguration* config =
+ InitializeDefaultClientConfig();
+ return *config;
+}
+
+static mutex mu(LINKER_INITIALIZED);
+static unsigned count(0);
+void AwsInitAPI() {
+ mutex_lock lock(mu);
+ count++;
+ if (count == 1) {
+ Aws::SDKOptions options;
+ options.cryptoOptions.sha256Factory_create_fn = []() {
+ return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
+ };
+ options.cryptoOptions.sha256HMACFactory_create_fn = []() {
+ return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
+ };
+ Aws::InitAPI(options);
+ }
+}
+void AwsShutdownAPI() {
+ mutex_lock lock(mu);
+ count--;
+ if (count == 0) {
+ Aws::SDKOptions options;
+ Aws::ShutdownAPI(options);
+ }
+}
+void ShutdownClient(Aws::Kinesis::KinesisClient* client) {
+ if (client != nullptr) {
+ delete client;
+ AwsShutdownAPI();
+ }
+}
+}
+class KinesisDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ std::string stream = "";
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<std::string>(ctx, "stream", &stream));
+ std::string shard = "";
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "shard", &shard));
+ bool read_indefinitely = true;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "read_indefinitely",
+ &read_indefinitely));
+ int64 interval = -1;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "interval", &interval));
+ OP_REQUIRES(ctx, (interval > 0),
+ errors::InvalidArgument(
+ "Interval value should be large than 0, got ", interval));
+ *output = new Dataset(ctx, stream, shard, read_indefinitely, interval);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const string& stream, const string& shard,
+ const bool read_indefinitely, const int64 interval)
+ : GraphDatasetBase(ctx),
+ stream_(stream),
+ shard_(shard),
+ read_indefinitely_(read_indefinitely),
+ interval_(interval) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Kinesis")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() const override { return "KinesisDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* stream = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream));
+ Node* shard = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard));
+ Node* read_indefinitely = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely));
+ Node* interval = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {stream, shard, read_indefinitely, interval}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ client_(nullptr, ShutdownClient) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (iterator_ == "") {
+ TF_RETURN_IF_ERROR(SetupStreamsLocked());
+ }
+ do {
+ Aws::Kinesis::Model::GetRecordsRequest request;
+ auto outcome = client_->GetRecords(
+ request.WithShardIterator(iterator_).WithLimit(1));
+ if (!outcome.IsSuccess()) {
+ return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
+ outcome.GetError().GetMessage());
+ }
+ if (outcome.GetResult().GetRecords().size() == 0) {
+ // If no records were returned then nothing is available at the
+ // moment.
+ if (!dataset()->read_indefinitely_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ // Continue the loop after a period of time.
+ ctx->env()->SleepForMicroseconds(dataset()->interval_);
+ continue;
+ }
+ if (outcome.GetResult().GetRecords().size() != 1) {
+ return errors::Unknown("invalid number of records ",
+ outcome.GetResult().GetRecords().size(),
+ " returned");
+ }
+
+ iterator_ = outcome.GetResult().GetNextShardIterator();
+
+ const auto& data = outcome.GetResult().GetRecords()[0].GetData();
+ StringPiece value(
+ reinterpret_cast<const char*>(data.GetUnderlyingData()),
+ data.GetLength());
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<std::string>()() = std::string(value);
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ *end_of_sequence = false;
+ return Status::OK();
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented("SaveInternal is currently not supported");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "RestoreInternal is currently not supported");
+ }
+
+ private:
+ // Sets up Kinesis streams to read from.
+ Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ AwsInitAPI();
+ client_.reset(
+ new Aws::Kinesis::KinesisClient(GetDefaultClientConfig()));
+
+ Aws::Kinesis::Model::DescribeStreamRequest request;
+ auto outcome = client_->DescribeStream(
+ request.WithStreamName(dataset()->stream_.c_str()));
+ if (!outcome.IsSuccess()) {
+ return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
+ outcome.GetError().GetMessage());
+ }
+ Aws::String shard;
+ Aws::String sequence;
+ if (dataset()->shard_ == "") {
+ if (outcome.GetResult().GetStreamDescription().GetShards().size() !=
+ 1) {
+ return errors::InvalidArgument(
+ "shard has to be provided unless the stream only have one "
+ "shard, there are ",
+ outcome.GetResult().GetStreamDescription().GetShards().size(),
+ " shards in stream ", dataset()->stream_);
+ }
+ shard = outcome.GetResult()
+ .GetStreamDescription()
+ .GetShards()[0]
+ .GetShardId();
+ sequence = outcome.GetResult()
+ .GetStreamDescription()
+ .GetShards()[0]
+ .GetSequenceNumberRange()
+ .GetStartingSequenceNumber();
+ } else {
+ for (const auto& entry :
+ outcome.GetResult().GetStreamDescription().GetShards()) {
+ if (entry.GetShardId() == dataset()->shard_.c_str()) {
+ shard = entry.GetShardId();
+ sequence =
+ entry.GetSequenceNumberRange().GetStartingSequenceNumber();
+ break;
+ }
+ }
+ if (shard == "") {
+ return errors::InvalidArgument("no shard ", dataset()->shard_,
+ " in stream ", dataset()->stream_);
+ }
+ }
+
+ Aws::Kinesis::Model::GetShardIteratorRequest iterator_request;
+ auto iterator_outcome = client_->GetShardIterator(
+ iterator_request.WithStreamName(dataset()->stream_.c_str())
+ .WithShardId(shard)
+ .WithShardIteratorType(
+ Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER)
+ .WithStartingSequenceNumber(sequence));
+ if (!iterator_outcome.IsSuccess()) {
+ return errors::Unknown(iterator_outcome.GetError().GetExceptionName(),
+ ": ",
+ iterator_outcome.GetError().GetMessage());
+ }
+ iterator_ = iterator_outcome.GetResult().GetShardIterator();
+ return Status::OK();
+ }
+
+ mutex mu_;
+ Aws::String iterator_ GUARDED_BY(mu_);
+ std::unique_ptr<Aws::Kinesis::KinesisClient, decltype(&ShutdownClient)>
+ client_ GUARDED_BY(mu_);
+ };
+
+ const std::string stream_;
+ const std::string shard_;
+ const bool read_indefinitely_;
+ const int64 interval_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU),
+ KinesisDatasetOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/kinesis/ops/dataset_ops.cc b/tensorflow/contrib/kinesis/ops/dataset_ops.cc
new file mode 100644
index 0000000000..54204513cf
--- /dev/null
+++ b/tensorflow/contrib/kinesis/ops/dataset_ops.cc
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KinesisDataset")
+ .Input("stream: string")
+ .Input("shard: string")
+ .Input("read_indefinitely: bool")
+ .Input("interval: int64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that emits the messages of one or more Kinesis topics.
+
+stream: A `tf.string` tensor containing the name of the stream.
+shard: A `tf.string` tensor containing the id of the shard.
+read_indefinitely: If `True`, the Kinesis dataset will keep retry
+ again on `EOF` after the `interval` period. If `False`, then
+ the dataset will stop on `EOF`. The default value is `True`.
+interval: The interval for the Kinesis Client to wait before
+ it tries to get records again (in millisecond).
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
new file mode 100644
index 0000000000..7289b45c50
--- /dev/null
+++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
@@ -0,0 +1,139 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for KinesisDataset.
+NOTE: boto3 is needed and the test has to be invoked manually:
+```
+$ bazel test -s --verbose_failures --config=opt \
+ --action_env=AWS_ACCESS_KEY_ID=XXXXXX \
+ --action_env=AWS_SECRET_ACCESS_KEY=XXXXXX \
+ //tensorflow/contrib/kinesis:kinesis_test
+```
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import boto3
+
+from tensorflow.contrib.kinesis.python.ops import kinesis_dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class KinesisDatasetTest(test.TestCase):
+
+ def testKinesisDatasetOneShard(self):
+ client = boto3.client('kinesis', region_name='us-east-1')
+
+ # Setup the Kinesis with 1 shard.
+ stream_name = "tf_kinesis_test_1"
+ client.create_stream(StreamName=stream_name, ShardCount=1)
+ # Wait until stream exists, default is 10 * 18 seconds.
+ client.get_waiter('stream_exists').wait(StreamName=stream_name)
+ for i in range(10):
+ data = "D" + str(i)
+ client.put_record(
+ StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
+
+ stream = array_ops.placeholder(dtypes.string, shape=[])
+ num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = kinesis_dataset_ops.KinesisDataset(
+ stream, read_indefinitely=False).repeat(num_epochs)
+ batch_dataset = repeat_dataset.batch(batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ init_op = iterator.make_initializer(repeat_dataset)
+ init_batch_op = iterator.make_initializer(batch_dataset)
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ # Basic test: read from shard 0 of stream 1.
+ sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1})
+ for i in range(10):
+ self.assertEqual("D" + str(i), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ client.delete_stream(StreamName=stream_name)
+ # Wait until stream deleted, default is 10 * 18 seconds.
+ client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
+
+ def testKinesisDatasetTwoShards(self):
+ client = boto3.client('kinesis', region_name='us-east-1')
+
+ # Setup the Kinesis with 2 shards.
+ stream_name = "tf_kinesis_test_2"
+ client.create_stream(StreamName=stream_name, ShardCount=2)
+ # Wait until stream exists, default is 10 * 18 seconds.
+ client.get_waiter('stream_exists').wait(StreamName=stream_name)
+
+ for i in range(10):
+ data = "D" + str(i)
+ client.put_record(
+ StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
+ response = client.describe_stream(StreamName=stream_name)
+ shard_id_0 = response["StreamDescription"]["Shards"][0]["ShardId"]
+ shard_id_1 = response["StreamDescription"]["Shards"][1]["ShardId"]
+
+ stream = array_ops.placeholder(dtypes.string, shape=[])
+ shard = array_ops.placeholder(dtypes.string, shape=[])
+ num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = kinesis_dataset_ops.KinesisDataset(
+ stream, shard, read_indefinitely=False).repeat(num_epochs)
+ batch_dataset = repeat_dataset.batch(batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ init_op = iterator.make_initializer(repeat_dataset)
+ init_batch_op = iterator.make_initializer(batch_dataset)
+ get_next = iterator.get_next()
+
+ data = list()
+ with self.test_session() as sess:
+ # Basic test: read from shard 0 of stream 2.
+ sess.run(
+ init_op, feed_dict={
+ stream: stream_name, shard: shard_id_0, num_epochs: 1})
+ with self.assertRaises(errors.OutOfRangeError):
+ # Use range(11) to guarantee the OutOfRangeError.
+ for i in range(11):
+ data.append(sess.run(get_next))
+
+ # Basic test: read from shard 1 of stream 2.
+ sess.run(
+ init_op, feed_dict={
+ stream: stream_name, shard: shard_id_1, num_epochs: 1})
+ with self.assertRaises(errors.OutOfRangeError):
+ # Use range(11) to guarantee the OutOfRangeError.
+ for i in range(11):
+ data.append(sess.run(get_next))
+
+ data.sort()
+ self.assertEqual(data, ["D" + str(i) for i in range(10)])
+
+ client.delete_stream(StreamName=stream_name)
+ # Wait until stream deleted, default is 10 * 18 seconds.
+ client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
new file mode 100644
index 0000000000..ca2df95ba4
--- /dev/null
+++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
@@ -0,0 +1,96 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kinesis Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops
+from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class KinesisDataset(Dataset):
+ """A Kinesis Dataset that consumes the message.
+
+ Kinesis is a managed service provided by AWS for data streaming.
+ This dataset reads messages from Kinesis with each message presented
+ as a `tf.string`.
+
+ For example, we can construct and use the KinesisDataset as follows:
+ ```python
+ dataset = tf.contrib.kinesis.KinesisDataset(
+ "kinesis_stream_name", read_indefinitely=False)
+ next = dataset.make_one_shot_iterator().get_next()
+ with tf.Session() as sess:
+ while True:
+ try:
+ print(sess.run(nxt))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+
+ Since Kinesis is a data streaming service, data may not be available
+ at the time it is being read. The argument `read_indefinitely` is
+ used to control the behavior in this situation. If `read_indefinitely`
+ is `True`, then `KinesisDataset` will keep retrying to retrieve data
+ from the stream. If `read_indefinitely` is `False`, an `OutOfRangeError`
+ is returned immediately instead.
+ """
+
+ def __init__(self,
+ stream,
+ shard="",
+ read_indefinitely=True,
+ interval=100000):
+ """Create a KinesisDataset.
+
+ Args:
+ stream: A `tf.string` tensor containing the name of the stream.
+ shard: A `tf.string` tensor containing the id of the shard.
+ read_indefinitely: If `True`, the Kinesis dataset will keep retry
+ again on `EOF` after the `interval` period. If `False`, then
+ the dataset will stop on `EOF`. The default value is `True`.
+ interval: The interval for the Kinesis Client to wait before
+ it tries to get records again (in millisecond).
+ """
+ super(KinesisDataset, self).__init__()
+ self._stream = ops.convert_to_tensor(
+ stream, dtype=dtypes.string, name="stream")
+ self._shard = ops.convert_to_tensor(
+ shard, dtype=dtypes.string, name="shard")
+ self._read_indefinitely = ops.convert_to_tensor(
+ read_indefinitely, dtype=dtypes.bool, name="read_indefinitely")
+ self._interval = ops.convert_to_tensor(
+ interval, dtype=dtypes.int64, name="interval")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.kinesis_dataset(
+ self._stream, self._shard, self._read_indefinitely, self._interval)
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.string
diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py
new file mode 100644
index 0000000000..c9ce9f3646
--- /dev/null
+++ b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python helper for loading kinesis ops and kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_dataset_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 66ccd81e41..28891320c4 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -620,10 +620,10 @@ def tf_additional_core_deps():
],
"//conditions:default": [],
}) + select({
- "//tensorflow:with_s3_support_windows_override": [],
- "//tensorflow:with_s3_support_android_override": [],
- "//tensorflow:with_s3_support_ios_override": [],
- "//tensorflow:with_s3_support": [
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support_android_override": [],
+ "//tensorflow:with_aws_support_ios_override": [],
+ "//tensorflow:with_aws_support": [
"//tensorflow/core/platform/s3:s3_file_system",
],
"//conditions:default": [],
diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD
index 21038cfeb1..41184b6fd9 100644
--- a/tensorflow/core/platform/s3/BUILD
+++ b/tensorflow/core/platform/s3/BUILD
@@ -16,10 +16,10 @@ load(
tf_cc_binary(
name = "s3_file_system.so",
srcs = [
+ "aws_crypto.cc",
+ "aws_crypto.h",
"aws_logging.cc",
"aws_logging.h",
- "s3_crypto.cc",
- "s3_crypto.h",
"s3_file_system.cc",
"s3_file_system.h",
],
@@ -40,16 +40,14 @@ tf_cc_binary(
)
cc_library(
- name = "s3_crypto",
+ name = "aws_crypto",
srcs = [
- "s3_crypto.cc",
+ "aws_crypto.cc",
],
hdrs = [
- "s3_crypto.h",
+ "aws_crypto.h",
],
deps = [
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
"@aws",
"@boringssl//:crypto",
],
@@ -81,8 +79,8 @@ cc_library(
"s3_file_system.h",
],
deps = [
+ ":aws_crypto",
":aws_logging",
- ":s3_crypto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@aws",
diff --git a/tensorflow/core/platform/s3/s3_crypto.cc b/tensorflow/core/platform/s3/aws_crypto.cc
index d7062a59d2..90e46d6c1d 100644
--- a/tensorflow/core/platform/s3/s3_crypto.cc
+++ b/tensorflow/core/platform/s3/aws_crypto.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/platform/s3/s3_crypto.h"
+#include "tensorflow/core/platform/s3/aws_crypto.h"
#include <openssl/hmac.h>
#include <openssl/sha.h>
@@ -21,11 +21,11 @@ limitations under the License.
namespace tensorflow {
-class S3Sha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
+class AWSSha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
public:
- S3Sha256HMACOpenSSLImpl() {}
+ AWSSha256HMACOpenSSLImpl() {}
- virtual ~S3Sha256HMACOpenSSLImpl() = default;
+ virtual ~AWSSha256HMACOpenSSLImpl() = default;
virtual Aws::Utils::Crypto::HashResult Calculate(
const Aws::Utils::ByteBuffer& toSign,
@@ -47,11 +47,11 @@ class S3Sha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
}
};
-class S3Sha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
+class AWSSha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
public:
- S3Sha256OpenSSLImpl() {}
+ AWSSha256OpenSSLImpl() {}
- virtual ~S3Sha256OpenSSLImpl() = default;
+ virtual ~AWSSha256OpenSSLImpl() = default;
virtual Aws::Utils::Crypto::HashResult Calculate(
const Aws::String& str) override {
@@ -101,13 +101,13 @@ class S3Sha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
};
std::shared_ptr<Aws::Utils::Crypto::Hash>
-S3SHA256Factory::CreateImplementation() const {
- return Aws::MakeShared<S3Sha256OpenSSLImpl>(S3CryptoAllocationTag);
+AWSSHA256Factory::CreateImplementation() const {
+ return Aws::MakeShared<AWSSha256OpenSSLImpl>(AWSCryptoAllocationTag);
}
std::shared_ptr<Aws::Utils::Crypto::HMAC>
-S3SHA256HmacFactory::CreateImplementation() const {
- return Aws::MakeShared<S3Sha256HMACOpenSSLImpl>(S3CryptoAllocationTag);
+AWSSHA256HmacFactory::CreateImplementation() const {
+ return Aws::MakeShared<AWSSha256HMACOpenSSLImpl>(AWSCryptoAllocationTag);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/s3/s3_crypto.h b/tensorflow/core/platform/s3/aws_crypto.h
index e376b8b0c0..f05771b904 100644
--- a/tensorflow/core/platform/s3/s3_crypto.h
+++ b/tensorflow/core/platform/s3/aws_crypto.h
@@ -18,15 +18,15 @@ limitations under the License.
#include <aws/core/utils/crypto/Hash.h>
namespace tensorflow {
-static const char* S3CryptoAllocationTag = "S3CryptoAllocation";
+static const char* AWSCryptoAllocationTag = "AWSCryptoAllocation";
-class S3SHA256Factory : public Aws::Utils::Crypto::HashFactory {
+class AWSSHA256Factory : public Aws::Utils::Crypto::HashFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation()
const override;
};
-class S3SHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
+class AWSSHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation()
const override;
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index 6da679dc75..bdc8f808df 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/s3/aws_crypto.h"
#include "tensorflow/core/platform/s3/aws_logging.h"
-#include "tensorflow/core/platform/s3/s3_crypto.h"
#include <aws/core/Aws.h>
#include <aws/core/config/AWSProfileConfigLoader.h>
@@ -300,10 +300,10 @@ std::shared_ptr<Aws::S3::S3Client> S3FileSystem::GetS3Client() {
Aws::SDKOptions options;
options.cryptoOptions.sha256Factory_create_fn = []() {
- return Aws::MakeShared<S3SHA256Factory>(S3CryptoAllocationTag);
+ return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
};
options.cryptoOptions.sha256HMACFactory_create_fn = []() {
- return Aws::MakeShared<S3SHA256HmacFactory>(S3CryptoAllocationTag);
+ return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
};
Aws::InitAPI(options);
diff --git a/third_party/aws.BUILD b/third_party/aws.BUILD
index 2dc921933c..5426f79e46 100644
--- a/third_party/aws.BUILD
+++ b/third_party/aws.BUILD
@@ -46,6 +46,8 @@ cc_library(
"aws-cpp-sdk-core/source/utils/xml/**/*.cpp",
"aws-cpp-sdk-core/source/utils/crypto/*.cpp",
"aws-cpp-sdk-core/source/utils/crypto/factory/**/*.cpp",
+ "aws-cpp-sdk-kinesis/include/**/*.h",
+ "aws-cpp-sdk-kinesis/source/**/*.cpp",
"aws-cpp-sdk-s3/include/**/*.h",
"aws-cpp-sdk-s3/source/**/*.cpp",
]),
@@ -72,6 +74,7 @@ cc_library(
}),
includes = [
"aws-cpp-sdk-core/include/",
+ "aws-cpp-sdk-kinesis/include/",
"aws-cpp-sdk-s3/include/",
],
deps = [