aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-10-01 17:18:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 17:22:39 -0700
commitbfbe2bbe6a83a4acfa8f87aa5c8228e74b37bb61 (patch)
tree18a274c3c1a8f917fc8addf9630ddff55436a4fd /tensorflow/core/kernels
parent80f8931682aeaae89786f0940892a6557b4cfd67 (diff)
[tf.data] More robust solution for input pipeline <--> performance model coordination.
PiperOrigin-RevId: 215309735
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc90
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc86
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc77
3 files changed, 130 insertions, 123 deletions
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index b4c7f9e510..bf08970560 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -187,29 +187,31 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- num_parallel_calls_(params.dataset->num_parallel_calls_) {}
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
~Iterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
- AddTunableParameter(ctx, "parallelism",
- &num_parallel_calls_ /* value */, 1 /* min */,
- port::NumSchedulableCPUs() /* max */, &cond_var_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ port::NumSchedulableCPUs());
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
@@ -221,27 +223,27 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override {
std::shared_ptr<BatchResult> result;
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -257,7 +259,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("call_counter"), &call_counter_));
@@ -298,7 +300,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void Callback(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<BatchResult>& result,
const std::shared_ptr<std::vector<Tensor>>& return_values,
- int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) {
+ int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) {
result->UpdateStatus(status);
if (status.ok()) {
EnsureOutputAllocated(ctx, result, return_values);
@@ -334,16 +336,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void CallCompleted(const std::shared_ptr<BatchResult>& result)
- LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ LOCKS_EXCLUDED(*mu_) {
+ mutex_lock l(*mu_);
num_calls_--;
result->num_calls--;
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
void CallFunction(std::shared_ptr<IteratorContext> ctx,
const std::shared_ptr<BatchResult>& result,
- int64 offset) LOCKS_EXCLUDED(mu_) {
+ int64 offset) LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
bool end_of_input;
@@ -400,7 +402,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -476,14 +478,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
RecordStart(ctx.get());
auto stop_cleanup =
gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
- new_calls.reserve(num_parallel_calls_);
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
- int64 num_parallel_calls = num_parallel_calls_;
+ new_calls.reserve(num_parallel_calls_->value);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_->value;
int64 max_batch_results =
(num_parallel_calls + dataset()->batch_size_ - 1) /
dataset()->batch_size_;
@@ -494,10 +496,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
};
while (true) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
@@ -524,7 +526,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
- size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
@@ -569,7 +571,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadStatus(IteratorStateReader* reader, const string& prefix,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_code")), &code_int));
@@ -587,7 +589,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
std::shared_ptr<BatchResult> result = batch_results_[index];
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -628,7 +630,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteStatus(IteratorStateWriter* writer, const string& prefix,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
static_cast<int64>(status.code())));
@@ -642,24 +644,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// Used for coordination between the main thread, the runner thread, and
// the callback threads.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the callback threads. In particular, the runner thread should only
- // schedule new calls when the number of in-flight calls is less than the
- // user specified level of parallelism and there are slots available in
- // the `batch_results_` buffer.
- condition_variable cond_var_;
+ // schedule new calls when the number of in-flight calls is less than
+ // `num_parallel_calls_->value` and there are slots available in the
+ // `batch_results_` buffer.
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Counts the number of outstanding calls for this batch.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
- int64 call_counter_ GUARDED_BY(mu_) = 0;
+ int64 call_counter_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results.
- std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
+ std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 2bb38bf0b9..6b6b3d6ab9 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -1217,7 +1217,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- num_parallel_calls_(params.dataset->num_parallel_calls_),
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)),
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
@@ -1227,25 +1230,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
false /* low_latency_hint */)) {}
~Iterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
- AddTunableParameter(ctx, "parallelism",
- &num_parallel_calls_ /* value */, 1 /* min */,
- dataset()->cycle_length_ /* max */, &cond_var_);
+ mutex_lock l(*mu_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ dataset()->cycle_length_);
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
@@ -1259,12 +1261,12 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
std::shared_ptr<InvocationResult> result;
do {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
if (!invocation_results_.empty()) {
@@ -1274,7 +1276,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
RecordStop(ctx);
result->notification.WaitForNotification();
@@ -1290,10 +1292,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -1331,7 +1333,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
@@ -1384,7 +1386,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -1401,7 +1403,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void FetchOutputs(
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
@@ -1424,14 +1426,14 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
if (end_of_input) {
current_elements_[cycle_index].reset();
}
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
element_in_use_[cycle_index] = false;
num_calls_--;
if (end_of_input) {
args_list_[cycle_index].clear();
num_open_--;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
// Method responsible for 1) creating iterators out of input elements, 2)
@@ -1442,20 +1444,20 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
return element_in_use_[cycle_index_] ||
- num_calls_ >= num_parallel_calls_ ||
+ num_calls_ >= num_parallel_calls_->value ||
invocation_results_.size() >=
dataset()->cycle_length_ * dataset()->block_length_;
};
while (true) {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait until this thread is cancelled, the end of input has been
// reached, or the cycle element at the `cycle_index_` position is
// not in use and there is space in the `invocation_results_` queue.
while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
@@ -1509,13 +1511,13 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
@@ -1526,7 +1528,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
@@ -1553,7 +1555,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status WriteCurrentElements(IteratorStateWriter* writer)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (current_elements_[idx]) {
TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
@@ -1572,7 +1574,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
Status ReadCurrentElements(IteratorContext* ctx,
IteratorStateReader* reader)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (reader->Contains(
full_name(strings::StrCat("args_size[", idx, "]")))) {
@@ -1600,7 +1602,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Used for coordination between the main thread, the runner thread, and
// the worker threads.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the worker threads. In particular, the runner thread should only
@@ -1608,45 +1610,45 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// user specified level of parallelism, there are slots available in the
// `invocation_results_` buffer, the current cycle element is not in use,
// and there are elements left to be fetched.
- condition_variable cond_var_;
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Iterator for input elements.
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_);
// Identifies current cycle element.
int64 cycle_index_ = 0;
// Arguments for creating an iterator for cycle elements.
- std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
+ std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(*mu_);
// Iterators for the current cycle elements. Concurrent access is
// protected by `element_in_use_`.
std::vector<std::unique_ptr<IteratorBase>> current_elements_;
// Identifies cycle elements that are in use by worker threads.
- std::vector<bool> element_in_use_ GUARDED_BY(mu_);
+ std::vector<bool> element_in_use_ GUARDED_BY(*mu_);
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
- GUARDED_BY(mu_);
+ GUARDED_BY(*mu_);
// Identifies whether end of input has been reached.
- bool end_of_input_ GUARDED_BY(mu_) = false;
+ bool end_of_input_ GUARDED_BY(*mu_) = false;
// Identifies the number of open iterators.
- int64 num_open_ GUARDED_BY(mu_) = 0;
+ int64 num_open_ GUARDED_BY(*mu_) = 0;
// Identifies the number of outstanding calls.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
// Identifies whether background activity should be cancelled.
- bool cancelled_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index da067a4e6f..13bd4b6036 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -40,30 +40,32 @@ class ParallelMapIterator : public DatasetBaseIterator {
input_dataset_(input_dataset),
init_func_(std::move(init_func)),
map_func_(std::move(map_func)),
- num_parallel_calls_(num_parallel_calls) {}
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ num_parallel_calls, mu_, cond_var_)) {}
~ParallelMapIterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
+ mutex_lock l(*mu_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
// use it here for the maximum.
- AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */,
- 1 /* min */, port::NumSchedulableCPUs() /* max */,
- &cond_var_);
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ port::NumSchedulableCPUs());
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
@@ -77,16 +79,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
RecordStop(ctx);
result->notification.WaitForNotification();
@@ -96,10 +98,10 @@ class ParallelMapIterator : public DatasetBaseIterator {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -128,7 +130,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
@@ -175,7 +177,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -185,18 +187,18 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void CallCompleted(const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
num_calls_--;
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
result->notification.Notify();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status =
@@ -239,18 +241,18 @@ class ParallelMapIterator : public DatasetBaseIterator {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
- new_calls.reserve(num_parallel_calls_);
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
- int64 num_parallel_calls = num_parallel_calls_;
+ new_calls.reserve(num_parallel_calls_->value);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
invocation_results_.size() >= num_parallel_calls;
};
while (true) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
@@ -261,7 +263,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
for (const auto& call : new_calls) {
CallFunction(ctx, call);
@@ -271,7 +273,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
@@ -282,7 +285,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
@@ -312,23 +315,23 @@ class ParallelMapIterator : public DatasetBaseIterator {
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
// Used for coordination between the main thread and the runner thread.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
- condition_variable cond_var_;
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Counts the number of outstanding calls.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
- GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
+ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
} // namespace