diff options
Diffstat (limited to 'tensorflow/core/kernels/data/generator_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/generator_dataset_op.cc | 57 |
1 files changed, 20 insertions, 37 deletions
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index ccee690d7e..b4367d5a11 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -85,8 +86,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx)); TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx)); TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); - TF_RETURN_IF_ERROR( - dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); return Status::OK(); } @@ -95,6 +94,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { bool* end_of_sequence) override { mutex_lock l(mu_); + if (!initialized_) { + TF_RETURN_IF_ERROR( + dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); + initialized_ = true; + } + if (finalized_) { *end_of_sequence = true; return Status::OK(); @@ -122,6 +127,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { private: mutex mu_; + bool initialized_ GUARDED_BY(mu_) = false; bool finalized_ GUARDED_BY(mu_) = false; std::vector<Tensor> state_ GUARDED_BY(mu_); }; @@ -144,54 +150,31 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx) void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { - OpInputList init_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", - &init_func_other_args_input)); - std::vector<Tensor> init_func_other_args; - init_func_other_args.reserve(init_func_other_args_input.size()); - for (const Tensor& t : init_func_other_args_input) { - init_func_other_args.push_back(t); - } std::unique_ptr<CapturedFunction> init_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args), - &init_func)); - - OpInputList next_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", - &next_func_other_args_input)); - std::vector<Tensor> next_func_other_args; - next_func_other_args.reserve(next_func_other_args_input.size()); - for (const Tensor& t : next_func_other_args_input) { - next_func_other_args.push_back(t); - } + OP_REQUIRES_OK(ctx, CapturedFunction::Create( + init_func_, ctx, "init_func_other_args", &init_func)); + std::unique_ptr<CapturedFunction> next_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args), - &next_func)); - - OpInputList finalize_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", - &finalize_func_other_args_input)); - std::vector<Tensor> finalize_func_other_args; - finalize_func_other_args.reserve(finalize_func_other_args_input.size()); - for (const Tensor& t : finalize_func_other_args_input) { - finalize_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> finalize_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create( - finalize_func_, std::move(finalize_func_other_args), - &finalize_func)); + next_func_, ctx, "next_func_other_args", &next_func)); + + std::unique_ptr<CapturedFunction> finalize_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx, + "finalize_func_other_args", + &finalize_func)); *output = new Dataset(ctx, std::move(init_func), std::move(next_func), std::move(finalize_func), output_types_, output_shapes_); } +namespace { REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU), GeneratorDatasetOp); REGISTER_KERNEL_BUILDER( Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"), GeneratorDatasetOp); +} // namespace +} // namespace data } // namespace tensorflow |