diff options
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_map_iterator.cc')
-rw-r--r-- | tensorflow/core/kernels/data/parallel_map_iterator.cc | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 13bd4b6036..ebf41925c9 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -179,7 +180,7 @@ class ParallelMapIterator : public DatasetBaseIterator { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); + auto ctx_copy = std::make_shared<IteratorContext>(*ctx); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); @@ -208,15 +209,15 @@ class ParallelMapIterator : public DatasetBaseIterator { return; } - // Call `func_(input_element)`, store the result in `result->return_values`, - // and notify `result->notification` to unblock a consumer. auto done = [this, result](Status status) { result->status.Update(status); CallCompleted(result); }; - map_func_(ctx.get(), std::move(input_element), &result->return_values, - std::move(done)); + // Apply the map function on `input_element`, storing the result in + // `result->return_values`, and invoking `done` when finished. + map_func_(ctx.get(), prefix(), std::move(input_element), + &result->return_values, std::move(done)); } Status ProcessResult(const std::shared_ptr<InvocationResult>& result, @@ -349,9 +350,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator( const DatasetBase* input_dataset, std::function<Status(IteratorContext*)> init_func, ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { - return std::unique_ptr<IteratorBase>( - new ParallelMapIterator(params, input_dataset, std::move(init_func), - std::move(map_func), num_parallel_calls)); + return MakeUnique<ParallelMapIterator>( + params, input_dataset, std::move(init_func), std::move(map_func), + num_parallel_calls); } } // namespace data |