aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_map_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc79
1 files changed, 56 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 6abe6c8338..3a14924fba 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/random/random.h"
@@ -56,9 +57,55 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ ParallelMapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ raw_captured_func->RunAsync(ctx, std::move(args), out_tensors,
+ std::move(done), prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args),
+ out_tensors, std::move(done)));
+ };
+ }
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args, std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ done(Status::OK());
+ };
+ }
+
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
output_shapes_, use_inter_op_parallelism_,
- std::move(captured_func));
+ std::move(captured_func), std::move(map_func));
}
private:
@@ -69,7 +116,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
bool use_inter_op_parallelism,
- std::unique_ptr<CapturedFunction> captured_func)
+ std::unique_ptr<CapturedFunction> captured_func,
+ ParallelMapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
@@ -77,7 +125,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
output_types_(output_types),
output_shapes_(output_shapes),
use_inter_op_parallelism_(use_inter_op_parallelism),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -89,26 +138,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
- ParallelMapIteratorFunction map_func =
- [this, new_prefix](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done), new_prefix);
- };
- if (!use_inter_op_parallelism_) {
- map_func = [map_func](
- IteratorContext* ctx, std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
- result, std::move(done)));
- };
- }
-
- return NewParallelMapIterator({this, new_prefix}, input_,
- std::move(init_func), std::move(map_func),
- num_parallel_calls_);
+ return NewParallelMapIterator(
+ {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
+ std::move(init_func), map_func_, num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -176,6 +208,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
+ const ParallelMapIteratorFunction map_func_;
};
DataTypeVector output_types_;