aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc89
1 files changed, 47 insertions, 42 deletions
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 2e6e0465f7..6b6b3d6ab9 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -1084,6 +1084,9 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
// The above design choices were made with automated optimizations in mind,
// isolating the degree of parallelism as the single tunable knob of this
// implementation.
+//
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
@@ -1214,7 +1217,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- num_parallel_calls_(params.dataset->num_parallel_calls_),
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)),
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
@@ -1224,25 +1230,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
false /* low_latency_hint */)) {}
~Iterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
- AddTunableParameter(ctx, "parallelism",
- &num_parallel_calls_ /* value */, 1 /* min */,
- dataset()->cycle_length_ /* max */, &cond_var_);
+ mutex_lock l(*mu_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ dataset()->cycle_length_);
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
@@ -1256,12 +1261,12 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
std::shared_ptr<InvocationResult> result;
do {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
if (!invocation_results_.empty()) {
@@ -1271,7 +1276,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
RecordStop(ctx);
result->notification.WaitForNotification();
@@ -1287,10 +1292,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -1328,7 +1333,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
@@ -1381,7 +1386,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -1398,7 +1403,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void FetchOutputs(
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
@@ -1421,14 +1426,14 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
if (end_of_input) {
current_elements_[cycle_index].reset();
}
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
element_in_use_[cycle_index] = false;
num_calls_--;
if (end_of_input) {
args_list_[cycle_index].clear();
num_open_--;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
// Method responsible for 1) creating iterators out of input elements, 2)
@@ -1439,20 +1444,20 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
return element_in_use_[cycle_index_] ||
- num_calls_ >= num_parallel_calls_ ||
+ num_calls_ >= num_parallel_calls_->value ||
invocation_results_.size() >=
dataset()->cycle_length_ * dataset()->block_length_;
};
while (true) {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait until this thread is cancelled, the end of input has been
// reached, or the cycle element at the `cycle_index_` position is
// not in use and there is space in the `invocation_results_` queue.
while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
@@ -1506,13 +1511,13 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
@@ -1523,7 +1528,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
@@ -1550,7 +1555,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status WriteCurrentElements(IteratorStateWriter* writer)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (current_elements_[idx]) {
TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
@@ -1569,7 +1574,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
Status ReadCurrentElements(IteratorContext* ctx,
IteratorStateReader* reader)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (reader->Contains(
full_name(strings::StrCat("args_size[", idx, "]")))) {
@@ -1597,7 +1602,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Used for coordination between the main thread, the runner thread, and
// the worker threads.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the worker threads. In particular, the runner thread should only
@@ -1605,45 +1610,45 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// user specified level of parallelism, there are slots available in the
// `invocation_results_` buffer, the current cycle element is not in use,
// and there are elements left to be fetched.
- condition_variable cond_var_;
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Iterator for input elements.
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_);
// Identifies current cycle element.
int64 cycle_index_ = 0;
// Arguments for creating an iterator for cycle elements.
- std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
+ std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(*mu_);
// Iterators for the current cycle elements. Concurrent access is
// protected by `element_in_use_`.
std::vector<std::unique_ptr<IteratorBase>> current_elements_;
// Identifies cycle elements that are in use by worker threads.
- std::vector<bool> element_in_use_ GUARDED_BY(mu_);
+ std::vector<bool> element_in_use_ GUARDED_BY(*mu_);
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
- GUARDED_BY(mu_);
+ GUARDED_BY(*mu_);
// Identifies whether end of input has been reached.
- bool end_of_input_ GUARDED_BY(mu_) = false;
+ bool end_of_input_ GUARDED_BY(*mu_) = false;
// Identifies the number of open iterators.
- int64 num_open_ GUARDED_BY(mu_) = 0;
+ int64 num_open_ GUARDED_BY(*mu_) = 0;
// Identifies the number of outstanding calls.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
// Identifies whether background activity should be cancelled.
- bool cancelled_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
const DatasetBase* const input_;