diff options
author | 2018-09-17 16:41:56 -0700 | |
---|---|---|
committer | 2018-09-17 16:50:56 -0700 | |
commit | 6805a8b27759a530f0ebab0670593a05455a64a0 (patch) | |
tree | 1ea29c728b29b5b5641ee00997debb7737bcb13c /tensorflow | |
parent | 0cdf60ff8239a68326af9610e715f42c773be731 (diff) |
Changing `OpInputList` so that it is a forward iterator and taking advantage of the fact in the tf.data kernels.
PiperOrigin-RevId: 213361953
Diffstat (limited to 'tensorflow')
16 files changed, 88 insertions, 231 deletions
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index e752599de1..4bbd6c3d7d 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -372,18 +372,37 @@ class OpKernelConstruction { template <typename ListType, typename ElementType> class OpArgIterator { public: - typedef OpArgIterator<ListType, ElementType> ME; + using iterator_category = std::forward_iterator_tag; + using value_type = ElementType; + using pointer = ElementType*; + using reference = ElementType&; + using difference_type = ptrdiff_t; + OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} - bool operator==(const ME& rhs) { + + bool operator==(const OpArgIterator& rhs) { DCHECK(list_ == rhs.list_); return i_ == rhs.i_; } - bool operator!=(const ME& rhs) { + + bool operator!=(const OpArgIterator& rhs) { DCHECK(list_ == rhs.list_); return i_ != rhs.i_; } - void operator++() { ++i_; } - ElementType& operator*() { return (*list_)[i_]; } + + OpArgIterator operator++() { // prefix ++it + ++i_; + return *this; + } + + OpArgIterator operator++(int) { // postfix it++ + OpArgIterator old_value = *this; + ++i_; + return old_value; + } + + reference operator*() { return (*list_)[i_]; } + pointer operator->() { return &(*list_)[i_]; } private: const ListType* const list_; @@ -394,7 +413,7 @@ class OpArgIterator { // that are passed to the op as a single named argument. class OpInputList { public: - typedef OpArgIterator<OpInputList, const Tensor&> Iterator; + typedef OpArgIterator<OpInputList, const Tensor> Iterator; OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} OpInputList(OpKernelContext* ctx, int start, int stop) : ctx_(ctx), start_(start), stop_(stop) {} diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 31c8f5c0ea..b3ab7e2bc6 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -22,41 +22,30 @@ limitations under the License. #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { /* static */ Status CapturedFunction::Create( - const NameAttrList& func, std::vector<Tensor> captured_inputs, + const NameAttrList& func, OpKernelContext* ctx, const string& argument, std::unique_ptr<CapturedFunction>* out_function) { - return Create(func, std::move(captured_inputs), true, out_function); + return CapturedFunction::Create(func, ctx, argument, true, out_function); } -/* static */ Status CapturedFunction::Create( - const NameAttrList& func, std::vector<Tensor> captured_inputs, + const NameAttrList& func, OpKernelContext* ctx, const string& argument, bool use_inter_op_parallelism, std::unique_ptr<CapturedFunction>* out_function) { - out_function->reset(new CapturedFunction(func, std::move(captured_inputs), - use_inter_op_parallelism)); + OpInputList inputs; + TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs)); + std::vector<Tensor> arguments(inputs.begin(), inputs.end()); + *out_function = WrapUnique(new CapturedFunction(func, std::move(arguments), + use_inter_op_parallelism)); return Status::OK(); } -/* static */ -Status CapturedFunction::Create( - const NameAttrList& func, OpKernelContext* ctx, const string& argument, - std::unique_ptr<CapturedFunction>* out_function) { - OpInputList argument_inputs; - TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs)); - std::vector<Tensor> arguments_t; - arguments_t.reserve(argument_inputs.size()); - for (const Tensor& t : argument_inputs) { - arguments_t.push_back(t); - } - return CapturedFunction::Create(func, std::move(arguments_t), out_function); -} - CapturedFunction::~CapturedFunction() { if (lib_ != nullptr && f_handle_ != kInvalidHandle) { lib_->ReleaseHandle(f_handle_).IgnoreError(); diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index 8b420fa5db..a10376bf97 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -42,27 +42,19 @@ namespace data { // context. class CapturedFunction { public: - // Creates a new instance from a list of named attributes and captured inputs. - // - // NOTE(mrry): The `captured_inputs` are passed by value. For - // efficiency, you are recommended to move this argument into the call. - static Status Create(const NameAttrList& func, - std::vector<Tensor> captured_inputs, + // Creates a new instance using a list of named attributes, fetching captured + // inputs from a context argument. + static Status Create(const NameAttrList& func, OpKernelContext* ctx, + const string& argument, std::unique_ptr<CapturedFunction>* out_function); - // Creates a new instance from a list of named attributes and captured inputs. + // Creates a new instance using a list of named attributes, fetching captured + // inputs from a context argument. // // If `use_inter_op_parallelism` is false, the runtime may use an executor // that is optimized for small functions. - static Status Create(const NameAttrList& func, - std::vector<Tensor> captured_inputs, - bool use_inter_op_parallelism, - std::unique_ptr<CapturedFunction>* out_function); - - // Creates a new instance using a list of named attributes, fetching captured - // inputs from a context argument. static Status Create(const NameAttrList& func, OpKernelContext* ctx, - const string& argument, + const string& argument, bool use_inter_op_parallelism, std::unique_ptr<CapturedFunction>* out_function); ~CapturedFunction(); diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index bf0aecaf3c..19c35f94a6 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -37,14 +37,6 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - FunctionLibraryRuntime::Handle pred_handle; OP_REQUIRES_OK(ctx, ctx->function_library()->Instantiate( @@ -61,9 +53,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { Node* ret_node = pred_body->ret_nodes[0]; Node* ret_input_node; OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node)); + std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); if (ret_input_node->def().op() == "_Arg") { int32 index = -1; diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index e3c45ef86c..2fada22a21 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -39,18 +39,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); - + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), output_types_, output_shapes_); } diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index ac5cc1b2c1..71a36314a0 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -145,44 +145,18 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx) void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { - OpInputList init_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", - &init_func_other_args_input)); - std::vector<Tensor> init_func_other_args; - init_func_other_args.reserve(init_func_other_args_input.size()); - for (const Tensor& t : init_func_other_args_input) { - init_func_other_args.push_back(t); - } std::unique_ptr<CapturedFunction> init_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args), - &init_func)); - - OpInputList next_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", - &next_func_other_args_input)); - std::vector<Tensor> next_func_other_args; - next_func_other_args.reserve(next_func_other_args_input.size()); - for (const Tensor& t : next_func_other_args_input) { - next_func_other_args.push_back(t); - } + OP_REQUIRES_OK(ctx, CapturedFunction::Create( + init_func_, ctx, "init_func_other_args", &init_func)); + std::unique_ptr<CapturedFunction> next_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args), - &next_func)); - - OpInputList finalize_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", - &finalize_func_other_args_input)); - std::vector<Tensor> finalize_func_other_args; - finalize_func_other_args.reserve(finalize_func_other_args_input.size()); - for (const Tensor& t : finalize_func_other_args_input) { - finalize_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> finalize_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create( - finalize_func_, std::move(finalize_func_other_args), - &finalize_func)); + next_func_, ctx, "next_func_other_args", &next_func)); + + std::unique_ptr<CapturedFunction> finalize_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx, + "finalize_func_other_args", + &finalize_func)); *output = new Dataset(ctx, std::move(init_func), std::move(next_func), diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index e4fa557598..8b417bb1c2 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -42,50 +42,19 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - // Get captured inputs for the key, reduce, and window_size functions. - OpInputList key_func_other_argument_inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments", - &key_func_other_argument_inputs)); - std::vector<Tensor> key_func_other_arguments; - key_func_other_arguments.reserve(key_func_other_argument_inputs.size()); - for (const Tensor& t : key_func_other_argument_inputs) { - key_func_other_arguments.push_back(t); - } - OpInputList reduce_func_other_argument_inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("reduce_func_other_arguments", - &reduce_func_other_argument_inputs)); - std::vector<Tensor> reduce_func_other_arguments; - reduce_func_other_arguments.reserve( - reduce_func_other_argument_inputs.size()); - for (const Tensor& t : reduce_func_other_argument_inputs) { - reduce_func_other_arguments.push_back(t); - } - OpInputList window_size_func_other_argument_inputs; - OP_REQUIRES_OK(ctx, - ctx->input_list("window_size_func_other_arguments", - &window_size_func_other_argument_inputs)); - std::vector<Tensor> window_size_func_other_arguments; - window_size_func_other_arguments.reserve( - window_size_func_other_argument_inputs.size()); - for (const Tensor& t : window_size_func_other_argument_inputs) { - window_size_func_other_arguments.push_back(t); - } - // TODO(mrry): Refactor CapturedFunction to share the runtime - // state between multiple functions? std::unique_ptr<CapturedFunction> captured_key_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - key_func_, std::move(key_func_other_arguments), - &captured_key_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx, + "key_func_other_arguments", + &captured_key_func)); std::unique_ptr<CapturedFunction> captured_reduce_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(reduce_func_, - std::move(reduce_func_other_arguments), - &captured_reduce_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx, + "reduce_func_other_arguments", + &captured_reduce_func)); std::unique_ptr<CapturedFunction> captured_window_size_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - window_size_func_, std::move(window_size_func_other_arguments), - &captured_window_size_func)); + OP_REQUIRES_OK(ctx, + CapturedFunction::Create(window_size_func_, ctx, + "window_size_func_other_arguments", + &captured_window_size_func)); *output = new Dataset( ctx, input, key_func_, reduce_func_, window_size_func_, diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 0768f46665..0aa802b874 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -39,14 +39,6 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - const Tensor* cycle_length_t; OP_REQUIRES_OK(ctx, ctx->input("cycle_length", &cycle_length_t)); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cycle_length_t->shape()), @@ -66,8 +58,8 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { errors::InvalidArgument("block_length must be greater than zero.")); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), cycle_length, diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 80efac5d4b..83896219a3 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -49,14 +49,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int64 batch_size; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size)); OP_REQUIRES( @@ -93,8 +85,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ParseScalarArgument(ctx, "drop_remainder", &drop_remainder)); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, batch_size, num_parallel_calls, drop_remainder, output_types_, output_shapes_, func_, diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index af301e2b42..f112e1dc43 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -38,18 +38,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), - use_inter_op_parallelism_, &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + use_inter_op_parallelism_, + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), output_types_, output_shapes_); diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc index 6180df5af2..346e4ceebd 100644 --- a/tensorflow/core/kernels/data/optional_ops.cc +++ b/tensorflow/core/kernels/data/optional_ops.cc @@ -108,11 +108,8 @@ class OptionalFromValueOp : public OpKernel { void Compute(OpKernelContext* ctx) override { OpInputList components_input; OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input)); - std::vector<Tensor> components; - components.reserve(components_input.size()); - for (const Tensor& component_t : components_input) { - components.push_back(component_t); - } + std::vector<Tensor> components(components_input.begin(), + components_input.end()); OP_REQUIRES_OK( ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components))); } diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 2f2db09508..9cd46bf5dd 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -44,14 +44,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int64 cycle_length = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "cycle_length", &cycle_length)); @@ -83,8 +75,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<CapturedFunction> captured_func; OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - interleave_func_, std::move(other_arguments), &captured_func)); + ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, interleave_func_, std::move(captured_func), @@ -1102,9 +1094,6 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - int64 cycle_length = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "cycle_length", &cycle_length)); @@ -1128,16 +1117,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { errors::InvalidArgument( "num_parallel_calls must less than or equal to cycle_length.")); - // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`. - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } std::unique_ptr<CapturedFunction> captured_func; OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - interleave_func_, std::move(other_arguments), &captured_func)); + ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, interleave_func_, std::move(captured_func), cycle_length, block_length, diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index b584316d69..6abe6c8338 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -44,14 +44,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int32 num_parallel_calls; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); @@ -60,9 +52,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { "num_parallel_calls must be greater than zero.")); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), - use_inter_op_parallelism_, &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + use_inter_op_parallelism_, + &captured_func)); *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_, output_shapes_, use_inter_op_parallelism_, diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index 0cf5db017b..c28c06da62 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -87,11 +87,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { "Expected len(dense_defaults) == len(dense_keys) but got: ", dense_default_tensors.size(), " vs. ", dense_keys_.size())); - std::vector<Tensor> dense_defaults; - dense_defaults.reserve(dense_default_tensors.size()); - for (const Tensor& dense_default_t : dense_default_tensors) { - dense_defaults.push_back(dense_default_t); - } + std::vector<Tensor> dense_defaults(dense_default_tensors.begin(), + dense_default_tensors.end()); for (int d = 0; d < dense_keys_.size(); ++d) { const Tensor& def_value = dense_defaults[d]; diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index 6e515d6cc8..dbe31f37b8 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -45,23 +45,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { OpInputList initial_state_inputs; OP_REQUIRES_OK(ctx, ctx->input_list("initial_state", &initial_state_inputs)); - std::vector<Tensor> initial_state; - initial_state.reserve(initial_state_inputs.size()); - for (const Tensor& t : initial_state_inputs) { - initial_state.push_back(t); - } - - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } + std::vector<Tensor> initial_state(initial_state_inputs.begin(), + initial_state_inputs.end()); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(initial_state), std::move(captured_func), state_types_, output_types_, diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index e1cefd23d8..ca4ea25b89 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -33,11 +33,7 @@ class TensorDatasetOp : public DatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs)); // TODO(mrry): Validate that the shapes of the "components" tensors match // the "shapes" attr.; - std::vector<Tensor> components; - components.reserve(inputs.size()); - for (const Tensor& t : inputs) { - components.push_back(t); - } + std::vector<Tensor> components(inputs.begin(), inputs.end()); *output = new Dataset(ctx, std::move(components)); } |