aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/parallel_map_iterator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_map_iterator.cc')
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc17
1 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index ebf41925c9..13bd4b6036 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -180,7 +179,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
+ std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
@@ -209,15 +208,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
+ // Call `func_(input_element)`, store the result in `result->return_values`,
+ // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
};
- // Apply the map function on `input_element`, storing the result in
- // `result->return_values`, and invoking `done` when finished.
- map_func_(ctx.get(), prefix(), std::move(input_element),
- &result->return_values, std::move(done));
+ map_func_(ctx.get(), std::move(input_element), &result->return_values,
+ std::move(done));
}
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
@@ -350,9 +349,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
- return MakeUnique<ParallelMapIterator>(
- params, input_dataset, std::move(init_func), std::move(map_func),
- num_parallel_calls);
+ return std::unique_ptr<IteratorBase>(
+ new ParallelMapIterator(params, input_dataset, std::move(init_func),
+ std::move(map_func), num_parallel_calls));
}
} // namespace data