aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/generator_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/generator_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc57
1 files changed, 20 insertions, 37 deletions
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index ccee690d7e..b4367d5a11 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
@@ -85,8 +86,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
return Status::OK();
}
@@ -95,6 +94,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ if (!initialized_) {
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ initialized_ = true;
+ }
+
if (finalized_) {
*end_of_sequence = true;
return Status::OK();
@@ -122,6 +127,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private:
mutex mu_;
+ bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
@@ -144,54 +150,31 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
- OpInputList init_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
- &init_func_other_args_input));
- std::vector<Tensor> init_func_other_args;
- init_func_other_args.reserve(init_func_other_args_input.size());
- for (const Tensor& t : init_func_other_args_input) {
- init_func_other_args.push_back(t);
- }
std::unique_ptr<CapturedFunction> init_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
- &init_func));
-
- OpInputList next_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
- &next_func_other_args_input));
- std::vector<Tensor> next_func_other_args;
- next_func_other_args.reserve(next_func_other_args_input.size());
- for (const Tensor& t : next_func_other_args_input) {
- next_func_other_args.push_back(t);
- }
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+ init_func_, ctx, "init_func_other_args", &init_func));
+
std::unique_ptr<CapturedFunction> next_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
- &next_func));
-
- OpInputList finalize_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
- &finalize_func_other_args_input));
- std::vector<Tensor> finalize_func_other_args;
- finalize_func_other_args.reserve(finalize_func_other_args_input.size());
- for (const Tensor& t : finalize_func_other_args_input) {
- finalize_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> finalize_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- finalize_func_, std::move(finalize_func_other_args),
- &finalize_func));
+ next_func_, ctx, "next_func_other_args", &next_func));
+
+ std::unique_ptr<CapturedFunction> finalize_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
+ "finalize_func_other_args",
+ &finalize_func));
*output =
new Dataset(ctx, std::move(init_func), std::move(next_func),
std::move(finalize_func), output_types_, output_shapes_);
}
+namespace {
REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU),
GeneratorDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"),
GeneratorDatasetOp);
+} // namespace
+} // namespace data
} // namespace tensorflow