diff options
author | 2018-09-17 16:41:56 -0700 | |
---|---|---|
committer | 2018-09-17 16:50:56 -0700 | |
commit | 6805a8b27759a530f0ebab0670593a05455a64a0 (patch) | |
tree | 1ea29c728b29b5b5641ee00997debb7737bcb13c /tensorflow/core/kernels/data/generator_dataset_op.cc | |
parent | 0cdf60ff8239a68326af9610e715f42c773be731 (diff) |
Changing `OpInputList` so that it is a forward iterator and taking advantage of the fact in the tf.data kernels.
PiperOrigin-RevId: 213361953
Diffstat (limited to 'tensorflow/core/kernels/data/generator_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/generator_dataset_op.cc | 44 |
1 files changed, 9 insertions, 35 deletions
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index ac5cc1b2c1..71a36314a0 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -145,44 +145,18 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx) void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { - OpInputList init_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", - &init_func_other_args_input)); - std::vector<Tensor> init_func_other_args; - init_func_other_args.reserve(init_func_other_args_input.size()); - for (const Tensor& t : init_func_other_args_input) { - init_func_other_args.push_back(t); - } std::unique_ptr<CapturedFunction> init_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args), - &init_func)); - - OpInputList next_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", - &next_func_other_args_input)); - std::vector<Tensor> next_func_other_args; - next_func_other_args.reserve(next_func_other_args_input.size()); - for (const Tensor& t : next_func_other_args_input) { - next_func_other_args.push_back(t); - } + OP_REQUIRES_OK(ctx, CapturedFunction::Create( + init_func_, ctx, "init_func_other_args", &init_func)); + std::unique_ptr<CapturedFunction> next_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args), - &next_func)); - - OpInputList finalize_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", - &finalize_func_other_args_input)); - std::vector<Tensor> finalize_func_other_args; - finalize_func_other_args.reserve(finalize_func_other_args_input.size()); - for (const Tensor& t : finalize_func_other_args_input) { - finalize_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> finalize_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create( - finalize_func_, std::move(finalize_func_other_args), - &finalize_func)); + next_func_, ctx, "next_func_other_args", &next_func)); + + std::unique_ptr<CapturedFunction> finalize_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx, + "finalize_func_other_args", + &finalize_func)); *output = new Dataset(ctx, std::move(init_func), std::move(next_func), |