diff options
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc | 89 |
1 files changed, 47 insertions, 42 deletions
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 2e6e0465f7..6b6b3d6ab9 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -1084,6 +1084,9 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU), // The above design choices were made with automated optimizations in mind, // isolating the degree of parallelism as the single tunable knob of this // implementation. +// +// TODO(b/116852688): Make coordination between the performance model and this +// transformation more robust. class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { public: explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx) @@ -1214,7 +1217,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params), - num_parallel_calls_(params.dataset->num_parallel_calls_), + mu_(std::make_shared<mutex>()), + cond_var_(std::make_shared<condition_variable>()), + num_parallel_calls_(std::make_shared<model::SharedState>( + params.dataset->num_parallel_calls_, mu_, cond_var_)), args_list_(params.dataset->cycle_length_), current_elements_(params.dataset->cycle_length_), element_in_use_(params.dataset->cycle_length_, false), @@ -1224,25 +1230,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { false /* low_latency_hint */)) {} ~Iterator() override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; - cond_var_.notify_all(); + cond_var_->notify_all(); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } } Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); - if (num_parallel_calls_ == kAutoTune) { - num_parallel_calls_ = 1; - AddTunableParameter(ctx, "parallelism", - &num_parallel_calls_ /* value */, 1 /* min */, - dataset()->cycle_length_ /* max */, &cond_var_); + mutex_lock l(*mu_); + if (num_parallel_calls_->value == kAutoTune) { + num_parallel_calls_->value = 1; + AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, + dataset()->cycle_length_); } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); } AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_); TF_RETURN_IF_ERROR( @@ -1256,12 +1261,12 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { std::shared_ptr<InvocationResult> result; do { { - mutex_lock l(mu_); + mutex_lock l(*mu_); EnsureRunnerThreadStarted(ctx); while (invocation_results_.empty() && (!end_of_input_ || num_open_ > 0)) { RecordStop(ctx); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx); } if (!invocation_results_.empty()) { @@ -1271,7 +1276,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } - cond_var_.notify_all(); + cond_var_->notify_all(); } RecordStop(ctx); result->notification.WaitForNotification(); @@ -1287,10 +1292,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); @@ -1328,7 +1333,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64 invocation_results_size; TF_RETURN_IF_ERROR(reader->ReadScalar( @@ -1381,7 +1386,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { }; void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( @@ -1398,7 +1403,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { void FetchOutputs( const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index, const std::vector<std::shared_ptr<InvocationResult>>& results) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); bool end_of_input = false; @@ -1421,14 +1426,14 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { if (end_of_input) { current_elements_[cycle_index].reset(); } - mutex_lock l(mu_); + mutex_lock l(*mu_); element_in_use_[cycle_index] = false; num_calls_--; if (end_of_input) { args_list_[cycle_index].clear(); num_open_--; } - cond_var_.notify_all(); + cond_var_->notify_all(); } // Method responsible for 1) creating iterators out of input elements, 2) @@ -1439,20 +1444,20 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); - auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { return element_in_use_[cycle_index_] || - num_calls_ >= num_parallel_calls_ || + num_calls_ >= num_parallel_calls_->value || invocation_results_.size() >= dataset()->cycle_length_ * dataset()->block_length_; }; while (true) { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait until this thread is cancelled, the end of input has been // reached, or the cycle element at the `cycle_index_` position is // not in use and there is space in the `invocation_results_` queue. while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) { RecordStop(ctx.get()); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx.get()); } @@ -1506,13 +1511,13 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; } - cond_var_.notify_all(); + cond_var_->notify_all(); } } Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, const Status& status) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar( CodeKey(index), static_cast<int64>(status.code()))); if (!status.ok()) { @@ -1523,7 +1528,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } Status ReadStatusLocked(IteratorStateReader* reader, size_t index, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64 code_int; TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); error::Code code = static_cast<error::Code>(code_int); @@ -1550,7 +1555,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } Status WriteCurrentElements(IteratorStateWriter* writer) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { for (int idx = 0; idx < current_elements_.size(); idx++) { if (current_elements_[idx]) { TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx])); @@ -1569,7 +1574,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { Status ReadCurrentElements(IteratorContext* ctx, IteratorStateReader* reader) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { for (int idx = 0; idx < current_elements_.size(); idx++) { if (reader->Contains( full_name(strings::StrCat("args_size[", idx, "]")))) { @@ -1597,7 +1602,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // Used for coordination between the main thread, the runner thread, and // the worker threads. - mutex mu_; + const std::shared_ptr<mutex> mu_; // Used for coordination between the main thread, the runner thread, and // the worker threads. In particular, the runner thread should only @@ -1605,45 +1610,45 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // user specified level of parallelism, there are slots available in the // `invocation_results_` buffer, the current cycle element is not in use, // and there are elements left to be fetched. - condition_variable cond_var_; + const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. - std::atomic<int64> num_parallel_calls_; + const std::shared_ptr<model::SharedState> num_parallel_calls_; // Iterator for input elements. - std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_); // Identifies current cycle element. int64 cycle_index_ = 0; // Arguments for creating an iterator for cycle elements. - std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_); + std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(*mu_); // Iterators for the current cycle elements. Concurrent access is // protected by `element_in_use_`. std::vector<std::unique_ptr<IteratorBase>> current_elements_; // Identifies cycle elements that are in use by worker threads. - std::vector<bool> element_in_use_ GUARDED_BY(mu_); + std::vector<bool> element_in_use_ GUARDED_BY(*mu_); // Buffer for storing the invocation results. std::deque<std::shared_ptr<InvocationResult>> invocation_results_ - GUARDED_BY(mu_); + GUARDED_BY(*mu_); // Identifies whether end of input has been reached. - bool end_of_input_ GUARDED_BY(mu_) = false; + bool end_of_input_ GUARDED_BY(*mu_) = false; // Identifies the number of open iterators. - int64 num_open_ GUARDED_BY(mu_) = 0; + int64 num_open_ GUARDED_BY(*mu_) = 0; // Identifies the number of outstanding calls. - int64 num_calls_ GUARDED_BY(mu_) = 0; + int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr<thread::ThreadPool> thread_pool_; - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); // Identifies whether background activity should be cancelled. - bool cancelled_ GUARDED_BY(mu_) = false; + bool cancelled_ GUARDED_BY(*mu_) = false; }; const DatasetBase* const input_; |