From eb0f862ba60f41e8d0f06ceb6fc65f7f9905a25a Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 8 Oct 2018 15:27:40 -0700 Subject: Automated rollback of commit 13b47e6c4f9d7b295948b1057139bf676e394b6f PiperOrigin-RevId: 216260575 --- tensorflow/core/kernels/data/iterator_ops.cc | 4 ++++ .../core/kernels/data/map_and_batch_dataset_op.cc | 9 ++++---- tensorflow/core/kernels/data/model_dataset_op.cc | 10 ++++---- .../kernels/data/parallel_interleave_dataset_op.cc | 27 +++++++++------------- .../core/kernels/data/parallel_map_iterator.cc | 9 ++++---- .../core/kernels/data/prefetch_dataset_op.cc | 10 ++++---- tensorflow/core/kernels/data/writer_ops.cc | 12 +++++----- 7 files changed, 37 insertions(+), 44 deletions(-) diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 8acd6cc724..7a833668ac 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/renamed_device.h" +#include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" @@ -25,11 +27,13 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/public/session_options.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 0fb721cd7c..f45a239793 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -445,10 +445,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { auto ctx_copy = std::make_shared(*ctx); - runner_thread_ = - MakeUnique(ctx->env(), "runner_thread"); - runner_thread_->Schedule( - std::bind(&Iterator::RunnerThread, this, ctx_copy)); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + std::bind(&Iterator::RunnerThread, this, ctx_copy))); } } @@ -704,7 +703,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr input_impl_; // Buffer for storing the (intermediate) batch results. std::deque> batch_results_ GUARDED_BY(*mu_); - std::unique_ptr runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr runner_thread_ GUARDED_BY(*mu_); bool cancelled_ GUARDED_BY(*mu_) = false; }; diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 859df57962..9aa505f4f1 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -127,10 +126,9 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!optimize_thread_) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); - optimize_thread_ = - MakeUnique(ctx->env(), "optimize_thread"); - optimize_thread_->Schedule( - [this, new_ctx]() { OptimizeThread(new_ctx); }); + optimize_thread_.reset(ctx->env()->StartThread( + {}, "optimize_thread", + [this, new_ctx]() { OptimizeThread(new_ctx); })); } return Status::OK(); } @@ -169,7 +167,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { mutex mu_; condition_variable cond_var_; std::shared_ptr model_; - std::unique_ptr optimize_thread_ GUARDED_BY(mu_); + std::unique_ptr optimize_thread_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_) = false; std::unique_ptr input_impl_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 9c836b836e..6b6b3d6ab9 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -482,10 +481,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back( - MakeUnique(ctx->env(), "worker_thread")); - worker_threads_.back()->Schedule( - [this, new_ctx, i]() { WorkerThread(new_ctx, i); }); + worker_threads_.emplace_back(ctx->env()->StartThread( + {}, "worker_thread", + [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); } } return Status::OK(); @@ -582,10 +580,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } workers_[i].SetInputs(s, std::move(args)); std::shared_ptr new_ctx(new IteratorContext(*ctx)); - worker_threads_.emplace_back( - MakeUnique(ctx->env(), "worker_thread")); - worker_threads_.back()->Schedule( - [this, new_ctx, i]() { WorkerThread(new_ctx, i); }); + worker_threads_.emplace_back(ctx->env()->StartThread( + {}, "worker_thread", + [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); } else { @@ -1050,8 +1047,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // The worker threads. This must be last to ensure the // threads have exited before any other members are deallocated. // TODO(b/65178177): Avoid allocating additional threads. - std::vector> worker_threads_ - GUARDED_BY(mu_); + std::vector> worker_threads_ GUARDED_BY(mu_); }; const DatasetBase* const input_; @@ -1393,10 +1389,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); - runner_thread_ = - MakeUnique(ctx->env(), "runner_thread"); - runner_thread_->Schedule( - [this, new_ctx]() { RunnerThread(new_ctx); }); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + [this, new_ctx]() { RunnerThread(new_ctx); })); } } @@ -1650,7 +1645,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr thread_pool_; - std::unique_ptr runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr runner_thread_ GUARDED_BY(*mu_); // Identifies whether background activity should be cancelled. bool cancelled_ GUARDED_BY(*mu_) = false; diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index e69274e4f2..ebf41925c9 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -181,10 +181,9 @@ class ParallelMapIterator : public DatasetBaseIterator { EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { auto ctx_copy = std::make_shared(*ctx); - runner_thread_ = - MakeUnique(ctx->env(), "runner_thread"); - runner_thread_->Schedule( - std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); } } @@ -332,7 +331,7 @@ class ParallelMapIterator : public DatasetBaseIterator { // Buffer for storing the invocation results. std::deque> invocation_results_ GUARDED_BY(*mu_); - std::unique_ptr runner_thread_ GUARDED_BY(*mu_); + std::unique_ptr runner_thread_ GUARDED_BY(*mu_); bool cancelled_ GUARDED_BY(*mu_) = false; }; diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index e9c38eb8a0..754ed772db 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -257,11 +256,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status EnsurePrefetchThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!prefetch_thread_) { - prefetch_thread_ = - MakeUnique(ctx->env(), "prefetch_thread"); std::shared_ptr new_ctx(new IteratorContext(*ctx)); - prefetch_thread_->Schedule( - [this, new_ctx]() { PrefetchThread(new_ctx); }); + prefetch_thread_.reset(ctx->env()->StartThread( + {}, "prefetch_thread", + [this, new_ctx]() { PrefetchThread(new_ctx); })); } return Status::OK(); } @@ -365,7 +363,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { string prefix_end_; PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); - std::unique_ptr prefetch_thread_ GUARDED_BY(mu_); + std::unique_ptr prefetch_thread_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_) = false; bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; }; diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc index 7bb2077b62..3f76695bb1 100644 --- a/tensorflow/core/kernels/data/writer_ops.cc +++ b/tensorflow/core/kernels/data/writer_ops.cc @@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel { public: explicit ToTFRecordOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - background_worker_( - ctx->env(), - strings::StrCat("to_tf_record_op_", SanitizeThreadSuffix(name()))) { - } + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("to_tf_record__op_", SanitizeThreadSuffix(name())), + 1 /* num_threads */, false /* low_latency_hint */)) {} template Status ParseScalarArgument(OpKernelContext* ctx, @@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel { // The call to `iterator->GetNext()` may block and depend on an // inter-op thread pool thread, so we issue the call from the // owned thread pool. - background_worker_.Schedule([this, ctx, done]() { + thread_pool_->Schedule([this, ctx, done]() { string filename; OP_REQUIRES_OK_ASYNC( ctx, ParseScalarArgument(ctx, "filename", &filename), done); @@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel { } private: - BackgroundWorker background_worker_; + std::unique_ptr thread_pool_; }; REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU), -- cgit v1.2.3