aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-09-28 08:38:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 08:46:34 -0700
commitc7bb3c3d65e4e064d53630d4b524522eed6f3f44 (patch)
tree1fd0b73ab916093c80dcd289154035bba5fb393d /tensorflow/core/kernels
parente06783e7bb80f664c7ec9be90680ac6ddcbd598f (diff)
[tf.data] Move `tf.contrib.data` C++ code to a core "experimental" directory.
NOTE: All ops and kernels previously previously defined in tensorflow/contrib/data have had their name prefixed with "Experimental" to indicate that they are not (yet) stable, and thus not subject to backwards or forwards compatibility guarantees. PiperOrigin-RevId: 214940819
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/data/BUILD1
-rw-r--r--tensorflow/core/kernels/data/experimental/BUILD139
-rw-r--r--tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc156
-rw-r--r--tensorflow/core/kernels/data/experimental/csv_dataset_op.cc860
-rw-r--r--tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc281
-rw-r--r--tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc156
-rw-r--r--tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc141
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.cc375
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.h119
-rw-r--r--tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc218
-rw-r--r--tensorflow/core/kernels/data/experimental/prefetching_kernels.cc482
-rw-r--r--tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc220
-rw-r--r--tensorflow/core/kernels/data/experimental/unique_dataset_op.cc224
13 files changed, 3372 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 87efdff789..6333853cdf 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -765,6 +765,7 @@ tf_kernel_library(
":window_dataset_op",
":writer_ops",
":zip_dataset_op",
+ "//tensorflow/core/kernels/data/experimental:dataset_kernels",
],
)
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
new file mode 100644
index 0000000000..43406db3ed
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -0,0 +1,139 @@
+# Description:
+# Contains experimental kernels for datasets and iterators.
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
+)
+
+cc_library(
+ name = "indexed_dataset_headers",
+ hdrs = ["indexed_dataset.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "indexed_dataset",
+ srcs = [
+ "identity_indexed_dataset.cc",
+ "indexed_dataset.cc",
+ ],
+ deps = [
+ ":indexed_dataset_headers",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "prefetching_kernels",
+ srcs = ["prefetching_kernels.cc"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
+ name = "directed_interleave_dataset_op",
+ srcs = ["directed_interleave_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "csv_dataset_op",
+ srcs = ["csv_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
+ name = "ignore_errors_dataset_op",
+ srcs = ["ignore_errors_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "lmdb_dataset_op",
+ srcs = ["lmdb_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ "@lmdb",
+ ],
+)
+
+tf_kernel_library(
+ name = "threadpool_dataset_op",
+ srcs = ["threadpool_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "unique_dataset_op",
+ srcs = ["unique_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "assert_next_dataset_op",
+ srcs = ["assert_next_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "dataset_kernels",
+ deps = [
+ ":assert_next_dataset_op",
+ ":csv_dataset_op",
+ ":directed_interleave_dataset_op",
+ ":ignore_errors_dataset_op",
+ ":indexed_dataset",
+ ":lmdb_dataset_op",
+ ":prefetching_kernels",
+ ":threadpool_dataset_op",
+ ":unique_dataset_op",
+ ],
+)
diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
new file mode 100644
index 0000000000..3511cca0f5
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
@@ -0,0 +1,156 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class AssertNextDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit AssertNextDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ std::vector<string> transformations;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations",
+ &transformations));
+ *output =
+ new Dataset(ctx, input, transformations, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const std::vector<string>& transformations,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ transformations_(transformations),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Assert")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "AssertNextDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* transformations_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, transformations_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ std::vector<string> tokens =
+ str_util::Split(prefix(), ':', str_util::SkipEmpty());
+ if (dataset()->transformations_.size() > tokens.size() - 2) {
+ return errors::InvalidArgument(
+ "Asserted next ", dataset()->transformations_.size(),
+ " transformations but encountered only ", tokens.size() - 2, ".");
+ }
+ int n = tokens.size();
+ for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
+ if (dataset()->transformations_[i] != tokens[n - 2 - i]) {
+ return errors::InvalidArgument(
+ "Asserted ", dataset()->transformations_[i],
+ " transformation at offset ", i, " but encountered ",
+ tokens[n - 2 - i], " transformation instead.");
+ }
+ }
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ const DatasetBase* input_;
+ const std::vector<string> transformations_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU),
+ AssertNextDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
new file mode 100644
index 0000000000..7451ca4cb1
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
@@ -0,0 +1,860 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/parsing_ops.cc.
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/io/inputstream_interface.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
+#include "tensorflow/core/lib/io/zlib_inputstream.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class CSVDatasetOp : public DatasetOpKernel {
+ public:
+ explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ string compression_type;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
+ &compression_type));
+
+ OpInputList record_defaults_list;
+ OP_REQUIRES_OK(ctx,
+ ctx->input_list("record_defaults", &record_defaults_list));
+ for (int i = 0; i < record_defaults_list.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
+ OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2,
+ errors::InvalidArgument(
+ "There should only be 1 default per field but field ", i,
+ " has ", record_defaults_list[i].NumElements()));
+ }
+
+ const Tensor* select_cols_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor));
+ OP_REQUIRES(ctx, select_cols_tensor->dims() == 1,
+ errors::InvalidArgument("`select_cols` must be a vector."));
+
+ int64 buffer_size;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
+ OP_REQUIRES(ctx, buffer_size > 0,
+ errors::InvalidArgument("buffer_size should be positive"));
+
+ string delim;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "field_delim", &delim));
+ OP_REQUIRES(ctx, delim.size() == 1,
+ errors::InvalidArgument("field_delim should be only 1 char"));
+
+ bool header;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "header", &header));
+
+ bool use_quote_delim;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "use_quote_delim",
+ &use_quote_delim));
+ string na_value;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "na_value", &na_value));
+
+ std::vector<Tensor> record_defaults;
+ record_defaults.reserve(record_defaults_list.size());
+ for (const Tensor& t : record_defaults_list) {
+ record_defaults.push_back(t);
+ }
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ io::ZlibCompressionOptions zlib_compression_options =
+ io::ZlibCompressionOptions::DEFAULT();
+ if (compression_type == "ZLIB") {
+ zlib_compression_options = io::ZlibCompressionOptions::DEFAULT();
+ } else if (compression_type == "GZIP") {
+ zlib_compression_options = io::ZlibCompressionOptions::GZIP();
+ } else {
+ OP_REQUIRES(ctx, compression_type.empty(),
+ errors::InvalidArgument(
+ "Unsupported compression_type: ", compression_type, "."));
+ }
+ zlib_compression_options.input_buffer_size = buffer_size;
+
+ std::vector<int64> select_cols;
+ select_cols.reserve(select_cols_tensor->NumElements());
+ for (int i = 0; i < select_cols_tensor->NumElements(); ++i) {
+ select_cols.push_back(select_cols_tensor->flat<int64>()(i));
+ }
+ OP_REQUIRES(
+ ctx, output_types_.size() == select_cols.size() || select_cols.empty(),
+ errors::InvalidArgument("select_cols should match output size"));
+ for (int i = 1; i < select_cols.size(); i++) {
+ OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i],
+ errors::InvalidArgument(
+ "select_cols should be strictly increasing indices"));
+ }
+ OP_REQUIRES(
+ ctx, select_cols.empty() || select_cols.front() >= 0,
+ errors::InvalidArgument("select_cols should be non-negative indices"));
+
+ *output = new Dataset(ctx, std::move(filenames), header,
+ std::move(compression_type), zlib_compression_options,
+ output_types_, output_shapes_,
+ std::move(record_defaults), std::move(select_cols),
+ use_quote_delim, delim[0], std::move(na_value));
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header,
+ string compression_type, io::ZlibCompressionOptions options,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
+ bool use_quote_delim, char delim, string na_value)
+ : DatasetBase(DatasetContext(ctx)),
+ filenames_(std::move(filenames)),
+ header_(header),
+ out_type_(output_types),
+ output_shapes_(output_shapes),
+ record_defaults_(std::move(record_defaults)),
+ select_cols_(std::move(select_cols)),
+ use_quote_delim_(use_quote_delim),
+ delim_(delim),
+ na_value_(std::move(na_value)),
+ use_compression_(!compression_type.empty()),
+ compression_type_(std::move(compression_type)),
+ options_(options) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::CSV")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override { return out_type_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override { return "CSVDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ Node* compression_type = nullptr;
+ Node* buffer_size = nullptr;
+ Node* header = nullptr;
+ Node* delim = nullptr;
+ Node* use_quote_delim = nullptr;
+ Node* na_value = nullptr;
+ Node* select_cols = nullptr;
+
+ std::vector<Node*> record_defaults;
+ record_defaults.reserve(record_defaults_.size());
+ for (const Tensor& t : record_defaults_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ record_defaults.emplace_back(node);
+ }
+
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(options_.input_buffer_size, &buffer_size));
+ TF_RETURN_IF_ERROR(b->AddScalar(header_, &header));
+
+ string delim_string(1, delim_);
+ TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value));
+ TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols));
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {std::make_pair(0, filenames), std::make_pair(1, compression_type),
+ std::make_pair(2, buffer_size), std::make_pair(3, header),
+ std::make_pair(4, delim), std::make_pair(5, use_quote_delim),
+ std::make_pair(6, na_value),
+ std::make_pair(7, select_cols)}, // Single tensor inputs
+ {std::make_pair(8, record_defaults)}, // Tensor list inputs
+ {}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ bool select_all = dataset()->select_cols_.empty();
+ do {
+ // We are currently processing a file, so try to read the next record
+ if (input_stream_) {
+ Status s = ReadRecord(ctx, out_tensors, select_all,
+ dataset()->select_cols_);
+ if (s.ok()) {
+ // Validate output
+ if (out_tensors->size() != dataset()->out_type_.size()) {
+ return errors::InvalidArgument(
+ "Expect ", dataset()->out_type_.size(), " fields but have ",
+ out_tensors->size(), " in record");
+ }
+
+ *end_of_sequence = false;
+ return s;
+ }
+ if (!errors::IsOutOfRange(s)) {
+ // Not at the end of file, return OK or non-EOF errors to caller.
+ *end_of_sequence = false;
+ return s;
+ }
+ // We have reached the end of the current file, so maybe
+ // move on to next file.
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ // Iteration ends when there are no more files to process.
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
+ current_file_index_));
+ // `input_stream_` is empty if
+ // 1. GetNext has not been called even once.
+ // 2. All files have been read and the iterator has been exhausted.
+ if (input_stream_ && num_buffer_reads_ > 0) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_));
+ // If num_buffer_reads_ == 0, the buffer hasn't been filled even once.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"),
+ num_buffer_reads_));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ ResetStreamsLocked();
+ int64 current_file_index;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
+ &current_file_index));
+ current_file_index_ = size_t(current_file_index);
+ // The keys "pos" and "num_buffer_reads" are written only if
+ // the iterator was saved with an open, partially read file.
+ if (reader->Contains(full_name("pos"))) {
+ int64 pos, num_buffer_reads;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"),
+ &num_buffer_reads));
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+
+ num_buffer_reads_ = size_t(num_buffer_reads - 1);
+
+ // Restores the most recently held buffer
+ Status s = input_stream_->SkipNBytes(
+ num_buffer_reads_ * dataset()->options_.input_buffer_size);
+ if (!s.ok() && !errors::IsOutOfRange(s)) {
+ // We might get out of range error here if the size of the file
+ // is not an exact multiple of the buffer size, and the last buffer
+ // read is < buffer_size. This is valid and we do not surface the
+ // error.
+ return s;
+ }
+
+ Status s2 = FillBuffer(&buffer_);
+ if (!s2.ok() && !errors::IsOutOfRange(s2)) {
+ return s2;
+ }
+ pos_ = size_t(pos);
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Reads an entire CSV row from the input stream, either from the
+ // existing buffer or by filling the buffer as needed. Converts extracted
+ // fields to output tensors as we go.
+ //
+ // When this function is called, pos_ should be the index of the first
+ // character of the record in buffer_, or past the end of the buffer.
+ // Note: ctx and out_tensors are only used in this function
+ // when fields are included in the record.
+ Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool select_all, const std::vector<int64>& selected)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ // At the end of the file, this will return errors::OutOfRange
+ TF_RETURN_IF_ERROR(FillBuffer(&buffer_));
+ pos_ = 0;
+ }
+
+ // The first character may be \n if this is the continuation of a
+ // \r\n linebreak between this and the previous record. If so, skip it.
+
+ bool end_of_record = false; // Keep track of when we find \n, \r or EOF
+ size_t num_parsed = 0;
+ size_t num_selected_parsed = 0;
+
+ Status result;
+
+ while (!end_of_record) { // Read till we reach \n, \r or EOF
+ bool include =
+ select_all || (num_selected_parsed < selected.size() &&
+ selected[num_selected_parsed] == num_parsed);
+
+ // Don't fail fast, so that the next call to GetNext may still return
+ // a valid record
+ result.Update(
+ ParseOneField(ctx, out_tensors, &end_of_record, include));
+
+ num_parsed++;
+ if (include) num_selected_parsed++;
+ }
+
+ return result;
+ }
+
+ // Parses one field from position pos_ in the buffer. Fields are
+ // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of
+ // the next field.
+ Status ParseOneField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ // If we get here, this means the previous field's end coincided
+ // with the end of the buffer. We can fill the buffer without abandon.
+ Status s = FillBuffer(&buffer_);
+
+ if (errors::IsOutOfRange(s)) {
+ // Reached EOF, and last field is empty
+ *end_of_record = true;
+ if (include) {
+ return FieldToOutput(ctx, StringPiece(), out_tensors);
+ } else {
+ return Status::OK();
+ }
+ } else if (!s.ok()) {
+ return s; // Surface other errors back to caller
+ }
+
+ pos_ = 0;
+ }
+
+ if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') {
+ return ParseQuotedField(ctx, out_tensors, end_of_record, include);
+ }
+
+ return ParseUnquotedField(ctx, out_tensors, end_of_record, include);
+ }
+
+ // For keeping track of relevant parts of a field from a previous buffer
+ struct Piece {
+ size_t start;
+ size_t len;
+ string buffer;
+
+ Piece(string buffer, size_t start, size_t len)
+ : start(start), len(len), buffer(std::move(buffer)) {}
+ };
+
+ // Given that pos_ exceeds the buffer, saves the relevant part of the
+ // current buffer (if necessary), fills the buffer, and resets indices to
+ // 0.
+ Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces,
+ size_t* start, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ string temp_buffer;
+
+ buffer_.swap(temp_buffer);
+ if (include && pos_ > *start) {
+ earlier_pieces->push_back(
+ Piece(std::move(temp_buffer), *start, pos_ - *start));
+ }
+ pos_ = 0;
+ *start = 0;
+ return FillBuffer(&buffer_);
+ }
+
+ // Parses unquoted field from position pos_ in the buffer. Continually
+ // reads from buffer until end of field is reached (delim, CRLF, or EOF).
+ // Advances pos_ to keep track of our position in the buffer as we go,
+ // stopping at the first character of the next field.
+ Status ParseQuotedField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::vector<Piece> earlier_pieces;
+ size_t start = pos_;
+ pos_++; // Starting quotation mark
+
+ Status parse_result;
+ while (true) { // Each iter reads 1 char, filling buffer if necessary
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ if (errors::IsOutOfRange(s)) {
+ return errors::InvalidArgument(
+ "Reached end of file without closing quoted field in "
+ "record");
+ } else if (!s.ok()) {
+ return s; // Surface all other errors to caller
+ }
+ }
+
+ char ch = buffer_[pos_];
+ if (ch == '"') {
+ // When we encounter a quote, we look ahead to the next character to
+ // decide what to do
+ pos_++;
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ if (errors::IsOutOfRange(s)) {
+ // This was the last field. We are done
+ *end_of_record = true;
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(), out_tensors, earlier_pieces, include));
+ return parse_result;
+ } else if (!s.ok()) {
+ return s;
+ }
+ }
+
+ char next = buffer_[pos_];
+ pos_++;
+ if (next == dataset()->delim_) {
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
+ out_tensors, earlier_pieces, include));
+ return parse_result;
+
+ } else if (next == '\n' || next == '\r') {
+ *end_of_record = true;
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
+ out_tensors, earlier_pieces, include));
+ if (next == '\r') SkipNewLineIfNecessary();
+ return parse_result;
+ } else if (next != '"') {
+ // Take note of the error, but keep going to end of field.
+ include = false; // So we don't get funky errors when trying to
+ // unescape the quotes.
+ parse_result.Update(errors::InvalidArgument(
+ "Quote inside a string has to be escaped by another quote"));
+ }
+
+ } else {
+ pos_++;
+ }
+ }
+ }
+
+ // Converts quoted field to an output tensor, removing the starting
+ // and ending quotes from it and unescaping double quotations if
+ // necessary.
+ Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors,
+ const std::vector<Piece>& earlier_pieces,
+ bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!include) return Status::OK();
+
+ if (earlier_pieces.empty()) {
+ if (field.find('\"', 1) == field.size() - 1) {
+ // `field` contains no escaped quotation marks.
+ // Exclude framing quotation marks
+ field.remove_prefix(1);
+ field.remove_suffix(1);
+ return FieldToOutput(ctx, field, out_tensors);
+ }
+ }
+ string field_complete;
+ size_t str_len = field.size();
+ for (const Piece& p : earlier_pieces) {
+ str_len += p.len;
+ }
+ field_complete.reserve(str_len);
+
+ // This bool flips every time we see a quote, so that we skip the second
+ // quote of every pair of adjacent quotes in the field. We need to track
+ // this across iterations of the for loop because adjacent double quotes
+ // may be in different buffers. Initialize to true because we also skip
+ // the opening quotation mark of the quoted field.
+ bool skip_next_quote = true;
+ for (const Piece& p : earlier_pieces) {
+ AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len),
+ &field_complete, &skip_next_quote);
+ }
+ AppendUnescapedPiece(field, &field_complete, &skip_next_quote);
+ StringPiece result = StringPiece(field_complete);
+ result.remove_suffix(1); // Skip final quote
+
+ return FieldToOutput(ctx, result, out_tensors);
+ }
+
+ void AppendUnescapedPiece(StringPiece piece, string* field_complete,
+ bool* skip_next_quote) {
+ size_t from = 0;
+ size_t found = piece.find('\"', from);
+ while (found != string::npos) {
+ if (!*skip_next_quote) {
+ // This is the first quote in a pair of adjacent double quotes
+ field_complete->append(piece.data() + from, found + 1 - from);
+ }
+ *skip_next_quote = !*skip_next_quote;
+ from = found + 1;
+ found = piece.find('\"', from);
+ }
+ // Include the chunk after the last quotation mark in the string
+ if (from < piece.size()) {
+ field_complete->append(piece.data() + from, piece.size() - from);
+ }
+ }
+
+ // Parses unquoted field from position pos_ in the buffer. Continually
+ // reads from buffer until end of field is reached (delim, CRLF, or EOF).
+ // Advances pos_ to keep track of our position in the buffer as we go,
+ // stopping at the first character of the next field.
+ Status ParseUnquotedField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::vector<Piece> earlier_pieces;
+ size_t start = pos_;
+ Status parse_result;
+
+ while (true) { // Each iter reads 1 char, filling buffer if necessary
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ // Handle errors
+ if (errors::IsOutOfRange(s)) {
+ // Whatever we have is the last field of the last record
+ *end_of_record = true;
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ return parse_result;
+ } else if (!s.ok()) {
+ return s; // Surface all other errors to caller
+ }
+ }
+
+ char ch = buffer_[pos_];
+
+ if (ch == dataset()->delim_) {
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ pos_++;
+ return parse_result;
+ }
+ if (ch == '\n' || ch == '\r') {
+ // need special case to skip over first \n of record if the line
+ // breaks are \r\n
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ *end_of_record = true;
+ pos_++;
+ if (ch == '\r') SkipNewLineIfNecessary();
+ return parse_result;
+ }
+ if (dataset()->use_quote_delim_ && ch == '"') {
+ // Take note of the error, but keep going to end of field.
+ parse_result.Update(errors::InvalidArgument(
+ "Unquoted fields cannot have quotes inside"));
+ }
+ // Otherwise, go to next character
+ pos_++;
+ }
+ }
+
+ Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ result->clear();
+ ++num_buffer_reads_;
+ Status s = input_stream_->ReadNBytes(
+ dataset()->options_.input_buffer_size, result);
+
+ if (errors::IsOutOfRange(s) && !result->empty()) {
+ // Ignore OutOfRange error when ReadNBytes read < N bytes.
+ return Status::OK();
+ }
+ return s;
+ }
+
+ // Given a field, converts it to the right output tensor type
+ Status FieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors) {
+ size_t output_idx = out_tensors->size();
+ if (output_idx >= dataset()->out_type_.size()) {
+ // We can get here if we're selecting all columns, but the number of
+ // fields exceeds the number of defaults provided
+ return errors::InvalidArgument("Expect ", dataset()->out_type_.size(),
+ " fields but have more in record");
+ }
+ const DataType& dtype = dataset()->out_type_[output_idx];
+ Tensor component(ctx->allocator({}), dtype, {});
+ if ((field.empty() || field == dataset()->na_value_) &&
+ dataset()->record_defaults_[output_idx].NumElements() != 1) {
+ // If the field is empty or NA value, and default is not given,
+ // report error.
+ return errors::InvalidArgument("Field ", output_idx,
+ " is required but missing in record!");
+ }
+
+ switch (dtype) {
+ // For each case, if the field is empty, we use the default.
+ // Otherwise, we convert it to the right type.
+ case DT_INT32: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<int32>()() =
+ dataset()->record_defaults_[output_idx].flat<int32>()(0);
+ } else {
+ int32 value;
+ if (!strings::safe_strto32(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid int32: ", field);
+ }
+ component.scalar<int32>()() = value;
+ }
+ break;
+ }
+ case DT_INT64: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<int64>()() =
+ dataset()->record_defaults_[output_idx].flat<int64>()(0);
+ } else {
+ int64 value;
+ if (!strings::safe_strto64(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid int64: ", field);
+ }
+ component.scalar<int64>()() = value;
+ }
+ break;
+ }
+ case DT_FLOAT: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<float>()() =
+ dataset()->record_defaults_[output_idx].flat<float>()(0);
+ } else {
+ float value;
+ if (!strings::safe_strtof(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid float: ", field);
+ }
+ component.scalar<float>()() = value;
+ }
+ break;
+ }
+ case DT_DOUBLE: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<double>()() =
+ dataset()->record_defaults_[output_idx].flat<double>()(0);
+ } else {
+ double value;
+ if (!strings::safe_strtod(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid double: ", field);
+ }
+ component.scalar<double>()() = value;
+ }
+ break;
+ }
+ case DT_STRING: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<string>()() =
+ dataset()->record_defaults_[output_idx].flat<string>()(0);
+ } else {
+ component.scalar<string>()() = string(field);
+ }
+ break;
+ }
+ default:
+ return errors::InvalidArgument("csv: data type ", dtype,
+ " not supported in field ",
+ output_idx);
+ }
+ out_tensors->push_back(std::move(component));
+ return Status::OK();
+ }
+
+ // Records can be delimited by "\r\n" line breaks. When we encounter a
+ // '\r', we have to check the next character to see if it is part of the
+ // linebreak, and ignore it if so.
+ void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ Status s = FillBuffer(&buffer_);
+ pos_ = 0;
+ // If we failed to fill buffer, it doesn't matter because we're done
+ // with the record
+ if (!s.ok()) return;
+ }
+ if (buffer_[pos_] == '\n') {
+ pos_++;
+ }
+ }
+
+ // Given a string field, and its index in the output,
+ // converts it to a Tensor of the right type and adds it to the
+ // out_tensors vector.
+ Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors,
+ const std::vector<Piece>& earlier_pieces,
+ bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!include) return Status::OK();
+
+ if (earlier_pieces.empty()) {
+ return FieldToOutput(ctx, field, out_tensors);
+ }
+
+ size_t str_len = field.size();
+ for (const Piece& p : earlier_pieces) {
+ str_len += p.len;
+ }
+ string field_complete;
+ field_complete.reserve(str_len);
+
+ for (const Piece& p : earlier_pieces) {
+ field_complete.append(p.buffer, p.start, p.len);
+ }
+
+ field_complete.append(field.data(), field.size());
+ return FieldToOutput(ctx, field_complete, out_tensors);
+ }
+
+ // Sets up reader streams to read from the file at `current_file_index_`.
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+
+ // Actually move on to next file.
+ TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
+ dataset()->filenames_[current_file_index_], &file_));
+ random_access_input_stream_ =
+ std::make_shared<io::RandomAccessInputStream>(file_.get(), false);
+
+ if (dataset()->use_compression_) {
+ input_stream_ = std::make_shared<io::ZlibInputStream>(
+ random_access_input_stream_.get(),
+ dataset()->options_.input_buffer_size,
+ dataset()->options_.input_buffer_size, dataset()->options_);
+ } else {
+ input_stream_ = random_access_input_stream_;
+ }
+ buffer_.clear();
+ pos_ = 0;
+ num_buffer_reads_ = 0;
+ if (dataset()->header_) {
+ // Read one line, but don't include it. Pass nullptrs as dummy
+ // pointers to objects that shouldn't be invoked anyway
+ // We need to process this as a record here instead of just finding
+ // the first newline because it might contain quoted fields with
+ // newlines in the header as well
+ std::vector<int64> empty;
+ Status s = ReadRecord(nullptr, nullptr, false, empty);
+ if (!s.ok()) {
+ return errors::InvalidArgument("Can't read header of file");
+ }
+ }
+ return Status::OK();
+ }
+
+ // Resets all reader streams.
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ input_stream_.reset();
+ file_.reset();
+ }
+
+ mutex mu_;
+ string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
+ size_t pos_ GUARDED_BY(
+ mu_); // Index into the buffer must be maintained between iters
+ size_t num_buffer_reads_ GUARDED_BY(mu_);
+ std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_
+ GUARDED_BY(mu_);
+ std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_);
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<RandomAccessFile> file_
+ GUARDED_BY(mu_); // must outlive input_stream_
+ }; // class Iterator
+
+ const std::vector<string> filenames_;
+ const bool header_;
+ const DataTypeVector out_type_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const std::vector<Tensor> record_defaults_;
+ const std::vector<int64> select_cols_;
+ const bool use_quote_delim_;
+ const char delim_;
+ const string na_value_;
+ const bool use_compression_;
+ const string compression_type_;
+ const io::ZlibCompressionOptions options_;
+ }; // class Dataset
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+}; // class CSVDatasetOp
+
+// Register the kernel implementation for CSVDataset.
+REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU),
+ CSVDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
new file mode 100644
index 0000000000..c47a9099c4
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -0,0 +1,281 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class DirectedInterleaveDatasetOp : public DatasetOpKernel {
+ public:
+ explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
+ : DatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ DatasetBase* selector_input;
+ OP_REQUIRES_OK(ctx,
+ GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
+
+ OP_REQUIRES(
+ ctx,
+ selector_input->output_dtypes().size() == 1 &&
+ selector_input->output_dtypes()[0] == DT_INT64 &&
+ selector_input->output_shapes().size() == 1 &&
+ selector_input->output_shapes()[0].IsCompatibleWith(
+ PartialTensorShape({})),
+ errors::InvalidArgument(
+ "The selector input must be a dataset of scalar int64 elements."));
+
+ std::vector<DatasetBase*> data_inputs;
+ for (size_t i = 1; i < ctx->num_inputs(); ++i) {
+ DatasetBase* input;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
+ data_inputs.push_back(input);
+
+ OP_REQUIRES(
+ ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
+ errors::InvalidArgument(
+ "All inputs must have the same output_dtypes. First input "
+ "has types ",
+ DataTypeVectorString(data_inputs[0]->output_dtypes()),
+ ", and input ", i - 1, " has types ",
+ DataTypeVectorString(input->output_dtypes())));
+ }
+ *output = new Dataset(ctx, selector_input, std::move(data_inputs));
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
+ std::vector<DatasetBase*> data_inputs)
+ : DatasetBase(DatasetContext(ctx)),
+ selector_input_(selector_input),
+ data_inputs_(std::move(data_inputs)) {
+ selector_input_->Ref();
+
+ output_shapes_ = data_inputs_[0]->output_shapes();
+ data_inputs_[0]->Ref();
+ for (size_t i = 1; i < data_inputs_.size(); ++i) {
+ const DatasetBase* data_input = data_inputs_[i];
+ data_input->Ref();
+ for (size_t j = 0; j < output_shapes_.size(); ++j) {
+ output_shapes_[j] = MostSpecificCompatibleShape(
+ output_shapes_[j], data_input->output_shapes()[j]);
+ }
+ }
+ }
+
+ ~Dataset() override {
+ selector_input_->Unref();
+ for (DatasetBase* data_input : data_inputs_) {
+ data_input->Unref();
+ }
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::DirectedInterleave")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return data_inputs_[0]->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* selector_input_node;
+ TF_RETURN_IF_ERROR(
+ b->AddInputDataset(ctx, selector_input_, &selector_input_node));
+ std::vector<Node*> data_input_nodes(data_inputs_.size());
+ for (size_t i = 0; i < data_inputs_.size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
+ }
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
+ {{1, data_input_nodes}}, {}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ num_active_inputs_(params.dataset->data_inputs_.size()) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
+ ctx, strings::StrCat(prefix(), ".selector"),
+ &selector_input_impl_));
+ data_input_impls_.resize(dataset()->data_inputs_.size());
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ const DatasetBase* data_input = dataset()->data_inputs_[i];
+ TF_RETURN_IF_ERROR(data_input->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[", i, "]"),
+ &data_input_impls_[i]));
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (!selector_input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ while (true) {
+ std::vector<Tensor> selector_result;
+ *end_of_sequence = false;
+ TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
+ ctx, &selector_result, end_of_sequence));
+ if (*end_of_sequence) {
+ selector_input_impl_.reset();
+ for (auto& data_input_impl : data_input_impls_) {
+ data_input_impl.reset();
+ }
+ return Status::OK();
+ }
+
+ int64 selected_input = selector_result[0].scalar<int64>()();
+ if (selected_input < 0 || selected_input > data_input_impls_.size()) {
+ return errors::InvalidArgument(
+ "Selector index out of range: ", selected_input,
+ " >= ", data_input_impls_.size());
+ }
+
+ if (data_input_impls_[selected_input]) {
+ bool end_of_selected_input = false;
+ TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
+ ctx, out_tensors, &end_of_selected_input));
+
+ if (!end_of_selected_input) {
+ return Status::OK();
+ }
+
+ data_input_impls_[selected_input].reset();
+ --num_active_inputs_;
+
+ if (num_active_inputs_ == 0) {
+ selector_input_impl_.reset();
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+
+ LOG(WARNING) << "DirectedInterleave selected an exhausted input: "
+ << selected_input;
+ }
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (selector_input_impl_) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
+ }
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ const auto& data_input_impl = data_input_impls_[i];
+ if (data_input_impl) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl));
+ } else {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
+ ""));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("selector_input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
+ } else {
+ selector_input_impl_.reset();
+ }
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ if (!reader->Contains(full_name(
+ strings::StrCat("data_input_impl_empty[", i, "]")))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
+ } else {
+ data_input_impls_[i].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> selector_input_impl_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
+ GUARDED_BY(mu_);
+ int64 num_active_inputs_ GUARDED_BY(mu_);
+ };
+
+ static PartialTensorShape MostSpecificCompatibleShape(
+ const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
+ PartialTensorShape output_tensorshape;
+ if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
+ return output_tensorshape;
+ auto dims1 = ts1.dim_sizes();
+ auto dims2 = ts2.dim_sizes();
+ for (int d = 0; d < ts1.dims(); d++) {
+ if (dims1[d] == dims2[d])
+ output_tensorshape.Concatenate(dims1[d]);
+ else
+ output_tensorshape.Concatenate(-1);
+ }
+ return output_tensorshape;
+ }
+
+ const DatasetBase* const selector_input_;
+ const std::vector<DatasetBase*> data_inputs_;
+ std::vector<PartialTensorShape> output_shapes_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
+ DirectedInterleaveDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
new file mode 100644
index 0000000000..2141f118ca
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
@@ -0,0 +1,156 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
+ public:
+ using IndexedDatasetOpKernel::IndexedDatasetOpKernel;
+
+ void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) override {
+ uint64 size = -1;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<uint64>(ctx, "size", &size));
+ OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0"));
+ *output = new Dataset(ctx, size);
+ }
+
+ class Dataset : public IndexedDataset {
+ public:
+ Dataset(OpKernelContext* ctx, uint64 size)
+ : IndexedDataset(DatasetContext(ctx)), size_(size) {}
+
+ Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) override {
+ materialized->reset(new Materialized(this));
+ return Status::OK();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::IdentityIndexedDataset")}));
+ }
+
+ string DebugString() const override {
+ return "IdentityIndexedDataset::Dataset";
+ }
+
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** node) const override {
+ return errors::Unimplemented(
+ "identity_indexed_dataset.AsGraphDefInternal");
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (cur_ < dataset()->size_) {
+ Tensor result_tensor(ctx->allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = cur_++;
+ out_tensors->emplace_back(std::move(result_tensor));
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ uint64 cur_ GUARDED_BY(mu_);
+ };
+
+ class Materialized : public MaterializedIndexedDataset {
+ public:
+ explicit Materialized(Dataset* dataset) : dataset_(dataset) {
+ dataset->Ref();
+ }
+
+ ~Materialized() override {
+ // TODO(saeta): Pull this into MaterializedIndexedDataset
+ dataset_->Unref();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return dataset_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return dataset_->output_shapes();
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const override {
+ LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index
+ << ")";
+ if (index >= dataset_->size_) {
+ // Note: use InvalidArgument instead of OutOfRange error because many
+ // things consider OutOfRange to be a "clean termination" error.
+ return errors::InvalidArgument(
+ "Index ", index,
+ " is out of range for this dataset. (Size is: ", dataset_->size_,
+ ".)");
+ }
+ Tensor result_tensor(ctx.allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = index;
+ out_tensors->emplace_back(std::move(result_tensor));
+ return Status::OK();
+ }
+
+ Status Size(uint64* size) const override {
+ *size = dataset_->size_;
+ return Status::OK();
+ }
+
+ private:
+ const Dataset* const dataset_; // Not owned.
+ };
+
+ const uint64 size_;
+ std::shared_ptr<Materialized> materialized_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU),
+ IdentityIndexedDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
new file mode 100644
index 0000000000..b34377c642
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
@@ -0,0 +1,141 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "IgnoreErrorsDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ {
+ tf_shared_lock l(mu_);
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ while (!s.ok()) {
+ out_tensors->clear();
+ s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+ }
+ if (*end_of_sequence) {
+ mutex_lock l(mu_);
+ input_impl_.reset();
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (input_impl_)
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ else
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impls_empty"), ""));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (reader->Contains(full_name("input_impls_empty")))
+ input_impl_.reset();
+ else
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU),
+ IgnoreErrorsDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
new file mode 100644
index 0000000000..75ea462f40
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
@@ -0,0 +1,375 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
+class MaterializedDatasetResource : public ResourceBase {
+ public:
+ MaterializedDatasetResource(
+ const DataTypeVector& output_dtypes,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
+
+ string DebugString() override {
+ return "Materialized IndexedDataset resource";
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) {
+ std::shared_ptr<MaterializedIndexedDataset> captured(materialized_);
+ if (captured) {
+ return captured->Get(std::move(ctx), index, out_tensors);
+ } else {
+ return errors::FailedPrecondition(
+ "Get() failed because the MaterializedIndexedDataset has not been "
+ "initialized. Ensure that you have run the materialization operation "
+ "for this MaterializedIndexedDataset before retrieving elements.");
+ }
+ }
+
+ // TODO(saeta): Implement Save and Restore
+
+ const DataTypeVector& output_dtypes() const { return output_dtypes_; }
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ Status set_materialized_dataset(
+ const std::shared_ptr<MaterializedIndexedDataset>& dataset) {
+ if (dataset) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, dataset->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, dataset->output_shapes()));
+ }
+ materialized_ = dataset;
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<MaterializedIndexedDataset> materialized_;
+ const DataTypeVector output_dtypes_;
+ const std::vector<PartialTensorShape> output_shapes_;
+};
+
+// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT
+// tensor. Objects of the wrapper class own a reference on an instance of an
+// `IndexedTensor` and the wrapper's copy constructor and desctructor take care
+// of managing the reference count.
+//
+// NOTE: This is not a feature-complete implementation of the DT_VARIANT
+// specification. In particular, we cannot currently serialize an arbitrary
+// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not
+// implemented.
+//
+// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just
+// use `tensorflow::DatasetVariantWrapper`.
+class IndexedDatasetVariantWrapper {
+ public:
+ IndexedDatasetVariantWrapper() : dataset_(nullptr) {}
+
+ // Transfers ownership of `dataset` to `*this`.
+ explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset)
+ : dataset_(dataset) {}
+
+ IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other)
+ : dataset_(other.dataset_) {
+ if (dataset_) dataset_->Ref();
+ }
+
+ ~IndexedDatasetVariantWrapper() {
+ if (dataset_) dataset_->Unref();
+ }
+
+ IndexedDataset* get() const { return dataset_; }
+
+ string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; }
+ string DebugString() const {
+ if (dataset_) {
+ return dataset_->DebugString();
+ } else {
+ return "<Uninitialized IndexedDatasetVariantWrapper>";
+ }
+ }
+
+ void Encode(VariantTensorData* data) const {
+ LOG(ERROR) << "The Encode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ }
+
+ bool Decode(const VariantTensorData& data) {
+ LOG(ERROR) << "The Decode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ return false;
+ }
+
+ private:
+ IndexedDataset* const dataset_; // Owns one reference.
+};
+
+} // namespace
+
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset) {
+ if (!(tensor.dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor.shape()))) {
+ return errors::InvalidArgument(
+ "IndexedDataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ const Variant& variant = tensor.scalar<Variant>()();
+ const IndexedDatasetVariantWrapper* wrapper =
+ variant.get<IndexedDatasetVariantWrapper>();
+ if (wrapper == nullptr) {
+ return errors::InvalidArgument("Tensor must be an IndexedDataset object.");
+ }
+ *out_dataset = wrapper->get();
+ if (*out_dataset == nullptr) {
+ return errors::Internal("Read uninitialized IndexedDataset variant.");
+ }
+ return Status::OK();
+}
+
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor) {
+ if (!(tensor->dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor->shape()))) {
+ return errors::InvalidArgument(
+ "Dataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ tensor->scalar<Variant>()() = IndexedDatasetVariantWrapper(dataset);
+ return Status::OK();
+}
+
+void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) {
+ IndexedDataset* dataset = nullptr;
+ MakeIndexedDataset(ctx, &dataset);
+
+ if (ctx->status().ok()) {
+ OP_REQUIRES(ctx, dataset != nullptr,
+ errors::Internal("MakeIndexedDataset did not correctly "
+ "construct the IndexedDataset"));
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
+ OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output));
+ }
+}
+
+namespace {
+
+class MaterializedHandleOp : public OpKernel {
+ public:
+ explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ ~MaterializedHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MaterializedDatasetResource* resource;
+ OP_REQUIRES_OK(context,
+ mgr->LookupOrCreate<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this](MaterializedDatasetResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MaterializedDatasetResource(
+ output_dtypes_, output_shapes_);
+ return Status::OK();
+ }));
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MaterializedDatasetResource>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MaterializedDatasetResource* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+// TODO(saeta): Make async.
+class MaterializeDatasetOp : public OpKernel {
+ public:
+ explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ IndexedDataset* dataset;
+ OP_REQUIRES_OK(ctx,
+ GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset));
+
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
+ &materialized_resource));
+ core::ScopedUnref unref(materialized_resource);
+ std::shared_ptr<MaterializedIndexedDataset> materialized;
+ OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized));
+ OP_REQUIRES_OK(
+ ctx, materialized_resource->set_materialized_dataset(materialized));
+ }
+};
+
+// TODO(saeta): Make async
+class IndexedDatasetGet : public OpKernel {
+ public:
+ explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0),
+ &materialized_resource));
+ auto cleanup = gtl::MakeCleanup([materialized_resource] {
+ materialized_resource->Unref(); // Note: can't use core::ScopedUnref.
+ });
+
+ const Tensor* index_t;
+ OP_REQUIRES_OK(ctx, ctx->input("index", &index_t));
+ // TODO(saeta): Support batch reads (indexes should be non-scalar!)
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()),
+ errors::InvalidArgument("index must be a scalar"));
+ const uint64 index = index_t->scalar<uint64>()();
+
+ std::vector<Tensor> out_tensors;
+ Status s =
+ materialized_resource->Get(IteratorContext(ctx), index, &out_tensors);
+
+ // Note: Unref materialized_resource to avoid destruction races. (Important
+ // in a [future] async op implementation.)
+ cleanup.release()();
+
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else {
+ auto expected_shapes = materialized_resource->output_shapes();
+ auto expected_types = materialized_resource->output_dtypes();
+ for (size_t i = 0; i < out_tensors.size(); ++i) {
+ OP_REQUIRES(
+ ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()),
+ errors::Internal(
+ "Materialized dataset output at index ", i,
+ " is incompatible with the expected shape. (Expected: ",
+ expected_shapes[i], ", got: ", out_tensors[i].shape(), ")"));
+ OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i],
+ errors::Internal("Materialized dataset output at index ", i,
+ " was not the expected dtype. (Expected: ",
+ expected_types[i],
+ ", got: ", out_tensors[i].dtype(), ")"));
+ ctx->set_output(i, out_tensors[i]);
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU),
+ MaterializedHandleOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU),
+ MaterializeDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU),
+ IndexedDatasetGet);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.h b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
new file mode 100644
index 0000000000..27a8360cbc
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace data {
+
+// TODO(saeta): Urgh, this is ugly.
+class MaterializedIndexedDataset {
+ public:
+ virtual ~MaterializedIndexedDataset() = default;
+
+ // Retrieve the element at a given index. The output tensors are stored in
+ // out_tensors.
+ //
+ // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is
+ // returned.
+ //
+ // Get is thread-safe.
+ virtual Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const = 0;
+
+ // Size determines the number of elements in this IndexedDataset.
+ //
+ // Size is thread-safe.
+ virtual Status Size(uint64* size) const = 0;
+
+ // Returns a vector of DataType values, representing the respective
+ // element types of each tuple component in the outputs of this dataset.
+ virtual const DataTypeVector& output_dtypes() const = 0;
+
+ // Returns a vector of tensor shapes, representing the respective
+ // (and possibly partially defined) shapes of each tuple component
+ // in the outputs of this dataset.
+ virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+};
+
+// IndexedDataset represents a dataset that supports random access in addition
+// to iterator-based sequential access.
+//
+// Note: IndexedDatasets are HIGHLY experimental at this time. Expect
+// significant (backwards incompatible) changes!
+class IndexedDataset : public DatasetBase {
+ public:
+ IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {}
+
+ // Materialize (if necessary) the dataset, and return a pointer.
+ // TODO(saeta): Add in `IteratorContext* ctx` when materializing.
+ virtual Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) = 0;
+};
+
+// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the
+// rest of the TensorFlow runtime.
+//
+// Most IndexedDataset's will be private members of classes inheriting from this
+// class.
+class IndexedDatasetOpKernel : public OpKernel {
+ public:
+ IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) final;
+
+ protected:
+ // Subclasses should implement this method. It will be called during Compute
+ // execution.
+ virtual void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) = 0;
+
+ template <typename T>
+ Status ParseScalarArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name, T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+ }
+};
+
+// Validates and extracts an `IndexedDataset` object from `tensor`.
+//
+// `tensor` must have been written by a call to
+// `StoreIndexedDatasetInVariantTensor`
+//
+// The retrieved pointer isa borrowed reference to the dataset, which is owned
+// by the tensor. The consumer must either acquire its own reference to the
+// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
+// destroyed or mutated while the retrieved pointer is in use.
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset);
+
+// Stores an `IndexedDataset` object in `tensor.`
+//
+// The ownership of `dataset` is transferred to `tensor`.
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor);
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
new file mode 100644
index 0000000000..8a88d32f0c
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
@@ -0,0 +1,218 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sys/stat.h>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/platform/file_system.h"
+
+#include "lmdb.h" // NOLINT(build/include)
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class LMDBDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ *output = new Dataset(ctx, filenames);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const std::vector<string>& filenames)
+ : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::LMDB")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes =
+ new DataTypeVector({DT_STRING, DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ do {
+ if (mdb_cursor_) {
+ Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
+ key_tensor.scalar<string>()() = string(
+ static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
+ out_tensors->emplace_back(std::move(key_tensor));
+
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<string>()() =
+ string(static_cast<const char*>(mdb_value_.mv_data),
+ mdb_value_.mv_size);
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ int val;
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ private:
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+ const string& filename = dataset()->filenames_[current_file_index_];
+
+ int val = mdb_env_create(&mdb_env_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK;
+
+ struct stat source_stat;
+ if (stat(filename.c_str(), &source_stat) == 0 &&
+ (source_stat.st_mode & S_IFREG)) {
+ flags |= MDB_NOSUBDIR;
+ }
+ val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ }
+ return Status::OK();
+ }
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (mdb_env_ != nullptr) {
+ if (mdb_cursor_) {
+ mdb_cursor_close(mdb_cursor_);
+ mdb_cursor_ = nullptr;
+ }
+ mdb_dbi_close(mdb_env_, mdb_dbi_);
+ mdb_txn_abort(mdb_txn_);
+ mdb_env_close(mdb_env_);
+ mdb_txn_ = nullptr;
+ mdb_dbi_ = 0;
+ mdb_env_ = nullptr;
+ }
+ }
+ mutex mu_;
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr;
+ MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr;
+ MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0;
+ MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr;
+
+ MDB_val mdb_key_ GUARDED_BY(mu_);
+ MDB_val mdb_value_ GUARDED_BY(mu_);
+ };
+
+ const std::vector<string> filenames_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU),
+ LMDBDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
new file mode 100644
index 0000000000..2c6179d9f5
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
@@ -0,0 +1,482 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+struct BufferElement {
+ // The producer sets `status` if getting the input element fails.
+ Status status;
+ // The buffered data element.
+ std::vector<Tensor> value;
+};
+
+using FunctionBufferCallback = std::function<void(const BufferElement&)>;
+
+class FunctionBufferingResource : public ResourceBase {
+ public:
+ FunctionBufferingResource(FunctionLibraryRuntime* lib,
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+ const NameAttrList& func, int64 buffer_size,
+ const string& source_device,
+ const string& target_device,
+ const std::vector<Tensor>& func_args,
+ const DataTypeVector& output_types)
+ : lib_(lib),
+ pflr_(std::move(pflr)),
+ func_(func),
+ buffer_size_(buffer_size),
+ source_device_(source_device),
+ target_device_(target_device),
+ func_args_(func_args),
+ output_types_(output_types),
+ handle_(kInvalidHandle),
+ is_buffering_(false),
+ end_of_sequence_(false),
+ cancelled_(false) {}
+
+ ~FunctionBufferingResource() override {
+ Cancel();
+ }
+
+ string DebugString() override {
+ return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_,
+ "; target_device: ", target_device_);
+ }
+
+ // Instantiates the function the first time it's called. After that it caches
+ // the handle.
+ Status Instantiate() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ // Re-use existing handle if it's been set, effectively caching it.
+ if (handle_ != kInvalidHandle) {
+ return Status::OK();
+ }
+ AttrValueMap attr_values = func_.attr();
+ FunctionLibraryRuntime::InstantiateOptions opts;
+ opts.target = target_device_;
+ return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts,
+ &handle_);
+ }
+
+ // Returns true if we've got to the end of the sequence and exhausted the
+ // buffer.
+ bool Finished() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return end_of_sequence_ && buffer_.empty();
+ }
+
+ // Cancels any buffering / prefetching going on.
+ void Cancel() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ while (is_buffering_) {
+ cond_var_.wait(l);
+ }
+ }
+
+ // Cancels all pending operations and then clears out the state.
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ Cancel();
+ mutex_lock l(mu_);
+ buffer_.clear();
+ requests_.clear();
+ is_buffering_ = false;
+ end_of_sequence_ = false;
+ cancelled_ = false;
+ }
+
+ // If the buffer has anything, runs `callback` on the first element in the
+ // buffer, else schedules the `callback` to be called. Requires `args` and
+ // `lib` in case more function calls need to be scheduled.
+ void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) {
+ bool start_buffering = false;
+ bool produced_output = false;
+ BufferElement buffer_element;
+ {
+ mutex_lock l(mu_);
+ if (!is_buffering_ && !end_of_sequence_) {
+ start_buffering = true;
+ }
+ if (!buffer_.empty()) {
+ produced_output = true;
+ std::swap(buffer_element, buffer_.front());
+ buffer_.pop_front();
+ } else {
+ produced_output = false;
+ requests_.push_back(std::move(callback));
+ }
+ }
+ if (produced_output) {
+ callback(buffer_element);
+ }
+ if (start_buffering) {
+ FillBuffer();
+ }
+ }
+
+ private:
+ void FillBuffer() LOCKS_EXCLUDED(mu_) {
+ FunctionLibraryRuntime::Handle handle;
+ std::vector<FunctionBufferCallback> cancellation_callbacks;
+ std::vector<BufferElement> cancellation_buffer_elements;
+ bool cancelled = false;
+ {
+ mutex_lock l(mu_);
+ handle = handle_;
+ if (cancelled_) {
+ cancelled = true;
+ // Run through and fulfill all pending requests, if possible.
+ while (!requests_.empty()) {
+ if (!buffer_.empty()) {
+ cancellation_buffer_elements.push_back(std::move(buffer_.front()));
+ buffer_.pop_front();
+ cancellation_callbacks.push_back(std::move(requests_.front()));
+ requests_.pop_front();
+ } else {
+ LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: "
+ << requests_.size() << " requests";
+ break;
+ }
+ }
+ is_buffering_ = false;
+ } else {
+ is_buffering_ = true;
+ }
+ }
+ if (cancelled) {
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_buffer_elements[i]);
+ }
+ cond_var_.notify_all();
+ return;
+ }
+ FunctionLibraryRuntime::Options opts;
+ // Copied from CapturedFunction::generate_step_id();
+ opts.step_id = -std::abs(static_cast<int64>(random::New64()));
+ opts.source_device = source_device_;
+ AllocatorAttributes arg_alloc_attr;
+ arg_alloc_attr.set_on_host(true);
+ opts.args_alloc_attrs.push_back(arg_alloc_attr);
+ for (const auto& dtype : output_types_) {
+ AllocatorAttributes ret_alloc_attrs;
+ if (DataTypeAlwaysOnHost(dtype)) {
+ ret_alloc_attrs.set_on_host(true);
+ }
+ opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
+ }
+ if (opts.source_device != target_device_) {
+ opts.remote_execution = true;
+ }
+ opts.create_rendezvous = true;
+ auto* rets = new std::vector<Tensor>;
+ lib_->Run(opts, handle, func_args_, rets,
+ [this, rets](const Status& status) {
+ FunctionBufferCallback callback = nullptr;
+ BufferElement buffer_front;
+ bool restart_buffering = false;
+ {
+ mutex_lock l(mu_);
+ BufferElement buffer_element;
+ buffer_element.status = status;
+ if (status.ok()) {
+ buffer_element.value.swap(*rets);
+ } else {
+ end_of_sequence_ = true;
+ is_buffering_ = false;
+ }
+ buffer_.push_back(std::move(buffer_element));
+ if (!requests_.empty()) {
+ buffer_front = std::move(buffer_.front());
+ buffer_.pop_front();
+ callback = std::move(requests_.front());
+ requests_.pop_front();
+ }
+ if (buffer_.size() < buffer_size_ && !end_of_sequence_) {
+ restart_buffering = true;
+ } else {
+ // When the buffer is full, we don't want to call
+ // FillBuffer() unless we're in cancellation phase in which
+ // case FillBuffer() will do the final cleanup post
+ // cancellation.
+ if (cancelled_) {
+ restart_buffering = true;
+ }
+ is_buffering_ = false;
+ }
+ }
+ if (callback != nullptr) {
+ callback(buffer_front);
+ }
+ if (restart_buffering) {
+ FillBuffer();
+ }
+ });
+ }
+
+ mutex mu_;
+ FunctionLibraryRuntime* lib_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ NameAttrList func_;
+ const int64 buffer_size_;
+ const string source_device_;
+ const string target_device_;
+ const std::vector<Tensor> func_args_;
+ const DataTypeVector output_types_;
+ FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_);
+ std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
+ std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_);
+ bool is_buffering_ GUARDED_BY(mu_);
+ bool end_of_sequence_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_);
+ condition_variable cond_var_;
+};
+
+class FunctionBufferResourceHandleOp : public OpKernel {
+ public:
+ explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), flib_def_(nullptr) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ }
+
+ ~FunctionBufferResourceHandleOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<FunctionBufferingResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* string_arg;
+ OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg));
+ std::vector<Tensor> func_args;
+ func_args.push_back(*string_arg);
+
+ const string& source_device = ctx->device()->name();
+
+ // Obtain and canonicalize target_device.
+ const Tensor* target_arg;
+ OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg));
+ string target_device;
+ OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName(
+ target_arg->scalar<string>()(), source_device,
+ &target_device));
+
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."));
+
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
+ FunctionLibraryRuntime* clone_lib;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr;
+ OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib));
+ // Create the resource.
+ FunctionBufferingResource* buffer;
+ OP_REQUIRES_OK(
+ ctx,
+ ctx->resource_manager()->LookupOrCreate<FunctionBufferingResource>(
+ cinfo_.container(), cinfo_.name(), &buffer,
+ [clone_lib, &pflr, &source_device, &target_device, func_args,
+ this](FunctionBufferingResource** ptr) {
+ *ptr = new FunctionBufferingResource(
+ clone_lib, std::move(pflr), func_, buffer_size_,
+ source_device, target_device, func_args, output_types_);
+ return Status::OK();
+ }));
+ core::ScopedUnref s(buffer);
+ OP_REQUIRES_OK(ctx, buffer->Instantiate());
+ initialized_ = true;
+ }
+
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<FunctionBufferingResource>()));
+ }
+
+ private:
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ NameAttrList func_;
+ int64 buffer_size_;
+ string container_;
+ string name_;
+ DataTypeVector output_types_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_CPU)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_SYCL)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// Prefetches and fills up a buffer by calling a function that provides the
+// elements to buffer.
+class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
+ public:
+ explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx) {}
+
+ ~FunctionBufferingResourceGetNextOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done);
+ FunctionBufferingResource* buffer = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer),
+ done);
+
+ if (buffer->Finished()) {
+ buffer->Unref();
+ ctx->SetStatus(errors::OutOfRange("end_of_sequence"));
+ done();
+ return;
+ }
+
+ FunctionBufferCallback callback =
+ [ctx, buffer, done](const BufferElement& buffer_element) {
+ Status s = buffer_element.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ buffer->Unref();
+ done();
+ return;
+ }
+ for (size_t i = 0; i < buffer_element.value.size(); ++i) {
+ ctx->set_output(i, buffer_element.value[i]);
+ }
+ buffer->Unref();
+ done();
+ };
+ buffer->MaybeGet(std::move(callback));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_CPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_GPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_SYCL)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// Resets the FunctionBufferingResource, cancelling all pending requests and
+// clearing out the buffer.
+class FunctionBufferingResourceResetOp : public OpKernel {
+ public:
+ explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ ~FunctionBufferingResourceResetOp() override {}
+
+ void Compute(OpKernelContext* ctx) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(ctx,
+ HandleFromInput(ctx, "function_buffer_resource", &handle));
+ FunctionBufferingResource* buffer = nullptr;
+ OP_REQUIRES_OK(
+ ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer));
+ core::ScopedUnref s(buffer);
+
+ buffer->Reset();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_CPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_GPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_SYCL)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+#endif // TENSORFLOW_USE_SYCL
+
+class IteratorGetDeviceOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* ctx) override {
+ // NOTE(mrry): We do not currently Validate that the handle
+ // corresponds to a real IteratorResource, because that symbol is
+ // not exposed from the framework library.
+ Tensor* device_name_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &device_name_t));
+ // NOTE(mrry): Since the operation's input is a resource, we must be
+ // colocated with it, and so we can simply return the current device's
+ // name without looking at the input.
+ device_name_t->scalar<string>()() = ctx->device()->name();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU),
+ IteratorGetDeviceOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
new file mode 100644
index 0000000000..c80493d3a1
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -0,0 +1,220 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class ThreadPoolResource : public ResourceBase {
+ public:
+ ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
+ const string& name, int num_threads, bool low_latency_hint,
+ int max_intra_op_parallelism)
+ : thread_pool_(env, thread_options, name, num_threads, low_latency_hint),
+ max_intra_op_parallelism_(max_intra_op_parallelism) {}
+
+ // Schedules fn() for execution in the pool of threads.
+ void Schedule(std::function<void()> fn) {
+ if (max_intra_op_parallelism_ < 0) {
+ thread_pool_.Schedule(std::move(fn));
+ } else {
+ thread_pool_.Schedule(std::bind(
+ [this](std::function<void()> bound_fn) {
+ // TODO(mrry): Consider moving this thread-local configuration to
+ // the threads themselves.
+ ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_);
+ bound_fn();
+ },
+ std::move(fn)));
+ }
+ }
+
+ string DebugString() override { return "ThreadPoolResource"; }
+
+ private:
+ thread::ThreadPool thread_pool_;
+ const int max_intra_op_parallelism_;
+};
+
+// Creates a handle to a ThreadPool resource. Note that we don't use
+// ResourceOpKernel here because the ThreadPoolResource constructor requires
+// access to `OpKernelContext::env()`, which isn't provided by
+// `ResourceOpKernel<T>::CreateResource()`.
+class ThreadPoolHandleOp : public OpKernel {
+ public:
+ explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
+ &max_intra_op_parallelism_));
+ OP_REQUIRES(
+ ctx, num_threads_ > 0,
+ errors::InvalidArgument("`num_threads` must be greater than zero."));
+ }
+
+ // The resource is deleted from the resource manager only when it is private
+ // to kernel. Ideally the resource should be deleted when it is no longer held
+ // by anyone, but it would break backward compatibility.
+ ~ThreadPoolHandleOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<ThreadPoolResource>(cinfo_.container(), cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ ResourceMgr* mgr = ctx->resource_manager();
+ OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
+ ThreadPoolResource* resource;
+ OP_REQUIRES_OK(ctx, mgr->LookupOrCreate<ThreadPoolResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, ctx](ThreadPoolResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new ThreadPoolResource(
+ ctx->env(), {}, display_name_,
+ num_threads_, max_intra_op_parallelism_,
+ false /* low_latency_hint */);
+ return Status::OK();
+ }));
+ initialized_ = true;
+ }
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<ThreadPoolResource>()));
+ }
+
+ private:
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+ string display_name_;
+ int num_threads_;
+ int max_intra_op_parallelism_;
+};
+
+class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ ThreadPoolResource* threadpool_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
+ &threadpool_resource));
+ core::ScopedUnref unref_iterator(threadpool_resource);
+
+ *output = new Dataset(ctx, input, threadpool_resource);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ ThreadPoolResource* threadpool)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ threadpool_(threadpool) {
+ input_->Ref();
+ threadpool_->Ref();
+ }
+
+ ~Dataset() override {
+ input_->Unref();
+ threadpool_->Unref();
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::ThreadPool")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "ThreadPoolDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ ThreadPoolResource* pool = dataset()->threadpool_;
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = [pool](std::function<void()> c) {
+ pool->Schedule(std::move(c));
+ };
+ params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ params.lib = ctx->lib();
+ params.function_library = ctx->function_library();
+ params.allocator_getter = ctx->allocator_getter();
+ IteratorContext threadpool_ctx(params);
+ return input_impl_->GetNext(&threadpool_ctx, out_tensors,
+ end_of_sequence);
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ const DatasetBase* const input_;
+ ThreadPoolResource* const threadpool_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
+ ThreadPoolHandleOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
+ ThreadPoolDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
new file mode 100644
index 0000000000..cd612e0eb2
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
@@ -0,0 +1,224 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class UniqueDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit UniqueDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ OP_REQUIRES(ctx, input->output_dtypes().size() == 1,
+ errors::InvalidArgument("UniqueDataset only supports "
+ "inputs with a single component."));
+
+ DataType input_dtype = input->output_dtypes()[0];
+ OP_REQUIRES(ctx,
+ input_dtype == DT_INT32 || input_dtype == DT_INT64 ||
+ input_dtype == DT_STRING,
+ errors::InvalidArgument(
+ "UniqueDataset only supports inputs with a single "
+ "`tf.int32`, `tf.int64`, or `tf.string` component."));
+
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Unique")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return strings::StrCat("UniqueDatasetOp::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const typename Iterator::Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ bool saw_new_value;
+ do {
+ saw_new_value = false;
+ out_tensors->clear();
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
+ if (*end_of_sequence) {
+ break;
+ }
+ DCHECK_EQ(1, out_tensors->size());
+ saw_new_value = unique_elements_.insert((*out_tensors)[0]).second;
+ } while (!saw_new_value);
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (input_impl_) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name("unique_elements_size"), unique_elements_.size()));
+ size_t i = 0;
+ for (const Tensor& t : unique_elements_) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("unique_elements[", i++, "]")), t));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
+ int64 num_unique_elements;
+ unique_elements_.clear();
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"),
+ &num_unique_elements));
+ for (int64 i = 0; i < num_unique_elements; ++i) {
+ Tensor unique_element;
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("unique_elements[", i, "]")),
+ &unique_element));
+ auto insert_result = unique_elements_.insert(unique_element);
+ if (!insert_result.second) {
+ return errors::InvalidArgument(
+ "Checkpoint contained two unique elements with the same "
+ "value.");
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ struct TensorHash {
+ size_t operator()(const Tensor& t) const {
+ if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) {
+ return Hash64(t.tensor_data().data(), t.tensor_data().size());
+ } else {
+ DCHECK_EQ(DT_STRING, t.dtype());
+ auto flat_t = t.flat<string>();
+ uint64 hash = 0;
+ for (int64 i = 0; i < t.NumElements(); ++i) {
+ hash = Hash64Combine(hash, Hash64(flat_t(i)));
+ }
+ return static_cast<size_t>(hash);
+ }
+ }
+ };
+
+ struct TensorKeyEqual {
+ bool operator()(const Tensor& lhs, const Tensor& rhs) const {
+ if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) {
+ return false;
+ }
+ switch (lhs.dtype()) {
+#define HANDLE_TYPE(T) \
+ case T: \
+ do { \
+ auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \
+ auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \
+ for (int64 i = 0; i < lhs.NumElements(); ++i) { \
+ if (lhs_flat(i) != rhs_flat(i)) { \
+ return false; \
+ } \
+ } \
+ return true; \
+ } while (0)
+
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_INT64);
+ HANDLE_TYPE(DT_STRING);
+ default:
+ DCHECK(false) << "UniqueDataset unhandled data type: "
+ << DataTypeString(lhs.dtype());
+ return false;
+ }
+ }
+ };
+
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_
+ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU),
+ UniqueDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow