diff options
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_map_iterator.cc')
-rw-r--r-- | tensorflow/core/kernels/data/parallel_map_iterator.cc | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index ebf41925c9..13bd4b6036 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -180,7 +179,7 @@ class ParallelMapIterator : public DatasetBaseIterator { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - auto ctx_copy = std::make_shared<IteratorContext>(*ctx); + std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); @@ -209,15 +208,15 @@ class ParallelMapIterator : public DatasetBaseIterator { return; } + // Call `func_(input_element)`, store the result in `result->return_values`, + // and notify `result->notification` to unblock a consumer. auto done = [this, result](Status status) { result->status.Update(status); CallCompleted(result); }; - // Apply the map function on `input_element`, storing the result in - // `result->return_values`, and invoking `done` when finished. - map_func_(ctx.get(), prefix(), std::move(input_element), - &result->return_values, std::move(done)); + map_func_(ctx.get(), std::move(input_element), &result->return_values, + std::move(done)); } Status ProcessResult(const std::shared_ptr<InvocationResult>& result, @@ -350,9 +349,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator( const DatasetBase* input_dataset, std::function<Status(IteratorContext*)> init_func, ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { - return MakeUnique<ParallelMapIterator>( - params, input_dataset, std::move(init_func), std::move(map_func), - num_parallel_calls); + return std::unique_ptr<IteratorBase>( + new ParallelMapIterator(params, input_dataset, std::move(init_func), + std::move(map_func), num_parallel_calls)); } } // namespace data |