diff options
author | 2018-09-28 08:38:53 -0700 | |
---|---|---|
committer | 2018-09-28 08:46:34 -0700 | |
commit | c7bb3c3d65e4e064d53630d4b524522eed6f3f44 (patch) | |
tree | 1fd0b73ab916093c80dcd289154035bba5fb393d /tensorflow/core/kernels | |
parent | e06783e7bb80f664c7ec9be90680ac6ddcbd598f (diff) |
[tf.data] Move `tf.contrib.data` C++ code to a core "experimental" directory.
NOTE: All ops and kernels previously previously defined in
tensorflow/contrib/data have had their name prefixed with
"Experimental" to indicate that they are not (yet) stable, and thus
not subject to backwards or forwards compatibility guarantees.
PiperOrigin-RevId: 214940819
Diffstat (limited to 'tensorflow/core/kernels')
13 files changed, 3372 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 87efdff789..6333853cdf 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -765,6 +765,7 @@ tf_kernel_library( ":window_dataset_op", ":writer_ops", ":zip_dataset_op", + "//tensorflow/core/kernels/data/experimental:dataset_kernels", ], ) diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD new file mode 100644 index 0000000000..43406db3ed --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -0,0 +1,139 @@ +# Description: +# Contains experimental kernels for datasets and iterators. +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_kernel_library", +) + +cc_library( + name = "indexed_dataset_headers", + hdrs = ["indexed_dataset.h"], + deps = [ + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "indexed_dataset", + srcs = [ + "identity_indexed_dataset.cc", + "indexed_dataset.cc", + ], + deps = [ + ":indexed_dataset_headers", + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "prefetching_kernels", + srcs = ["prefetching_kernels.cc"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( + name = "directed_interleave_dataset_op", + srcs = ["directed_interleave_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "csv_dataset_op", + srcs = ["csv_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( + name = "ignore_errors_dataset_op", + srcs = ["ignore_errors_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "lmdb_dataset_op", + srcs = ["lmdb_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + "@lmdb", + ], +) + +tf_kernel_library( + name = "threadpool_dataset_op", + srcs = ["threadpool_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "unique_dataset_op", + srcs = ["unique_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "assert_next_dataset_op", + srcs = ["assert_next_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "dataset_kernels", + deps = [ + ":assert_next_dataset_op", + ":csv_dataset_op", + ":directed_interleave_dataset_op", + ":ignore_errors_dataset_op", + ":indexed_dataset", + ":lmdb_dataset_op", + ":prefetching_kernels", + ":threadpool_dataset_op", + ":unique_dataset_op", + ], +) diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc new file mode 100644 index 0000000000..3511cca0f5 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -0,0 +1,156 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <map> + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. +class AssertNextDatasetOp : public UnaryDatasetOpKernel { + public: + explicit AssertNextDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + std::vector<string> transformations; + OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations", + &transformations)); + *output = + new Dataset(ctx, input, transformations, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const std::vector<string>& transformations, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + transformations_(transformations), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Assert")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "AssertNextDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* transformations_node = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, transformations_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + std::vector<string> tokens = + str_util::Split(prefix(), ':', str_util::SkipEmpty()); + if (dataset()->transformations_.size() > tokens.size() - 2) { + return errors::InvalidArgument( + "Asserted next ", dataset()->transformations_.size(), + " transformations but encountered only ", tokens.size() - 2, "."); + } + int n = tokens.size(); + for (size_t i = 0; i < dataset()->transformations_.size(); ++i) { + if (dataset()->transformations_[i] != tokens[n - 2 - i]) { + return errors::InvalidArgument( + "Asserted ", dataset()->transformations_[i], + " transformation at offset ", i, " but encountered ", + tokens[n - 2 - i], " transformation instead."); + } + } + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + std::unique_ptr<IteratorBase> input_impl_; + }; + + const DatasetBase* input_; + const std::vector<string> transformations_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU), + AssertNextDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc new file mode 100644 index 0000000000..7451ca4cb1 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc @@ -0,0 +1,860 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/parsing_ops.cc. +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputstream.h" + +namespace tensorflow { +namespace data { +namespace { + +class CSVDatasetOp : public DatasetOpKernel { + public: + explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + string compression_type; + OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type", + &compression_type)); + + OpInputList record_defaults_list; + OP_REQUIRES_OK(ctx, + ctx->input_list("record_defaults", &record_defaults_list)); + for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1, + errors::InvalidArgument( + "Each record default should be at most rank 1")); + OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, + errors::InvalidArgument( + "There should only be 1 default per field but field ", i, + " has ", record_defaults_list[i].NumElements())); + } + + const Tensor* select_cols_tensor; + OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor)); + OP_REQUIRES(ctx, select_cols_tensor->dims() == 1, + errors::InvalidArgument("`select_cols` must be a vector.")); + + int64 buffer_size; + OP_REQUIRES_OK( + ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); + OP_REQUIRES(ctx, buffer_size > 0, + errors::InvalidArgument("buffer_size should be positive")); + + string delim; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<string>(ctx, "field_delim", &delim)); + OP_REQUIRES(ctx, delim.size() == 1, + errors::InvalidArgument("field_delim should be only 1 char")); + + bool header; + OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "header", &header)); + + bool use_quote_delim; + OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "use_quote_delim", + &use_quote_delim)); + string na_value; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<string>(ctx, "na_value", &na_value)); + + std::vector<Tensor> record_defaults; + record_defaults.reserve(record_defaults_list.size()); + for (const Tensor& t : record_defaults_list) { + record_defaults.push_back(t); + } + + std::vector<string> filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat<string>()(i)); + } + + io::ZlibCompressionOptions zlib_compression_options = + io::ZlibCompressionOptions::DEFAULT(); + if (compression_type == "ZLIB") { + zlib_compression_options = io::ZlibCompressionOptions::DEFAULT(); + } else if (compression_type == "GZIP") { + zlib_compression_options = io::ZlibCompressionOptions::GZIP(); + } else { + OP_REQUIRES(ctx, compression_type.empty(), + errors::InvalidArgument( + "Unsupported compression_type: ", compression_type, ".")); + } + zlib_compression_options.input_buffer_size = buffer_size; + + std::vector<int64> select_cols; + select_cols.reserve(select_cols_tensor->NumElements()); + for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { + select_cols.push_back(select_cols_tensor->flat<int64>()(i)); + } + OP_REQUIRES( + ctx, output_types_.size() == select_cols.size() || select_cols.empty(), + errors::InvalidArgument("select_cols should match output size")); + for (int i = 1; i < select_cols.size(); i++) { + OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i], + errors::InvalidArgument( + "select_cols should be strictly increasing indices")); + } + OP_REQUIRES( + ctx, select_cols.empty() || select_cols.front() >= 0, + errors::InvalidArgument("select_cols should be non-negative indices")); + + *output = new Dataset(ctx, std::move(filenames), header, + std::move(compression_type), zlib_compression_options, + output_types_, output_shapes_, + std::move(record_defaults), std::move(select_cols), + use_quote_delim, delim[0], std::move(na_value)); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header, + string compression_type, io::ZlibCompressionOptions options, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes, + std::vector<Tensor> record_defaults, std::vector<int64> select_cols, + bool use_quote_delim, char delim, string na_value) + : DatasetBase(DatasetContext(ctx)), + filenames_(std::move(filenames)), + header_(header), + out_type_(output_types), + output_shapes_(output_shapes), + record_defaults_(std::move(record_defaults)), + select_cols_(std::move(select_cols)), + use_quote_delim_(use_quote_delim), + delim_(delim), + na_value_(std::move(na_value)), + use_compression_(!compression_type.empty()), + compression_type_(std::move(compression_type)), + options_(options) {} + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::CSV")})); + } + + const DataTypeVector& output_dtypes() const override { return out_type_; } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { return "CSVDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + Node* compression_type = nullptr; + Node* buffer_size = nullptr; + Node* header = nullptr; + Node* delim = nullptr; + Node* use_quote_delim = nullptr; + Node* na_value = nullptr; + Node* select_cols = nullptr; + + std::vector<Node*> record_defaults; + record_defaults.reserve(record_defaults_.size()); + for (const Tensor& t : record_defaults_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + record_defaults.emplace_back(node); + } + + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); + TF_RETURN_IF_ERROR( + b->AddScalar(options_.input_buffer_size, &buffer_size)); + TF_RETURN_IF_ERROR(b->AddScalar(header_, &header)); + + string delim_string(1, delim_); + TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim)); + TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim)); + TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value)); + TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {std::make_pair(0, filenames), std::make_pair(1, compression_type), + std::make_pair(2, buffer_size), std::make_pair(3, header), + std::make_pair(4, delim), std::make_pair(5, use_quote_delim), + std::make_pair(6, na_value), + std::make_pair(7, select_cols)}, // Single tensor inputs + {std::make_pair(8, record_defaults)}, // Tensor list inputs + {}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + bool select_all = dataset()->select_cols_.empty(); + do { + // We are currently processing a file, so try to read the next record + if (input_stream_) { + Status s = ReadRecord(ctx, out_tensors, select_all, + dataset()->select_cols_); + if (s.ok()) { + // Validate output + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument( + "Expect ", dataset()->out_type_.size(), " fields but have ", + out_tensors->size(), " in record"); + } + + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { + // Not at the end of file, return OK or non-EOF errors to caller. + *end_of_sequence = false; + return s; + } + // We have reached the end of the current file, so maybe + // move on to next file. + ResetStreamsLocked(); + ++current_file_index_; + } + // Iteration ends when there are no more files to process. + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), + current_file_index_)); + // `input_stream_` is empty if + // 1. GetNext has not been called even once. + // 2. All files have been read and the iterator has been exhausted. + if (input_stream_ && num_buffer_reads_ > 0) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_)); + // If num_buffer_reads_ == 0, the buffer hasn't been filled even once. + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"), + num_buffer_reads_)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + ResetStreamsLocked(); + int64 current_file_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), + ¤t_file_index)); + current_file_index_ = size_t(current_file_index); + // The keys "pos" and "num_buffer_reads" are written only if + // the iterator was saved with an open, partially read file. + if (reader->Contains(full_name("pos"))) { + int64 pos, num_buffer_reads; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"), + &num_buffer_reads)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + + num_buffer_reads_ = size_t(num_buffer_reads - 1); + + // Restores the most recently held buffer + Status s = input_stream_->SkipNBytes( + num_buffer_reads_ * dataset()->options_.input_buffer_size); + if (!s.ok() && !errors::IsOutOfRange(s)) { + // We might get out of range error here if the size of the file + // is not an exact multiple of the buffer size, and the last buffer + // read is < buffer_size. This is valid and we do not surface the + // error. + return s; + } + + Status s2 = FillBuffer(&buffer_); + if (!s2.ok() && !errors::IsOutOfRange(s2)) { + return s2; + } + pos_ = size_t(pos); + } + return Status::OK(); + } + + private: + // Reads an entire CSV row from the input stream, either from the + // existing buffer or by filling the buffer as needed. Converts extracted + // fields to output tensors as we go. + // + // When this function is called, pos_ should be the index of the first + // character of the record in buffer_, or past the end of the buffer. + // Note: ctx and out_tensors are only used in this function + // when fields are included in the record. + Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors, + bool select_all, const std::vector<int64>& selected) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // At the end of the file, this will return errors::OutOfRange + TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); + pos_ = 0; + } + + // The first character may be \n if this is the continuation of a + // \r\n linebreak between this and the previous record. If so, skip it. + + bool end_of_record = false; // Keep track of when we find \n, \r or EOF + size_t num_parsed = 0; + size_t num_selected_parsed = 0; + + Status result; + + while (!end_of_record) { // Read till we reach \n, \r or EOF + bool include = + select_all || (num_selected_parsed < selected.size() && + selected[num_selected_parsed] == num_parsed); + + // Don't fail fast, so that the next call to GetNext may still return + // a valid record + result.Update( + ParseOneField(ctx, out_tensors, &end_of_record, include)); + + num_parsed++; + if (include) num_selected_parsed++; + } + + return result; + } + + // Parses one field from position pos_ in the buffer. Fields are + // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of + // the next field. + Status ParseOneField(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // If we get here, this means the previous field's end coincided + // with the end of the buffer. We can fill the buffer without abandon. + Status s = FillBuffer(&buffer_); + + if (errors::IsOutOfRange(s)) { + // Reached EOF, and last field is empty + *end_of_record = true; + if (include) { + return FieldToOutput(ctx, StringPiece(), out_tensors); + } else { + return Status::OK(); + } + } else if (!s.ok()) { + return s; // Surface other errors back to caller + } + + pos_ = 0; + } + + if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { + return ParseQuotedField(ctx, out_tensors, end_of_record, include); + } + + return ParseUnquotedField(ctx, out_tensors, end_of_record, include); + } + + // For keeping track of relevant parts of a field from a previous buffer + struct Piece { + size_t start; + size_t len; + string buffer; + + Piece(string buffer, size_t start, size_t len) + : start(start), len(len), buffer(std::move(buffer)) {} + }; + + // Given that pos_ exceeds the buffer, saves the relevant part of the + // current buffer (if necessary), fills the buffer, and resets indices to + // 0. + Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces, + size_t* start, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + string temp_buffer; + + buffer_.swap(temp_buffer); + if (include && pos_ > *start) { + earlier_pieces->push_back( + Piece(std::move(temp_buffer), *start, pos_ - *start)); + } + pos_ = 0; + *start = 0; + return FillBuffer(&buffer_); + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseQuotedField(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector<Piece> earlier_pieces; + size_t start = pos_; + pos_++; // Starting quotation mark + + Status parse_result; + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "Reached end of file without closing quoted field in " + "record"); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + if (ch == '"') { + // When we encounter a quote, we look ahead to the next character to + // decide what to do + pos_++; + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + // This was the last field. We are done + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(), out_tensors, earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; + } + } + + char next = buffer_[pos_]; + pos_++; + if (next == dataset()->delim_) { + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + return parse_result; + + } else if (next == '\n' || next == '\r') { + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + if (next == '\r') SkipNewLineIfNecessary(); + return parse_result; + } else if (next != '"') { + // Take note of the error, but keep going to end of field. + include = false; // So we don't get funky errors when trying to + // unescape the quotes. + parse_result.Update(errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote")); + } + + } else { + pos_++; + } + } + } + + // Converts quoted field to an output tensor, removing the starting + // and ending quotes from it and unescaping double quotations if + // necessary. + Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector<Tensor>* out_tensors, + const std::vector<Piece>& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + if (field.find('\"', 1) == field.size() - 1) { + // `field` contains no escaped quotation marks. + // Exclude framing quotation marks + field.remove_prefix(1); + field.remove_suffix(1); + return FieldToOutput(ctx, field, out_tensors); + } + } + string field_complete; + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + field_complete.reserve(str_len); + + // This bool flips every time we see a quote, so that we skip the second + // quote of every pair of adjacent quotes in the field. We need to track + // this across iterations of the for loop because adjacent double quotes + // may be in different buffers. Initialize to true because we also skip + // the opening quotation mark of the quoted field. + bool skip_next_quote = true; + for (const Piece& p : earlier_pieces) { + AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), + &field_complete, &skip_next_quote); + } + AppendUnescapedPiece(field, &field_complete, &skip_next_quote); + StringPiece result = StringPiece(field_complete); + result.remove_suffix(1); // Skip final quote + + return FieldToOutput(ctx, result, out_tensors); + } + + void AppendUnescapedPiece(StringPiece piece, string* field_complete, + bool* skip_next_quote) { + size_t from = 0; + size_t found = piece.find('\"', from); + while (found != string::npos) { + if (!*skip_next_quote) { + // This is the first quote in a pair of adjacent double quotes + field_complete->append(piece.data() + from, found + 1 - from); + } + *skip_next_quote = !*skip_next_quote; + from = found + 1; + found = piece.find('\"', from); + } + // Include the chunk after the last quotation mark in the string + if (from < piece.size()) { + field_complete->append(piece.data() + from, piece.size() - from); + } + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseUnquotedField(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector<Piece> earlier_pieces; + size_t start = pos_; + Status parse_result; + + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + // Handle errors + if (errors::IsOutOfRange(s)) { + // Whatever we have is the last field of the last record + *end_of_record = true; + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + + if (ch == dataset()->delim_) { + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + pos_++; + return parse_result; + } + if (ch == '\n' || ch == '\r') { + // need special case to skip over first \n of record if the line + // breaks are \r\n + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + *end_of_record = true; + pos_++; + if (ch == '\r') SkipNewLineIfNecessary(); + return parse_result; + } + if (dataset()->use_quote_delim_ && ch == '"') { + // Take note of the error, but keep going to end of field. + parse_result.Update(errors::InvalidArgument( + "Unquoted fields cannot have quotes inside")); + } + // Otherwise, go to next character + pos_++; + } + } + + Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->clear(); + ++num_buffer_reads_; + Status s = input_stream_->ReadNBytes( + dataset()->options_.input_buffer_size, result); + + if (errors::IsOutOfRange(s) && !result->empty()) { + // Ignore OutOfRange error when ReadNBytes read < N bytes. + return Status::OK(); + } + return s; + } + + // Given a field, converts it to the right output tensor type + Status FieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector<Tensor>* out_tensors) { + size_t output_idx = out_tensors->size(); + if (output_idx >= dataset()->out_type_.size()) { + // We can get here if we're selecting all columns, but the number of + // fields exceeds the number of defaults provided + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have more in record"); + } + const DataType& dtype = dataset()->out_type_[output_idx]; + Tensor component(ctx->allocator({}), dtype, {}); + if ((field.empty() || field == dataset()->na_value_) && + dataset()->record_defaults_[output_idx].NumElements() != 1) { + // If the field is empty or NA value, and default is not given, + // report error. + return errors::InvalidArgument("Field ", output_idx, + " is required but missing in record!"); + } + + switch (dtype) { + // For each case, if the field is empty, we use the default. + // Otherwise, we convert it to the right type. + case DT_INT32: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<int32>()() = + dataset()->record_defaults_[output_idx].flat<int32>()(0); + } else { + int32 value; + if (!strings::safe_strto32(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int32: ", field); + } + component.scalar<int32>()() = value; + } + break; + } + case DT_INT64: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<int64>()() = + dataset()->record_defaults_[output_idx].flat<int64>()(0); + } else { + int64 value; + if (!strings::safe_strto64(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int64: ", field); + } + component.scalar<int64>()() = value; + } + break; + } + case DT_FLOAT: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<float>()() = + dataset()->record_defaults_[output_idx].flat<float>()(0); + } else { + float value; + if (!strings::safe_strtof(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid float: ", field); + } + component.scalar<float>()() = value; + } + break; + } + case DT_DOUBLE: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<double>()() = + dataset()->record_defaults_[output_idx].flat<double>()(0); + } else { + double value; + if (!strings::safe_strtod(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid double: ", field); + } + component.scalar<double>()() = value; + } + break; + } + case DT_STRING: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<string>()() = + dataset()->record_defaults_[output_idx].flat<string>()(0); + } else { + component.scalar<string>()() = string(field); + } + break; + } + default: + return errors::InvalidArgument("csv: data type ", dtype, + " not supported in field ", + output_idx); + } + out_tensors->push_back(std::move(component)); + return Status::OK(); + } + + // Records can be delimited by "\r\n" line breaks. When we encounter a + // '\r', we have to check the next character to see if it is part of the + // linebreak, and ignore it if so. + void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + // If we failed to fill buffer, it doesn't matter because we're done + // with the record + if (!s.ok()) return; + } + if (buffer_[pos_] == '\n') { + pos_++; + } + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector<Tensor>* out_tensors, + const std::vector<Piece>& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + return FieldToOutput(ctx, field, out_tensors); + } + + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + string field_complete; + field_complete.reserve(str_len); + + for (const Piece& p : earlier_pieces) { + field_complete.append(p.buffer, p.start, p.len); + } + + field_complete.append(field.data(), field.size()); + return FieldToOutput(ctx, field_complete, out_tensors); + } + + // Sets up reader streams to read from the file at `current_file_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + + // Actually move on to next file. + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + dataset()->filenames_[current_file_index_], &file_)); + random_access_input_stream_ = + std::make_shared<io::RandomAccessInputStream>(file_.get(), false); + + if (dataset()->use_compression_) { + input_stream_ = std::make_shared<io::ZlibInputStream>( + random_access_input_stream_.get(), + dataset()->options_.input_buffer_size, + dataset()->options_.input_buffer_size, dataset()->options_); + } else { + input_stream_ = random_access_input_stream_; + } + buffer_.clear(); + pos_ = 0; + num_buffer_reads_ = 0; + if (dataset()->header_) { + // Read one line, but don't include it. Pass nullptrs as dummy + // pointers to objects that shouldn't be invoked anyway + // We need to process this as a record here instead of just finding + // the first newline because it might contain quoted fields with + // newlines in the header as well + std::vector<int64> empty; + Status s = ReadRecord(nullptr, nullptr, false, empty); + if (!s.ok()) { + return errors::InvalidArgument("Can't read header of file"); + } + } + return Status::OK(); + } + + // Resets all reader streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + input_stream_.reset(); + file_.reset(); + } + + mutex mu_; + string buffer_ GUARDED_BY(mu_); // Maintain our own buffer + size_t pos_ GUARDED_BY( + mu_); // Index into the buffer must be maintained between iters + size_t num_buffer_reads_ GUARDED_BY(mu_); + std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_ + GUARDED_BY(mu_); + std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_); + size_t current_file_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr<RandomAccessFile> file_ + GUARDED_BY(mu_); // must outlive input_stream_ + }; // class Iterator + + const std::vector<string> filenames_; + const bool header_; + const DataTypeVector out_type_; + const std::vector<PartialTensorShape> output_shapes_; + const std::vector<Tensor> record_defaults_; + const std::vector<int64> select_cols_; + const bool use_quote_delim_; + const char delim_; + const string na_value_; + const bool use_compression_; + const string compression_type_; + const io::ZlibCompressionOptions options_; + }; // class Dataset + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; // class CSVDatasetOp + +// Register the kernel implementation for CSVDataset. +REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU), + CSVDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc new file mode 100644 index 0000000000..c47a9099c4 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -0,0 +1,281 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { +namespace data { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class DirectedInterleaveDatasetOp : public DatasetOpKernel { + public: + explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + DatasetBase* selector_input; + OP_REQUIRES_OK(ctx, + GetDatasetFromVariantTensor(ctx->input(0), &selector_input)); + + OP_REQUIRES( + ctx, + selector_input->output_dtypes().size() == 1 && + selector_input->output_dtypes()[0] == DT_INT64 && + selector_input->output_shapes().size() == 1 && + selector_input->output_shapes()[0].IsCompatibleWith( + PartialTensorShape({})), + errors::InvalidArgument( + "The selector input must be a dataset of scalar int64 elements.")); + + std::vector<DatasetBase*> data_inputs; + for (size_t i = 1; i < ctx->num_inputs(); ++i) { + DatasetBase* input; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); + data_inputs.push_back(input); + + OP_REQUIRES( + ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(), + errors::InvalidArgument( + "All inputs must have the same output_dtypes. First input " + "has types ", + DataTypeVectorString(data_inputs[0]->output_dtypes()), + ", and input ", i - 1, " has types ", + DataTypeVectorString(input->output_dtypes()))); + } + *output = new Dataset(ctx, selector_input, std::move(data_inputs)); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, + std::vector<DatasetBase*> data_inputs) + : DatasetBase(DatasetContext(ctx)), + selector_input_(selector_input), + data_inputs_(std::move(data_inputs)) { + selector_input_->Ref(); + + output_shapes_ = data_inputs_[0]->output_shapes(); + data_inputs_[0]->Ref(); + for (size_t i = 1; i < data_inputs_.size(); ++i) { + const DatasetBase* data_input = data_inputs_[i]; + data_input->Ref(); + for (size_t j = 0; j < output_shapes_.size(); ++j) { + output_shapes_[j] = MostSpecificCompatibleShape( + output_shapes_[j], data_input->output_shapes()[j]); + } + } + } + + ~Dataset() override { + selector_input_->Unref(); + for (DatasetBase* data_input : data_inputs_) { + data_input->Unref(); + } + } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::DirectedInterleave")})); + } + + const DataTypeVector& output_dtypes() const override { + return data_inputs_[0]->output_dtypes(); + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* selector_input_node; + TF_RETURN_IF_ERROR( + b->AddInputDataset(ctx, selector_input_, &selector_input_node)); + std::vector<Node*> data_input_nodes(data_inputs_.size()); + for (size_t i = 0; i < data_inputs_.size(); ++i) { + TF_RETURN_IF_ERROR( + b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i])); + } + TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, + {{1, data_input_nodes}}, {}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + num_active_inputs_(params.dataset->data_inputs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( + ctx, strings::StrCat(prefix(), ".selector"), + &selector_input_impl_)); + data_input_impls_.resize(dataset()->data_inputs_.size()); + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const DatasetBase* data_input = dataset()->data_inputs_[i]; + TF_RETURN_IF_ERROR(data_input->MakeIterator( + ctx, strings::StrCat(prefix(), "[", i, "]"), + &data_input_impls_[i])); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (!selector_input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + + while (true) { + std::vector<Tensor> selector_result; + *end_of_sequence = false; + TF_RETURN_IF_ERROR(selector_input_impl_->GetNext( + ctx, &selector_result, end_of_sequence)); + if (*end_of_sequence) { + selector_input_impl_.reset(); + for (auto& data_input_impl : data_input_impls_) { + data_input_impl.reset(); + } + return Status::OK(); + } + + int64 selected_input = selector_result[0].scalar<int64>()(); + if (selected_input < 0 || selected_input > data_input_impls_.size()) { + return errors::InvalidArgument( + "Selector index out of range: ", selected_input, + " >= ", data_input_impls_.size()); + } + + if (data_input_impls_[selected_input]) { + bool end_of_selected_input = false; + TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext( + ctx, out_tensors, &end_of_selected_input)); + + if (!end_of_selected_input) { + return Status::OK(); + } + + data_input_impls_[selected_input].reset(); + --num_active_inputs_; + + if (num_active_inputs_ == 0) { + selector_input_impl_.reset(); + *end_of_sequence = true; + return Status::OK(); + } + } + + LOG(WARNING) << "DirectedInterleave selected an exhausted input: " + << selected_input; + } + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (selector_input_impl_) { + TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("selector_input_impl_empty"), "")); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const auto& data_input_impl = data_input_impls_[i]; + if (data_input_impl) { + TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl)); + } else { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")), + "")); + } + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("selector_input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); + } else { + selector_input_impl_.reset(); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + if (!reader->Contains(full_name( + strings::StrCat("data_input_impl_empty[", i, "]")))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); + } else { + data_input_impls_[i].reset(); + } + } + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr<IteratorBase> selector_input_impl_ GUARDED_BY(mu_); + std::vector<std::unique_ptr<IteratorBase>> data_input_impls_ + GUARDED_BY(mu_); + int64 num_active_inputs_ GUARDED_BY(mu_); + }; + + static PartialTensorShape MostSpecificCompatibleShape( + const PartialTensorShape& ts1, const PartialTensorShape& ts2) { + PartialTensorShape output_tensorshape; + if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + return output_tensorshape; + auto dims1 = ts1.dim_sizes(); + auto dims2 = ts2.dim_sizes(); + for (int d = 0; d < ts1.dims(); d++) { + if (dims1[d] == dims2[d]) + output_tensorshape.Concatenate(dims1[d]); + else + output_tensorshape.Concatenate(-1); + } + return output_tensorshape; + } + + const DatasetBase* const selector_input_; + const std::vector<DatasetBase*> data_inputs_; + std::vector<PartialTensorShape> output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU), + DirectedInterleaveDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc new file mode 100644 index 0000000000..2141f118ca --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc @@ -0,0 +1,156 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace data { +namespace { + +class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel { + public: + using IndexedDatasetOpKernel::IndexedDatasetOpKernel; + + void MakeIndexedDataset(OpKernelContext* ctx, + IndexedDataset** output) override { + uint64 size = -1; + OP_REQUIRES_OK(ctx, ParseScalarArgument<uint64>(ctx, "size", &size)); + OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0")); + *output = new Dataset(ctx, size); + } + + class Dataset : public IndexedDataset { + public: + Dataset(OpKernelContext* ctx, uint64 size) + : IndexedDataset(DatasetContext(ctx)), size_(size) {} + + Status MaterializeDataset( + std::shared_ptr<MaterializedIndexedDataset>* materialized) override { + materialized->reset(new Materialized(this)); + return Status::OK(); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}}); + return *shapes; + } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::IdentityIndexedDataset")})); + } + + string DebugString() const override { + return "IdentityIndexedDataset::Dataset"; + } + + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { + return errors::Unimplemented( + "identity_indexed_dataset.AsGraphDefInternal"); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (cur_ < dataset()->size_) { + Tensor result_tensor(ctx->allocator({}), DT_UINT64, {}); + result_tensor.scalar<uint64>()() = cur_++; + out_tensors->emplace_back(std::move(result_tensor)); + *end_of_sequence = false; + return Status::OK(); + } + *end_of_sequence = true; + return Status::OK(); + } + + private: + mutex mu_; + uint64 cur_ GUARDED_BY(mu_); + }; + + class Materialized : public MaterializedIndexedDataset { + public: + explicit Materialized(Dataset* dataset) : dataset_(dataset) { + dataset->Ref(); + } + + ~Materialized() override { + // TODO(saeta): Pull this into MaterializedIndexedDataset + dataset_->Unref(); + } + + const DataTypeVector& output_dtypes() const override { + return dataset_->output_dtypes(); + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return dataset_->output_shapes(); + } + + Status Get(IteratorContext&& ctx, uint64 index, + std::vector<Tensor>* out_tensors) const override { + LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index + << ")"; + if (index >= dataset_->size_) { + // Note: use InvalidArgument instead of OutOfRange error because many + // things consider OutOfRange to be a "clean termination" error. + return errors::InvalidArgument( + "Index ", index, + " is out of range for this dataset. (Size is: ", dataset_->size_, + ".)"); + } + Tensor result_tensor(ctx.allocator({}), DT_UINT64, {}); + result_tensor.scalar<uint64>()() = index; + out_tensors->emplace_back(std::move(result_tensor)); + return Status::OK(); + } + + Status Size(uint64* size) const override { + *size = dataset_->size_; + return Status::OK(); + } + + private: + const Dataset* const dataset_; // Not owned. + }; + + const uint64 size_; + std::shared_ptr<Materialized> materialized_; + }; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU), + IdentityIndexedDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc new file mode 100644 index 0000000000..b34377c642 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { + public: + explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + *output = new Dataset(ctx, input); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), input_(input) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { + return "IgnoreErrorsDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + { + tf_shared_lock l(mu_); + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + while (!s.ok()) { + out_tensors->clear(); + s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + } + if (*end_of_sequence) { + mutex_lock l(mu_); + input_impl_.reset(); + } + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (input_impl_) + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + else + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impls_empty"), "")); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (reader->Contains(full_name("input_impls_empty"))) + input_impl_.reset(); + else + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + }; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU), + IgnoreErrorsDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc new file mode 100644 index 0000000000..75ea462f40 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc @@ -0,0 +1,375 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h" + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" + +namespace tensorflow { +namespace data { +namespace { + +Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " types but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != received[i]) { + return errors::InvalidArgument("Data type mismatch at component ", i, + ": expected ", DataTypeString(expected[i]), + " but got ", DataTypeString(received[i]), + "."); + } + } + return Status::OK(); +} + +Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, + const std::vector<PartialTensorShape>& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " shapes but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (!expected[i].IsCompatibleWith(received[i])) { + return errors::InvalidArgument("Incompatible shapes at component ", i, + ": expected ", expected[i].DebugString(), + " but got ", received[i].DebugString(), + "."); + } + } + + return Status::OK(); +} + +class MaterializedDatasetResource : public ResourceBase { + public: + MaterializedDatasetResource( + const DataTypeVector& output_dtypes, + const std::vector<PartialTensorShape>& output_shapes) + : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {} + + string DebugString() override { + return "Materialized IndexedDataset resource"; + } + + Status Get(IteratorContext&& ctx, uint64 index, + std::vector<Tensor>* out_tensors) { + std::shared_ptr<MaterializedIndexedDataset> captured(materialized_); + if (captured) { + return captured->Get(std::move(ctx), index, out_tensors); + } else { + return errors::FailedPrecondition( + "Get() failed because the MaterializedIndexedDataset has not been " + "initialized. Ensure that you have run the materialization operation " + "for this MaterializedIndexedDataset before retrieving elements."); + } + } + + // TODO(saeta): Implement Save and Restore + + const DataTypeVector& output_dtypes() const { return output_dtypes_; } + const std::vector<PartialTensorShape>& output_shapes() const { + return output_shapes_; + } + + Status set_materialized_dataset( + const std::shared_ptr<MaterializedIndexedDataset>& dataset) { + if (dataset) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_dtypes_, dataset->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, dataset->output_shapes())); + } + materialized_ = dataset; + return Status::OK(); + } + + private: + std::shared_ptr<MaterializedIndexedDataset> materialized_; + const DataTypeVector output_dtypes_; + const std::vector<PartialTensorShape> output_shapes_; +}; + +// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT +// tensor. Objects of the wrapper class own a reference on an instance of an +// `IndexedTensor` and the wrapper's copy constructor and desctructor take care +// of managing the reference count. +// +// NOTE: This is not a feature-complete implementation of the DT_VARIANT +// specification. In particular, we cannot currently serialize an arbitrary +// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not +// implemented. +// +// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just +// use `tensorflow::DatasetVariantWrapper`. +class IndexedDatasetVariantWrapper { + public: + IndexedDatasetVariantWrapper() : dataset_(nullptr) {} + + // Transfers ownership of `dataset` to `*this`. + explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset) + : dataset_(dataset) {} + + IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other) + : dataset_(other.dataset_) { + if (dataset_) dataset_->Ref(); + } + + ~IndexedDatasetVariantWrapper() { + if (dataset_) dataset_->Unref(); + } + + IndexedDataset* get() const { return dataset_; } + + string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; } + string DebugString() const { + if (dataset_) { + return dataset_->DebugString(); + } else { + return "<Uninitialized IndexedDatasetVariantWrapper>"; + } + } + + void Encode(VariantTensorData* data) const { + LOG(ERROR) << "The Encode() method is not implemented for " + "IndexedDatasetVariantWrapper objects."; + } + + bool Decode(const VariantTensorData& data) { + LOG(ERROR) << "The Decode() method is not implemented for " + "IndexedDatasetVariantWrapper objects."; + return false; + } + + private: + IndexedDataset* const dataset_; // Owns one reference. +}; + +} // namespace + +Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, + IndexedDataset** out_dataset) { + if (!(tensor.dtype() == DT_VARIANT || + TensorShapeUtils::IsScalar(tensor.shape()))) { + return errors::InvalidArgument( + "IndexedDataset tensor must be a scalar of dtype DT_VARIANT."); + } + const Variant& variant = tensor.scalar<Variant>()(); + const IndexedDatasetVariantWrapper* wrapper = + variant.get<IndexedDatasetVariantWrapper>(); + if (wrapper == nullptr) { + return errors::InvalidArgument("Tensor must be an IndexedDataset object."); + } + *out_dataset = wrapper->get(); + if (*out_dataset == nullptr) { + return errors::Internal("Read uninitialized IndexedDataset variant."); + } + return Status::OK(); +} + +Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, + Tensor* tensor) { + if (!(tensor->dtype() == DT_VARIANT || + TensorShapeUtils::IsScalar(tensor->shape()))) { + return errors::InvalidArgument( + "Dataset tensor must be a scalar of dtype DT_VARIANT."); + } + tensor->scalar<Variant>()() = IndexedDatasetVariantWrapper(dataset); + return Status::OK(); +} + +void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) { + IndexedDataset* dataset = nullptr; + MakeIndexedDataset(ctx, &dataset); + + if (ctx->status().ok()) { + OP_REQUIRES(ctx, dataset != nullptr, + errors::Internal("MakeIndexedDataset did not correctly " + "construct the IndexedDataset")); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output)); + } +} + +namespace { + +class MaterializedHandleOp : public OpKernel { + public: + explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + ~MaterializedHandleOp() override { + if (resource_ != nullptr) { + resource_->Unref(); + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete<MaterializedDatasetResource>( + cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h + } + } + } + } + + void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (resource_ == nullptr) { + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + MaterializedDatasetResource* resource; + OP_REQUIRES_OK(context, + mgr->LookupOrCreate<MaterializedDatasetResource>( + cinfo_.container(), cinfo_.name(), &resource, + [this](MaterializedDatasetResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new MaterializedDatasetResource( + output_dtypes_, output_shapes_); + return Status::OK(); + })); + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; + } + + resource_ = resource; + } + } + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<MaterializedDatasetResource>())); + } + + private: + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + Status VerifyResource(MaterializedDatasetResource* resource) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, resource->output_shapes())); + return Status::OK(); + } + + mutex mu_; + ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. + MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr; + DataTypeVector output_dtypes_; + std::vector<PartialTensorShape> output_shapes_; +}; + +// TODO(saeta): Make async. +class MaterializeDatasetOp : public OpKernel { + public: + explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + IndexedDataset* dataset; + OP_REQUIRES_OK(ctx, + GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset)); + + MaterializedDatasetResource* materialized_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), + &materialized_resource)); + core::ScopedUnref unref(materialized_resource); + std::shared_ptr<MaterializedIndexedDataset> materialized; + OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized)); + OP_REQUIRES_OK( + ctx, materialized_resource->set_materialized_dataset(materialized)); + } +}; + +// TODO(saeta): Make async +class IndexedDatasetGet : public OpKernel { + public: + explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + MaterializedDatasetResource* materialized_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), + &materialized_resource)); + auto cleanup = gtl::MakeCleanup([materialized_resource] { + materialized_resource->Unref(); // Note: can't use core::ScopedUnref. + }); + + const Tensor* index_t; + OP_REQUIRES_OK(ctx, ctx->input("index", &index_t)); + // TODO(saeta): Support batch reads (indexes should be non-scalar!) + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()), + errors::InvalidArgument("index must be a scalar")); + const uint64 index = index_t->scalar<uint64>()(); + + std::vector<Tensor> out_tensors; + Status s = + materialized_resource->Get(IteratorContext(ctx), index, &out_tensors); + + // Note: Unref materialized_resource to avoid destruction races. (Important + // in a [future] async op implementation.) + cleanup.release()(); + + if (!s.ok()) { + ctx->SetStatus(s); + } else { + auto expected_shapes = materialized_resource->output_shapes(); + auto expected_types = materialized_resource->output_dtypes(); + for (size_t i = 0; i < out_tensors.size(); ++i) { + OP_REQUIRES( + ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()), + errors::Internal( + "Materialized dataset output at index ", i, + " is incompatible with the expected shape. (Expected: ", + expected_shapes[i], ", got: ", out_tensors[i].shape(), ")")); + OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i], + errors::Internal("Materialized dataset output at index ", i, + " was not the expected dtype. (Expected: ", + expected_types[i], + ", got: ", out_tensors[i].dtype(), ")")); + ctx->set_output(i, out_tensors[i]); + } + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU), + MaterializedHandleOp); +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU), + MaterializeDatasetOp); +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU), + IndexedDatasetGet); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.h b/tensorflow/core/kernels/data/experimental/indexed_dataset.h new file mode 100644 index 0000000000..27a8360cbc --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.h @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { + +// TODO(saeta): Urgh, this is ugly. +class MaterializedIndexedDataset { + public: + virtual ~MaterializedIndexedDataset() = default; + + // Retrieve the element at a given index. The output tensors are stored in + // out_tensors. + // + // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is + // returned. + // + // Get is thread-safe. + virtual Status Get(IteratorContext&& ctx, uint64 index, + std::vector<Tensor>* out_tensors) const = 0; + + // Size determines the number of elements in this IndexedDataset. + // + // Size is thread-safe. + virtual Status Size(uint64* size) const = 0; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this dataset. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this dataset. + virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; +}; + +// IndexedDataset represents a dataset that supports random access in addition +// to iterator-based sequential access. +// +// Note: IndexedDatasets are HIGHLY experimental at this time. Expect +// significant (backwards incompatible) changes! +class IndexedDataset : public DatasetBase { + public: + IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {} + + // Materialize (if necessary) the dataset, and return a pointer. + // TODO(saeta): Add in `IteratorContext* ctx` when materializing. + virtual Status MaterializeDataset( + std::shared_ptr<MaterializedIndexedDataset>* materialized) = 0; +}; + +// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the +// rest of the TensorFlow runtime. +// +// Most IndexedDataset's will be private members of classes inheriting from this +// class. +class IndexedDatasetOpKernel : public OpKernel { + public: + IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) final; + + protected: + // Subclasses should implement this method. It will be called during Compute + // execution. + virtual void MakeIndexedDataset(OpKernelContext* ctx, + IndexedDataset** output) = 0; + + template <typename T> + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar<T>()(); + return Status::OK(); + } +}; + +// Validates and extracts an `IndexedDataset` object from `tensor`. +// +// `tensor` must have been written by a call to +// `StoreIndexedDatasetInVariantTensor` +// +// The retrieved pointer isa borrowed reference to the dataset, which is owned +// by the tensor. The consumer must either acquire its own reference to the +// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not +// destroyed or mutated while the retrieved pointer is in use. +Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, + IndexedDataset** out_dataset); + +// Stores an `IndexedDataset` object in `tensor.` +// +// The ownership of `dataset` is transferred to `tensor`. +Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, + Tensor* tensor); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_ diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc new file mode 100644 index 0000000000..8a88d32f0c --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc @@ -0,0 +1,218 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <sys/stat.h> + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/platform/file_system.h" + +#include "lmdb.h" // NOLINT(build/include) + +namespace tensorflow { +namespace data { +namespace { + +class LMDBDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + std::vector<string> filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat<string>()(i)); + } + + *output = new Dataset(ctx, filenames); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const std::vector<string>& filenames) + : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {} + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::LMDB")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = + new DataTypeVector({DT_STRING, DT_STRING}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}, {}}); + return *shapes; + } + + string DebugString() const override { return "LMDBDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + if (mdb_cursor_) { + Tensor key_tensor(ctx->allocator({}), DT_STRING, {}); + key_tensor.scalar<string>()() = string( + static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size); + out_tensors->emplace_back(std::move(key_tensor)); + + Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); + value_tensor.scalar<string>()() = + string(static_cast<const char*>(mdb_value_.mv_data), + mdb_value_.mv_size); + out_tensors->emplace_back(std::move(value_tensor)); + + int val; + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + ++current_file_index_; + } + *end_of_sequence = false; + return Status::OK(); + } + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); + } + + private: + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + const string& filename = dataset()->filenames_[current_file_index_]; + + int val = mdb_env_create(&mdb_env_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; + + struct stat source_stat; + if (stat(filename.c_str(), &source_stat) == 0 && + (source_stat.st_mode & S_IFREG)) { + flags |= MDB_NOSUBDIR; + } + val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + } + return Status::OK(); + } + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (mdb_env_ != nullptr) { + if (mdb_cursor_) { + mdb_cursor_close(mdb_cursor_); + mdb_cursor_ = nullptr; + } + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + mdb_env_close(mdb_env_); + mdb_txn_ = nullptr; + mdb_dbi_ = 0; + mdb_env_ = nullptr; + } + } + mutex mu_; + size_t current_file_index_ GUARDED_BY(mu_) = 0; + MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr; + MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr; + MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0; + MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr; + + MDB_val mdb_key_ GUARDED_BY(mu_); + MDB_val mdb_value_ GUARDED_BY(mu_); + }; + + const std::vector<string> filenames_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU), + LMDBDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc new file mode 100644 index 0000000000..2c6179d9f5 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc @@ -0,0 +1,482 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <deque> + +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace data { +namespace { + +struct BufferElement { + // The producer sets `status` if getting the input element fails. + Status status; + // The buffered data element. + std::vector<Tensor> value; +}; + +using FunctionBufferCallback = std::function<void(const BufferElement&)>; + +class FunctionBufferingResource : public ResourceBase { + public: + FunctionBufferingResource(FunctionLibraryRuntime* lib, + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, + const NameAttrList& func, int64 buffer_size, + const string& source_device, + const string& target_device, + const std::vector<Tensor>& func_args, + const DataTypeVector& output_types) + : lib_(lib), + pflr_(std::move(pflr)), + func_(func), + buffer_size_(buffer_size), + source_device_(source_device), + target_device_(target_device), + func_args_(func_args), + output_types_(output_types), + handle_(kInvalidHandle), + is_buffering_(false), + end_of_sequence_(false), + cancelled_(false) {} + + ~FunctionBufferingResource() override { + Cancel(); + } + + string DebugString() override { + return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_, + "; target_device: ", target_device_); + } + + // Instantiates the function the first time it's called. After that it caches + // the handle. + Status Instantiate() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + // Re-use existing handle if it's been set, effectively caching it. + if (handle_ != kInvalidHandle) { + return Status::OK(); + } + AttrValueMap attr_values = func_.attr(); + FunctionLibraryRuntime::InstantiateOptions opts; + opts.target = target_device_; + return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts, + &handle_); + } + + // Returns true if we've got to the end of the sequence and exhausted the + // buffer. + bool Finished() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return end_of_sequence_ && buffer_.empty(); + } + + // Cancels any buffering / prefetching going on. + void Cancel() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + cancelled_ = true; + while (is_buffering_) { + cond_var_.wait(l); + } + } + + // Cancels all pending operations and then clears out the state. + void Reset() LOCKS_EXCLUDED(mu_) { + Cancel(); + mutex_lock l(mu_); + buffer_.clear(); + requests_.clear(); + is_buffering_ = false; + end_of_sequence_ = false; + cancelled_ = false; + } + + // If the buffer has anything, runs `callback` on the first element in the + // buffer, else schedules the `callback` to be called. Requires `args` and + // `lib` in case more function calls need to be scheduled. + void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) { + bool start_buffering = false; + bool produced_output = false; + BufferElement buffer_element; + { + mutex_lock l(mu_); + if (!is_buffering_ && !end_of_sequence_) { + start_buffering = true; + } + if (!buffer_.empty()) { + produced_output = true; + std::swap(buffer_element, buffer_.front()); + buffer_.pop_front(); + } else { + produced_output = false; + requests_.push_back(std::move(callback)); + } + } + if (produced_output) { + callback(buffer_element); + } + if (start_buffering) { + FillBuffer(); + } + } + + private: + void FillBuffer() LOCKS_EXCLUDED(mu_) { + FunctionLibraryRuntime::Handle handle; + std::vector<FunctionBufferCallback> cancellation_callbacks; + std::vector<BufferElement> cancellation_buffer_elements; + bool cancelled = false; + { + mutex_lock l(mu_); + handle = handle_; + if (cancelled_) { + cancelled = true; + // Run through and fulfill all pending requests, if possible. + while (!requests_.empty()) { + if (!buffer_.empty()) { + cancellation_buffer_elements.push_back(std::move(buffer_.front())); + buffer_.pop_front(); + cancellation_callbacks.push_back(std::move(requests_.front())); + requests_.pop_front(); + } else { + LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: " + << requests_.size() << " requests"; + break; + } + } + is_buffering_ = false; + } else { + is_buffering_ = true; + } + } + if (cancelled) { + for (int i = 0; i < cancellation_callbacks.size(); ++i) { + cancellation_callbacks[i](cancellation_buffer_elements[i]); + } + cond_var_.notify_all(); + return; + } + FunctionLibraryRuntime::Options opts; + // Copied from CapturedFunction::generate_step_id(); + opts.step_id = -std::abs(static_cast<int64>(random::New64())); + opts.source_device = source_device_; + AllocatorAttributes arg_alloc_attr; + arg_alloc_attr.set_on_host(true); + opts.args_alloc_attrs.push_back(arg_alloc_attr); + for (const auto& dtype : output_types_) { + AllocatorAttributes ret_alloc_attrs; + if (DataTypeAlwaysOnHost(dtype)) { + ret_alloc_attrs.set_on_host(true); + } + opts.rets_alloc_attrs.push_back(ret_alloc_attrs); + } + if (opts.source_device != target_device_) { + opts.remote_execution = true; + } + opts.create_rendezvous = true; + auto* rets = new std::vector<Tensor>; + lib_->Run(opts, handle, func_args_, rets, + [this, rets](const Status& status) { + FunctionBufferCallback callback = nullptr; + BufferElement buffer_front; + bool restart_buffering = false; + { + mutex_lock l(mu_); + BufferElement buffer_element; + buffer_element.status = status; + if (status.ok()) { + buffer_element.value.swap(*rets); + } else { + end_of_sequence_ = true; + is_buffering_ = false; + } + buffer_.push_back(std::move(buffer_element)); + if (!requests_.empty()) { + buffer_front = std::move(buffer_.front()); + buffer_.pop_front(); + callback = std::move(requests_.front()); + requests_.pop_front(); + } + if (buffer_.size() < buffer_size_ && !end_of_sequence_) { + restart_buffering = true; + } else { + // When the buffer is full, we don't want to call + // FillBuffer() unless we're in cancellation phase in which + // case FillBuffer() will do the final cleanup post + // cancellation. + if (cancelled_) { + restart_buffering = true; + } + is_buffering_ = false; + } + } + if (callback != nullptr) { + callback(buffer_front); + } + if (restart_buffering) { + FillBuffer(); + } + }); + } + + mutex mu_; + FunctionLibraryRuntime* lib_; + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + NameAttrList func_; + const int64 buffer_size_; + const string source_device_; + const string target_device_; + const std::vector<Tensor> func_args_; + const DataTypeVector output_types_; + FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); + std::deque<BufferElement> buffer_ GUARDED_BY(mu_); + std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_); + bool is_buffering_ GUARDED_BY(mu_); + bool end_of_sequence_ GUARDED_BY(mu_); + bool cancelled_ GUARDED_BY(mu_); + condition_variable cond_var_; +}; + +class FunctionBufferResourceHandleOp : public OpKernel { + public: + explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx), flib_def_(nullptr) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + } + + ~FunctionBufferResourceHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete<FunctionBufferingResource>(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* string_arg; + OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg)); + std::vector<Tensor> func_args; + func_args.push_back(*string_arg); + + const string& source_device = ctx->device()->name(); + + // Obtain and canonicalize target_device. + const Tensor* target_arg; + OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg)); + string target_device; + OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName( + target_arg->scalar<string>()(), source_device, + &target_device)); + + FunctionLibraryRuntime* lib = ctx->function_library(); + OP_REQUIRES(ctx, lib != nullptr, + errors::Internal("No function library is provided.")); + + mutex_lock l(mu_); + if (!initialized_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); + FunctionLibraryRuntime* clone_lib; + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr; + OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib)); + // Create the resource. + FunctionBufferingResource* buffer; + OP_REQUIRES_OK( + ctx, + ctx->resource_manager()->LookupOrCreate<FunctionBufferingResource>( + cinfo_.container(), cinfo_.name(), &buffer, + [clone_lib, &pflr, &source_device, &target_device, func_args, + this](FunctionBufferingResource** ptr) { + *ptr = new FunctionBufferingResource( + clone_lib, std::move(pflr), func_, buffer_size_, + source_device, target_device, func_args, output_types_); + return Status::OK(); + })); + core::ScopedUnref s(buffer); + OP_REQUIRES_OK(ctx, buffer->Instantiate()); + initialized_ = true; + } + + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<FunctionBufferingResource>())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + std::unique_ptr<FunctionLibraryDefinition> flib_def_; + NameAttrList func_; + int64 buffer_size_; + string container_; + string name_; + DataTypeVector output_types_; +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource") + .Device(DEVICE_CPU) + .HostMemory("resource") + .HostMemory("string_arg") + .HostMemory("target_device"), + FunctionBufferResourceHandleOp); +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource") + .Device(DEVICE_GPU) + .HostMemory("resource") + .HostMemory("string_arg") + .HostMemory("target_device"), + FunctionBufferResourceHandleOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource") + .Device(DEVICE_SYCL) + .HostMemory("resource") + .HostMemory("string_arg") + .HostMemory("target_device"), + FunctionBufferResourceHandleOp); +#endif // TENSORFLOW_USE_SYCL + +// Prefetches and fills up a buffer by calling a function that provides the +// elements to buffer. +class FunctionBufferingResourceGetNextOp : public AsyncOpKernel { + public: + explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx) {} + + ~FunctionBufferingResourceGetNextOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + ResourceHandle handle; + OP_REQUIRES_OK_ASYNC( + ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done); + FunctionBufferingResource* buffer = nullptr; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer), + done); + + if (buffer->Finished()) { + buffer->Unref(); + ctx->SetStatus(errors::OutOfRange("end_of_sequence")); + done(); + return; + } + + FunctionBufferCallback callback = + [ctx, buffer, done](const BufferElement& buffer_element) { + Status s = buffer_element.status; + if (!s.ok()) { + ctx->SetStatus(s); + buffer->Unref(); + done(); + return; + } + for (size_t i = 0; i < buffer_element.value.size(); ++i) { + ctx->set_output(i, buffer_element.value[i]); + } + buffer->Unref(); + done(); + }; + buffer->MaybeGet(std::move(callback)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext") + .Device(DEVICE_CPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceGetNextOp); +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext") + .Device(DEVICE_GPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceGetNextOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext") + .Device(DEVICE_SYCL) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceGetNextOp); +#endif // TENSORFLOW_USE_SYCL + +// Resets the FunctionBufferingResource, cancelling all pending requests and +// clearing out the buffer. +class FunctionBufferingResourceResetOp : public OpKernel { + public: + explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + ~FunctionBufferingResourceResetOp() override {} + + void Compute(OpKernelContext* ctx) override { + ResourceHandle handle; + OP_REQUIRES_OK(ctx, + HandleFromInput(ctx, "function_buffer_resource", &handle)); + FunctionBufferingResource* buffer = nullptr; + OP_REQUIRES_OK( + ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer)); + core::ScopedUnref s(buffer); + + buffer->Reset(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset") + .Device(DEVICE_CPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset") + .Device(DEVICE_GPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset") + .Device(DEVICE_SYCL) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#endif // TENSORFLOW_USE_SYCL + +class IteratorGetDeviceOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* ctx) override { + // NOTE(mrry): We do not currently Validate that the handle + // corresponds to a real IteratorResource, because that symbol is + // not exposed from the framework library. + Tensor* device_name_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &device_name_t)); + // NOTE(mrry): Since the operation's input is a resource, we must be + // colocated with it, and so we can simply return the current device's + // name without looking at the input. + device_name_t->scalar<string>()() = ctx->device()->name(); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU), + IteratorGetDeviceOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc new file mode 100644 index 0000000000..c80493d3a1 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -0,0 +1,220 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +namespace data { +namespace { + +class ThreadPoolResource : public ResourceBase { + public: + ThreadPoolResource(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads, bool low_latency_hint, + int max_intra_op_parallelism) + : thread_pool_(env, thread_options, name, num_threads, low_latency_hint), + max_intra_op_parallelism_(max_intra_op_parallelism) {} + + // Schedules fn() for execution in the pool of threads. + void Schedule(std::function<void()> fn) { + if (max_intra_op_parallelism_ < 0) { + thread_pool_.Schedule(std::move(fn)); + } else { + thread_pool_.Schedule(std::bind( + [this](std::function<void()> bound_fn) { + // TODO(mrry): Consider moving this thread-local configuration to + // the threads themselves. + ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_); + bound_fn(); + }, + std::move(fn))); + } + } + + string DebugString() override { return "ThreadPoolResource"; } + + private: + thread::ThreadPool thread_pool_; + const int max_intra_op_parallelism_; +}; + +// Creates a handle to a ThreadPool resource. Note that we don't use +// ResourceOpKernel here because the ThreadPoolResource constructor requires +// access to `OpKernelContext::env()`, which isn't provided by +// `ResourceOpKernel<T>::CreateResource()`. +class ThreadPoolHandleOp : public OpKernel { + public: + explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", + &max_intra_op_parallelism_)); + OP_REQUIRES( + ctx, num_threads_ > 0, + errors::InvalidArgument("`num_threads` must be greater than zero.")); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~ThreadPoolHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete<ThreadPoolResource>(cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + ThreadPoolResource* resource; + OP_REQUIRES_OK(ctx, mgr->LookupOrCreate<ThreadPoolResource>( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](ThreadPoolResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new ThreadPoolResource( + ctx->env(), {}, display_name_, + num_threads_, max_intra_op_parallelism_, + false /* low_latency_hint */); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<ThreadPoolResource>())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + string display_name_; + int num_threads_; + int max_intra_op_parallelism_; +}; + +class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + ThreadPoolResource* threadpool_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), + &threadpool_resource)); + core::ScopedUnref unref_iterator(threadpool_resource); + + *output = new Dataset(ctx, input, threadpool_resource); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + ThreadPoolResource* threadpool) + : DatasetBase(DatasetContext(ctx)), + input_(input), + threadpool_(threadpool) { + input_->Ref(); + threadpool_->Ref(); + } + + ~Dataset() override { + input_->Unref(); + threadpool_->Unref(); + } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { + return "ThreadPoolDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + ThreadPoolResource* pool = dataset()->threadpool_; + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = [pool](std::function<void()> c) { + pool->Schedule(std::move(c)); + }; + params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + params.lib = ctx->lib(); + params.function_library = ctx->function_library(); + params.allocator_getter = ctx->allocator_getter(); + IteratorContext threadpool_ctx(params); + return input_impl_->GetNext(&threadpool_ctx, out_tensors, + end_of_sequence); + } + + private: + std::unique_ptr<IteratorBase> input_impl_; + }; + + const DatasetBase* const input_; + ThreadPoolResource* const threadpool_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU), + ThreadPoolHandleOp); +REGISTER_KERNEL_BUILDER( + Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU), + ThreadPoolDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc new file mode 100644 index 0000000000..cd612e0eb2 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc @@ -0,0 +1,224 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { +namespace data { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class UniqueDatasetOp : public UnaryDatasetOpKernel { + public: + explicit UniqueDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OP_REQUIRES(ctx, input->output_dtypes().size() == 1, + errors::InvalidArgument("UniqueDataset only supports " + "inputs with a single component.")); + + DataType input_dtype = input->output_dtypes()[0]; + OP_REQUIRES(ctx, + input_dtype == DT_INT32 || input_dtype == DT_INT64 || + input_dtype == DT_STRING, + errors::InvalidArgument( + "UniqueDataset only supports inputs with a single " + "`tf.int32`, `tf.int64`, or `tf.string` component.")); + + *output = new Dataset(ctx, input); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), input_(input) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Unique")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { + return strings::StrCat("UniqueDatasetOp::Dataset"); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const typename Iterator::Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + bool saw_new_value; + do { + saw_new_value = false; + out_tensors->clear(); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (*end_of_sequence) { + break; + } + DCHECK_EQ(1, out_tensors->size()); + saw_new_value = unique_elements_.insert((*out_tensors)[0]).second; + } while (!saw_new_value); + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (input_impl_) { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name("unique_elements_size"), unique_elements_.size())); + size_t i = 0; + for (const Tensor& t : unique_elements_) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("unique_elements[", i++, "]")), t)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } + int64 num_unique_elements; + unique_elements_.clear(); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"), + &num_unique_elements)); + for (int64 i = 0; i < num_unique_elements; ++i) { + Tensor unique_element; + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("unique_elements[", i, "]")), + &unique_element)); + auto insert_result = unique_elements_.insert(unique_element); + if (!insert_result.second) { + return errors::InvalidArgument( + "Checkpoint contained two unique elements with the same " + "value."); + } + } + return Status::OK(); + } + + private: + struct TensorHash { + size_t operator()(const Tensor& t) const { + if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) { + return Hash64(t.tensor_data().data(), t.tensor_data().size()); + } else { + DCHECK_EQ(DT_STRING, t.dtype()); + auto flat_t = t.flat<string>(); + uint64 hash = 0; + for (int64 i = 0; i < t.NumElements(); ++i) { + hash = Hash64Combine(hash, Hash64(flat_t(i))); + } + return static_cast<size_t>(hash); + } + } + }; + + struct TensorKeyEqual { + bool operator()(const Tensor& lhs, const Tensor& rhs) const { + if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) { + return false; + } + switch (lhs.dtype()) { +#define HANDLE_TYPE(T) \ + case T: \ + do { \ + auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \ + auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \ + for (int64 i = 0; i < lhs.NumElements(); ++i) { \ + if (lhs_flat(i) != rhs_flat(i)) { \ + return false; \ + } \ + } \ + return true; \ + } while (0) + + HANDLE_TYPE(DT_INT32); + HANDLE_TYPE(DT_INT64); + HANDLE_TYPE(DT_STRING); + default: + DCHECK(false) << "UniqueDataset unhandled data type: " + << DataTypeString(lhs.dtype()); + return false; + } + } + }; + + mutex mu_; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_ + GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU), + UniqueDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow |