diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-10-01 17:18:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 17:22:39 -0700 |
commit | bfbe2bbe6a83a4acfa8f87aa5c8228e74b37bb61 (patch) | |
tree | 18a274c3c1a8f917fc8addf9630ddff55436a4fd /tensorflow/core/kernels | |
parent | 80f8931682aeaae89786f0940892a6557b4cfd67 (diff) |
[tf.data] More robust solution for input pipeline <--> performance model coordination.
PiperOrigin-RevId: 215309735
Diffstat (limited to 'tensorflow/core/kernels')
3 files changed, 130 insertions, 123 deletions
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index b4c7f9e510..bf08970560 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -187,29 +187,31 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params), - num_parallel_calls_(params.dataset->num_parallel_calls_) {} + mu_(std::make_shared<mutex>()), + cond_var_(std::make_shared<condition_variable>()), + num_parallel_calls_(std::make_shared<model::SharedState>( + params.dataset->num_parallel_calls_, mu_, cond_var_)) {} ~Iterator() override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; - cond_var_.notify_all(); + cond_var_->notify_all(); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } } Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); - if (num_parallel_calls_ == kAutoTune) { - num_parallel_calls_ = 1; - AddTunableParameter(ctx, "parallelism", - &num_parallel_calls_ /* value */, 1 /* min */, - port::NumSchedulableCPUs() /* max */, &cond_var_); + if (num_parallel_calls_->value == kAutoTune) { + num_parallel_calls_->value = 1; + AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, + port::NumSchedulableCPUs()); } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); } TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); @@ -221,27 +223,27 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { bool* end_of_sequence) override { std::shared_ptr<BatchResult> result; { - mutex_lock l(mu_); + mutex_lock l(*mu_); EnsureRunnerThreadStarted(ctx); while (batch_results_.empty() || batch_results_.front()->num_calls > 0) { RecordStop(ctx); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx); } std::swap(result, batch_results_.front()); batch_results_.pop_front(); - cond_var_.notify_all(); + cond_var_->notify_all(); } return ProcessResult(ctx, result, out_tensors, end_of_sequence); } protected: Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); @@ -257,7 +259,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR( reader->ReadScalar(full_name("call_counter"), &call_counter_)); @@ -298,7 +300,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void Callback(const std::shared_ptr<IteratorContext>& ctx, const std::shared_ptr<BatchResult>& result, const std::shared_ptr<std::vector<Tensor>>& return_values, - int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) { + int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) { result->UpdateStatus(status); if (status.ok()) { EnsureOutputAllocated(ctx, result, return_values); @@ -334,16 +336,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } void CallCompleted(const std::shared_ptr<BatchResult>& result) - LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + LOCKS_EXCLUDED(*mu_) { + mutex_lock l(*mu_); num_calls_--; result->num_calls--; - cond_var_.notify_all(); + cond_var_->notify_all(); } void CallFunction(std::shared_ptr<IteratorContext> ctx, const std::shared_ptr<BatchResult>& result, - int64 offset) LOCKS_EXCLUDED(mu_) { + int64 offset) LOCKS_EXCLUDED(*mu_) { // Get the next input element. std::vector<Tensor> input_element; bool end_of_input; @@ -400,7 +402,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( @@ -476,14 +478,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls; RecordStart(ctx.get()); auto stop_cleanup = gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); }); - new_calls.reserve(num_parallel_calls_); - auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { - int64 num_parallel_calls = num_parallel_calls_; + new_calls.reserve(num_parallel_calls_->value); + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { + int64 num_parallel_calls = num_parallel_calls_->value; int64 max_batch_results = (num_parallel_calls + dataset()->batch_size_ - 1) / dataset()->batch_size_; @@ -494,10 +496,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { }; while (true) { { - mutex_lock l(mu_); + mutex_lock l(*mu_); while (!cancelled_ && busy()) { RecordStop(ctx.get()); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx.get()); } @@ -524,7 +526,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, - size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { batch_results_.emplace_back(new BatchResult(dataset()->batch_size_)); std::shared_ptr<BatchResult> result = batch_results_.back(); string prefix = strings::StrCat("batch_results_", index); @@ -569,7 +571,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status ReadStatus(IteratorStateReader* reader, const string& prefix, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64 code_int; TF_RETURN_IF_ERROR(reader->ReadScalar( full_name(strings::StrCat(prefix, "_code")), &code_int)); @@ -587,7 +589,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status WriteBatchResult(IteratorStateWriter* writer, size_t index) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { std::shared_ptr<BatchResult> result = batch_results_[index]; string prefix = strings::StrCat("batch_results_", index); mutex_lock l(result->mu); @@ -628,7 +630,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status WriteStatus(IteratorStateWriter* writer, const string& prefix, - const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")), static_cast<int64>(status.code()))); @@ -642,24 +644,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { // Used for coordination between the main thread, the runner thread, and // the callback threads. - mutex mu_; + const std::shared_ptr<mutex> mu_; // Used for coordination between the main thread, the runner thread, and // the callback threads. In particular, the runner thread should only - // schedule new calls when the number of in-flight calls is less than the - // user specified level of parallelism and there are slots available in - // the `batch_results_` buffer. - condition_variable cond_var_; + // schedule new calls when the number of in-flight calls is less than + // `num_parallel_calls_->value` and there are slots available in the + // `batch_results_` buffer. + const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. - std::atomic<int64> num_parallel_calls_; + const std::shared_ptr<model::SharedState> num_parallel_calls_; // Counts the number of outstanding calls for this batch. - int64 num_calls_ GUARDED_BY(mu_) = 0; + int64 num_calls_ GUARDED_BY(*mu_) = 0; // Counts the total number of calls. - int64 call_counter_ GUARDED_BY(mu_) = 0; + int64 call_counter_ GUARDED_BY(*mu_) = 0; std::unique_ptr<IteratorBase> input_impl_; // Buffer for storing the (intermediate) batch results. - std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_); - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); - bool cancelled_ GUARDED_BY(mu_) = false; + std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + bool cancelled_ GUARDED_BY(*mu_) = false; }; const DatasetBase* const input_; diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 2bb38bf0b9..6b6b3d6ab9 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -1217,7 +1217,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params), - num_parallel_calls_(params.dataset->num_parallel_calls_), + mu_(std::make_shared<mutex>()), + cond_var_(std::make_shared<condition_variable>()), + num_parallel_calls_(std::make_shared<model::SharedState>( + params.dataset->num_parallel_calls_, mu_, cond_var_)), args_list_(params.dataset->cycle_length_), current_elements_(params.dataset->cycle_length_), element_in_use_(params.dataset->cycle_length_, false), @@ -1227,25 +1230,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { false /* low_latency_hint */)) {} ~Iterator() override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; - cond_var_.notify_all(); + cond_var_->notify_all(); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } } Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); - if (num_parallel_calls_ == kAutoTune) { - num_parallel_calls_ = 1; - AddTunableParameter(ctx, "parallelism", - &num_parallel_calls_ /* value */, 1 /* min */, - dataset()->cycle_length_ /* max */, &cond_var_); + mutex_lock l(*mu_); + if (num_parallel_calls_->value == kAutoTune) { + num_parallel_calls_->value = 1; + AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, + dataset()->cycle_length_); } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); } AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_); TF_RETURN_IF_ERROR( @@ -1259,12 +1261,12 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { std::shared_ptr<InvocationResult> result; do { { - mutex_lock l(mu_); + mutex_lock l(*mu_); EnsureRunnerThreadStarted(ctx); while (invocation_results_.empty() && (!end_of_input_ || num_open_ > 0)) { RecordStop(ctx); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx); } if (!invocation_results_.empty()) { @@ -1274,7 +1276,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } - cond_var_.notify_all(); + cond_var_->notify_all(); } RecordStop(ctx); result->notification.WaitForNotification(); @@ -1290,10 +1292,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); @@ -1331,7 +1333,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64 invocation_results_size; TF_RETURN_IF_ERROR(reader->ReadScalar( @@ -1384,7 +1386,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { }; void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( @@ -1401,7 +1403,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { void FetchOutputs( const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index, const std::vector<std::shared_ptr<InvocationResult>>& results) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); bool end_of_input = false; @@ -1424,14 +1426,14 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { if (end_of_input) { current_elements_[cycle_index].reset(); } - mutex_lock l(mu_); + mutex_lock l(*mu_); element_in_use_[cycle_index] = false; num_calls_--; if (end_of_input) { args_list_[cycle_index].clear(); num_open_--; } - cond_var_.notify_all(); + cond_var_->notify_all(); } // Method responsible for 1) creating iterators out of input elements, 2) @@ -1442,20 +1444,20 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); - auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { return element_in_use_[cycle_index_] || - num_calls_ >= num_parallel_calls_ || + num_calls_ >= num_parallel_calls_->value || invocation_results_.size() >= dataset()->cycle_length_ * dataset()->block_length_; }; while (true) { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait until this thread is cancelled, the end of input has been // reached, or the cycle element at the `cycle_index_` position is // not in use and there is space in the `invocation_results_` queue. while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) { RecordStop(ctx.get()); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx.get()); } @@ -1509,13 +1511,13 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; } - cond_var_.notify_all(); + cond_var_->notify_all(); } } Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, const Status& status) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar( CodeKey(index), static_cast<int64>(status.code()))); if (!status.ok()) { @@ -1526,7 +1528,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } Status ReadStatusLocked(IteratorStateReader* reader, size_t index, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64 code_int; TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); error::Code code = static_cast<error::Code>(code_int); @@ -1553,7 +1555,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } Status WriteCurrentElements(IteratorStateWriter* writer) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { for (int idx = 0; idx < current_elements_.size(); idx++) { if (current_elements_[idx]) { TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx])); @@ -1572,7 +1574,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { Status ReadCurrentElements(IteratorContext* ctx, IteratorStateReader* reader) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { for (int idx = 0; idx < current_elements_.size(); idx++) { if (reader->Contains( full_name(strings::StrCat("args_size[", idx, "]")))) { @@ -1600,7 +1602,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // Used for coordination between the main thread, the runner thread, and // the worker threads. - mutex mu_; + const std::shared_ptr<mutex> mu_; // Used for coordination between the main thread, the runner thread, and // the worker threads. In particular, the runner thread should only @@ -1608,45 +1610,45 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // user specified level of parallelism, there are slots available in the // `invocation_results_` buffer, the current cycle element is not in use, // and there are elements left to be fetched. - condition_variable cond_var_; + const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. - std::atomic<int64> num_parallel_calls_; + const std::shared_ptr<model::SharedState> num_parallel_calls_; // Iterator for input elements. - std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_); // Identifies current cycle element. int64 cycle_index_ = 0; // Arguments for creating an iterator for cycle elements. - std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_); + std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(*mu_); // Iterators for the current cycle elements. Concurrent access is // protected by `element_in_use_`. std::vector<std::unique_ptr<IteratorBase>> current_elements_; // Identifies cycle elements that are in use by worker threads. - std::vector<bool> element_in_use_ GUARDED_BY(mu_); + std::vector<bool> element_in_use_ GUARDED_BY(*mu_); // Buffer for storing the invocation results. std::deque<std::shared_ptr<InvocationResult>> invocation_results_ - GUARDED_BY(mu_); + GUARDED_BY(*mu_); // Identifies whether end of input has been reached. - bool end_of_input_ GUARDED_BY(mu_) = false; + bool end_of_input_ GUARDED_BY(*mu_) = false; // Identifies the number of open iterators. - int64 num_open_ GUARDED_BY(mu_) = 0; + int64 num_open_ GUARDED_BY(*mu_) = 0; // Identifies the number of outstanding calls. - int64 num_calls_ GUARDED_BY(mu_) = 0; + int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr<thread::ThreadPool> thread_pool_; - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); // Identifies whether background activity should be cancelled. - bool cancelled_ GUARDED_BY(mu_) = false; + bool cancelled_ GUARDED_BY(*mu_) = false; }; const DatasetBase* const input_; diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index da067a4e6f..13bd4b6036 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -40,30 +40,32 @@ class ParallelMapIterator : public DatasetBaseIterator { input_dataset_(input_dataset), init_func_(std::move(init_func)), map_func_(std::move(map_func)), - num_parallel_calls_(num_parallel_calls) {} + mu_(std::make_shared<mutex>()), + cond_var_(std::make_shared<condition_variable>()), + num_parallel_calls_(std::make_shared<model::SharedState>( + num_parallel_calls, mu_, cond_var_)) {} ~ParallelMapIterator() override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; - cond_var_.notify_all(); + cond_var_->notify_all(); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } } Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); - if (num_parallel_calls_ == kAutoTune) { - num_parallel_calls_ = 1; + mutex_lock l(*mu_); + if (num_parallel_calls_->value == kAutoTune) { + num_parallel_calls_->value = 1; // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and // use it here for the maximum. - AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */, - 1 /* min */, port::NumSchedulableCPUs() /* max */, - &cond_var_); + AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, + port::NumSchedulableCPUs()); } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); } TF_RETURN_IF_ERROR( input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); @@ -77,16 +79,16 @@ class ParallelMapIterator : public DatasetBaseIterator { bool* end_of_sequence) override { std::shared_ptr<InvocationResult> result; { - mutex_lock l(mu_); + mutex_lock l(*mu_); EnsureRunnerThreadStarted(ctx); while (invocation_results_.empty()) { RecordStop(ctx); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx); } std::swap(result, invocation_results_.front()); invocation_results_.pop_front(); - cond_var_.notify_all(); + cond_var_->notify_all(); } RecordStop(ctx); result->notification.WaitForNotification(); @@ -96,10 +98,10 @@ class ParallelMapIterator : public DatasetBaseIterator { protected: Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); @@ -128,7 +130,7 @@ class ParallelMapIterator : public DatasetBaseIterator { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64 invocation_results_size; TF_RETURN_IF_ERROR(reader->ReadScalar( @@ -175,7 +177,7 @@ class ParallelMapIterator : public DatasetBaseIterator { }; void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( @@ -185,18 +187,18 @@ class ParallelMapIterator : public DatasetBaseIterator { } void CallCompleted(const std::shared_ptr<InvocationResult>& result) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { { - mutex_lock l(mu_); + mutex_lock l(*mu_); num_calls_--; - cond_var_.notify_all(); + cond_var_->notify_all(); } result->notification.Notify(); } void CallFunction(const std::shared_ptr<IteratorContext>& ctx, const std::shared_ptr<InvocationResult>& result) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { // Get the next input element. std::vector<Tensor> input_element; result->status = @@ -239,18 +241,18 @@ class ParallelMapIterator : public DatasetBaseIterator { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); std::vector<std::shared_ptr<InvocationResult>> new_calls; - new_calls.reserve(num_parallel_calls_); - auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { - int64 num_parallel_calls = num_parallel_calls_; + new_calls.reserve(num_parallel_calls_->value); + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { + int64 num_parallel_calls = num_parallel_calls_->value; return num_calls_ >= num_parallel_calls || invocation_results_.size() >= num_parallel_calls; }; while (true) { { - mutex_lock l(mu_); + mutex_lock l(*mu_); while (!cancelled_ && busy()) { RecordStop(ctx.get()); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx.get()); } if (cancelled_) { @@ -261,7 +263,7 @@ class ParallelMapIterator : public DatasetBaseIterator { new_calls.push_back(invocation_results_.back()); num_calls_++; } - cond_var_.notify_all(); + cond_var_->notify_all(); } for (const auto& call : new_calls) { CallFunction(ctx, call); @@ -271,7 +273,8 @@ class ParallelMapIterator : public DatasetBaseIterator { } Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, - const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const Status& status) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR( writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code()))); if (!status.ok()) { @@ -282,7 +285,7 @@ class ParallelMapIterator : public DatasetBaseIterator { } Status ReadStatusLocked(IteratorStateReader* reader, size_t index, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64 code_int; TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); error::Code code = static_cast<error::Code>(code_int); @@ -312,23 +315,23 @@ class ParallelMapIterator : public DatasetBaseIterator { const std::function<Status(IteratorContext*)> init_func_; const ParallelMapIteratorFunction map_func_; // Used for coordination between the main thread and the runner thread. - mutex mu_; + const std::shared_ptr<mutex> mu_; // Used for coordination between the main thread and the runner thread. In // particular, the runner thread should only schedule new calls when the // number of in-flight calls is less than the user specified level of // parallelism and there are slots available in the `invocation_results_` // buffer. - condition_variable cond_var_; + const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. - std::atomic<int64> num_parallel_calls_; + const std::shared_ptr<model::SharedState> num_parallel_calls_; // Counts the number of outstanding calls. - int64 num_calls_ GUARDED_BY(mu_) = 0; + int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr<IteratorBase> input_impl_; // Buffer for storing the invocation results. std::deque<std::shared_ptr<InvocationResult>> invocation_results_ - GUARDED_BY(mu_); - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); - bool cancelled_ GUARDED_BY(mu_) = false; + GUARDED_BY(*mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + bool cancelled_ GUARDED_BY(*mu_) = false; }; } // namespace |