diff options
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_map_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/parallel_map_dataset_op.cc | 79 |
1 files changed, 56 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 6abe6c8338..3a14924fba 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/parallel_map_iterator.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/random/random.h" @@ -56,9 +57,55 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + ParallelMapIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::vector<Tensor>* out_tensors, + StatusCallback done) { + raw_captured_func->RunAsync(ctx, std::move(args), out_tensors, + std::move(done), prefix); + }; + if (!use_inter_op_parallelism_) { + map_func = [map_func](IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::vector<Tensor>* out_tensors, + StatusCallback done) { + (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args), + out_tensors, std::move(done))); + }; + } + } else { + std::vector<bool> can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, std::vector<Tensor>* out_tensors, + StatusCallback done) { + const std::vector<Tensor>& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + done(Status::OK()); + }; + } + *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_, output_shapes_, use_inter_op_parallelism_, - std::move(captured_func)); + std::move(captured_func), std::move(map_func)); } private: @@ -69,7 +116,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, bool use_inter_op_parallelism, - std::unique_ptr<CapturedFunction> captured_func) + std::unique_ptr<CapturedFunction> captured_func, + ParallelMapIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), @@ -77,7 +125,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { output_types_(output_types), output_shapes_(output_shapes), use_inter_op_parallelism_(use_inter_op_parallelism), - captured_func_(std::move(captured_func)) { + captured_func_(std::move(captured_func)), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -89,26 +138,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return captured_func_->Instantiate(ctx); }; - const string& new_prefix = strings::StrCat(prefix, "::ParallelMap"); - ParallelMapIteratorFunction map_func = - [this, new_prefix](IteratorContext* ctx, - std::vector<Tensor> input_element, - std::vector<Tensor>* result, StatusCallback done) { - captured_func_->RunAsync(ctx, std::move(input_element), result, - std::move(done), new_prefix); - }; - if (!use_inter_op_parallelism_) { - map_func = [map_func]( - IteratorContext* ctx, std::vector<Tensor> input_element, - std::vector<Tensor>* result, StatusCallback done) { - (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element), - result, std::move(done))); - }; - } - - return NewParallelMapIterator({this, new_prefix}, input_, - std::move(init_func), std::move(map_func), - num_parallel_calls_); + return NewParallelMapIterator( + {this, strings::StrCat(prefix, "::ParallelMap")}, input_, + std::move(init_func), map_func_, num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { @@ -176,6 +208,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; const bool use_inter_op_parallelism_; const std::unique_ptr<CapturedFunction> captured_func_; + const ParallelMapIteratorFunction map_func_; }; DataTypeVector output_types_; |