aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/generator_dataset_op.cc
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-17 16:41:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 16:50:56 -0700
commit6805a8b27759a530f0ebab0670593a05455a64a0 (patch)
tree1ea29c728b29b5b5641ee00997debb7737bcb13c /tensorflow/core/kernels/data/generator_dataset_op.cc
parent0cdf60ff8239a68326af9610e715f42c773be731 (diff)
Changing `OpInputList` so that it is a forward iterator and taking advantage of the fact in the tf.data kernels.
PiperOrigin-RevId: 213361953
Diffstat (limited to 'tensorflow/core/kernels/data/generator_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc44
1 files changed, 9 insertions, 35 deletions
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index ac5cc1b2c1..71a36314a0 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -145,44 +145,18 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
- OpInputList init_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
- &init_func_other_args_input));
- std::vector<Tensor> init_func_other_args;
- init_func_other_args.reserve(init_func_other_args_input.size());
- for (const Tensor& t : init_func_other_args_input) {
- init_func_other_args.push_back(t);
- }
std::unique_ptr<CapturedFunction> init_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
- &init_func));
-
- OpInputList next_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
- &next_func_other_args_input));
- std::vector<Tensor> next_func_other_args;
- next_func_other_args.reserve(next_func_other_args_input.size());
- for (const Tensor& t : next_func_other_args_input) {
- next_func_other_args.push_back(t);
- }
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+ init_func_, ctx, "init_func_other_args", &init_func));
+
std::unique_ptr<CapturedFunction> next_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
- &next_func));
-
- OpInputList finalize_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
- &finalize_func_other_args_input));
- std::vector<Tensor> finalize_func_other_args;
- finalize_func_other_args.reserve(finalize_func_other_args_input.size());
- for (const Tensor& t : finalize_func_other_args_input) {
- finalize_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> finalize_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- finalize_func_, std::move(finalize_func_other_args),
- &finalize_func));
+ next_func_, ctx, "next_func_other_args", &next_func));
+
+ std::unique_ptr<CapturedFunction> finalize_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
+ "finalize_func_other_args",
+ &finalize_func));
*output =
new Dataset(ctx, std::move(init_func), std::move(next_func),