diff options
author | 2018-08-07 16:42:16 -0700 | |
---|---|---|
committer | 2018-08-07 16:51:57 -0700 | |
commit | 75ee9f5e9bc04c312363d9c0836dd9c3851ead64 (patch) | |
tree | f17ca719eb10753eeb45206c7d7ae4303a7379cd /tensorflow/core/kernels/data/generator_dataset_op.cc | |
parent | 6f07c3f37425ca0f893cd7f2f457830662ffed02 (diff) |
Making PrefetchToDevice work on XLA Compile on Demand mode. Also adds a bunch of dataset / iterator kernel registrations for XLA.
PiperOrigin-RevId: 207802858
Diffstat (limited to 'tensorflow/core/kernels/data/generator_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/generator_dataset_op.cc | 292 |
1 files changed, 137 insertions, 155 deletions
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 0981e42ba1..c4dd849b8b 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -15,192 +15,174 @@ limitations under the License. #include <iterator> #include <vector> -#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/kernels/data/generator_dataset_op.h" + #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { -namespace { - // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class GeneratorDatasetOp : public DatasetOpKernel { +class GeneratorDatasetOp::Dataset : public GraphDatasetBase { public: - explicit GeneratorDatasetOp(OpKernelConstruction* ctx) - : DatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func, + std::unique_ptr<CapturedFunction> next_func, + std::unique_ptr<CapturedFunction> finalize_func, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : GraphDatasetBase(ctx), + init_func_(std::move(init_func)), + next_func_(std::move(next_func)), + finalize_func_(std::move(finalize_func)), + output_types_(output_types), + output_shapes_(output_shapes) {} + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Generator")})); } - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - OpInputList init_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", - &init_func_other_args_input)); - std::vector<Tensor> init_func_other_args; - init_func_other_args.reserve(init_func_other_args_input.size()); - for (const Tensor& t : init_func_other_args_input) { - init_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> init_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - init_func_, std::move(init_func_other_args), &init_func)); - - OpInputList next_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", - &next_func_other_args_input)); - std::vector<Tensor> next_func_other_args; - next_func_other_args.reserve(next_func_other_args_input.size()); - for (const Tensor& t : next_func_other_args_input) { - next_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> next_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - next_func_, std::move(next_func_other_args), &next_func)); - - OpInputList finalize_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", - &finalize_func_other_args_input)); - std::vector<Tensor> finalize_func_other_args; - finalize_func_other_args.reserve(finalize_func_other_args_input.size()); - for (const Tensor& t : finalize_func_other_args_input) { - finalize_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> finalize_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - finalize_func_, std::move(finalize_func_other_args), - &finalize_func)); - - *output = - new Dataset(ctx, std::move(init_func), std::move(next_func), - std::move(finalize_func), output_types_, output_shapes_); + const DataTypeVector& output_dtypes() const override { return output_types_; } + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; } + string DebugString() const override { return "GeneratorDatasetOp::Dataset"; } + private: - class Dataset : public GraphDatasetBase { + class Iterator : public DatasetIterator<Dataset> { public: - Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func, - std::unique_ptr<CapturedFunction> next_func, - std::unique_ptr<CapturedFunction> finalize_func, - const DataTypeVector& output_types, - const std::vector<PartialTensorShape>& output_shapes) - : GraphDatasetBase(ctx), - init_func_(std::move(init_func)), - next_func_(std::move(next_func)), - finalize_func_(std::move(finalize_func)), - output_types_(output_types), - output_shapes_(output_shapes) {} - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Generator")})); - } - - const DataTypeVector& output_dtypes() const override { - return output_types_; - } - const std::vector<PartialTensorShape>& output_shapes() const override { - return output_shapes_; - } - - string DebugString() const override { - return "GeneratorDatasetOp::Dataset"; - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} - - ~Iterator() override { - if (!finalized_) { - std::vector<Tensor> ignored; - Status s = - dataset()->finalize_func_->RunInstantiated(state_, &ignored); - if (!s.ok()) { - LOG(WARNING) - << "Error occurred when finalizing GeneratorDataset iterator: " - << s; - } + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + ~Iterator() override { + if (!finalized_) { + std::vector<Tensor> ignored; + Status s = dataset()->finalize_func_->RunInstantiated(state_, &ignored); + if (!s.ok()) { + LOG(WARNING) + << "Error occurred when finalizing GeneratorDataset iterator: " + << s; } } + } - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - - if (!initialized_) { - TF_RETURN_IF_ERROR( - dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); - // Explicitly instantiate the finalize function here so that - // we can invoke it in the destructor. - TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); - initialized_ = true; - } + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + + if (!initialized_) { + TF_RETURN_IF_ERROR( + dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); + // Explicitly instantiate the finalize function here so that + // we can invoke it in the destructor. + TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); + initialized_ = true; + } - if (finalized_) { - *end_of_sequence = true; - return Status::OK(); - } + if (finalized_) { + *end_of_sequence = true; + return Status::OK(); + } - Status s = dataset()->next_func_->RunWithBorrowedArgs(ctx, state_, - out_tensors); - if (s.ok()) { - *end_of_sequence = false; - } else if (errors::IsOutOfRange(s)) { - // `next_func` may deliberately raise `errors::OutOfRange` - // to indicate that we should terminate the iteration. - s = Status::OK(); - *end_of_sequence = true; - - // NOTE(mrry): We ignore any tensors returned by the - // finalize function. - std::vector<Tensor> ignored; - TF_RETURN_IF_ERROR( - dataset()->finalize_func_->RunInstantiated(state_, &ignored)); - finalized_ = true; - } - return s; + Status s = + dataset()->next_func_->RunWithBorrowedArgs(ctx, state_, out_tensors); + if (s.ok()) { + *end_of_sequence = false; + } else if (errors::IsOutOfRange(s)) { + // `next_func` may deliberately raise `errors::OutOfRange` + // to indicate that we should terminate the iteration. + s = Status::OK(); + *end_of_sequence = true; + + // NOTE(mrry): We ignore any tensors returned by the + // finalize function. + std::vector<Tensor> ignored; + TF_RETURN_IF_ERROR( + dataset()->finalize_func_->RunInstantiated(state_, &ignored)); + finalized_ = true; } + return s; + } - private: - mutex mu_; - bool initialized_ GUARDED_BY(mu_) = false; - bool finalized_ GUARDED_BY(mu_) = false; - std::vector<Tensor> state_ GUARDED_BY(mu_); - }; - - const std::unique_ptr<CapturedFunction> init_func_; - const std::unique_ptr<CapturedFunction> next_func_; - const std::unique_ptr<CapturedFunction> finalize_func_; - const DataTypeVector output_types_; - const std::vector<PartialTensorShape> output_shapes_; + private: + mutex mu_; + bool initialized_ GUARDED_BY(mu_) = false; + bool finalized_ GUARDED_BY(mu_) = false; + std::vector<Tensor> state_ GUARDED_BY(mu_); }; - DataTypeVector output_types_; - std::vector<PartialTensorShape> output_shapes_; - NameAttrList init_func_; - NameAttrList next_func_; - NameAttrList finalize_func_; + const std::unique_ptr<CapturedFunction> init_func_; + const std::unique_ptr<CapturedFunction> next_func_; + const std::unique_ptr<CapturedFunction> finalize_func_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; }; +GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); +} + +void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx, + DatasetBase** output) { + OpInputList init_func_other_args_input; + OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", + &init_func_other_args_input)); + std::vector<Tensor> init_func_other_args; + init_func_other_args.reserve(init_func_other_args_input.size()); + for (const Tensor& t : init_func_other_args_input) { + init_func_other_args.push_back(t); + } + std::unique_ptr<CapturedFunction> init_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args), + &init_func)); + + OpInputList next_func_other_args_input; + OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", + &next_func_other_args_input)); + std::vector<Tensor> next_func_other_args; + next_func_other_args.reserve(next_func_other_args_input.size()); + for (const Tensor& t : next_func_other_args_input) { + next_func_other_args.push_back(t); + } + std::unique_ptr<CapturedFunction> next_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args), + &next_func)); + + OpInputList finalize_func_other_args_input; + OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", + &finalize_func_other_args_input)); + std::vector<Tensor> finalize_func_other_args; + finalize_func_other_args.reserve(finalize_func_other_args_input.size()); + for (const Tensor& t : finalize_func_other_args_input) { + finalize_func_other_args.push_back(t); + } + std::unique_ptr<CapturedFunction> finalize_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create( + finalize_func_, std::move(finalize_func_other_args), + &finalize_func)); + + *output = + new Dataset(ctx, std::move(init_func), std::move(next_func), + std::move(finalize_func), output_types_, output_shapes_); +} + REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU), GeneratorDatasetOp); REGISTER_KERNEL_BUILDER( Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"), GeneratorDatasetOp); -} // namespace - } // namespace tensorflow |