aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/parallel_map_iterator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_map_iterator.cc')
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc17
1 files changed, 9 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 13bd4b6036..ebf41925c9 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -179,7 +180,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
@@ -208,15 +209,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
- // Call `func_(input_element)`, store the result in `result->return_values`,
- // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
};
- map_func_(ctx.get(), std::move(input_element), &result->return_values,
- std::move(done));
+ // Apply the map function on `input_element`, storing the result in
+ // `result->return_values`, and invoking `done` when finished.
+ map_func_(ctx.get(), prefix(), std::move(input_element),
+ &result->return_values, std::move(done));
}
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
@@ -349,9 +350,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
- return std::unique_ptr<IteratorBase>(
- new ParallelMapIterator(params, input_dataset, std::move(init_func),
- std::move(map_func), num_parallel_calls));
+ return MakeUnique<ParallelMapIterator>(
+ params, input_dataset, std::move(init_func), std::move(map_func),
+ num_parallel_calls);
}
} // namespace data