aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-09-20 13:36:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 13:45:10 -0700
commit88cfc00ad2a33ef1440d8474fa830bce44c13056 (patch)
tree91e020af6c2beb2b71b8a0d2130b4b7989d1b5f5
parentc367ba02acc1d292738e3213173acbc0fe04335e (diff)
[tf.data] Fixes for two recently introduced use-after-free bugs.
1. In ParallelMapIterator, do not call `cond_var_.notify_all()` without holding the associated mutex. In some cases, the iterator may have been deleted between releasing the lock and notifying the condition variable, which leads to a use-after-free. This change applies this style to all use of condition variables in tensorflow/core/kernels/data/. 2. In CapturedFunction::RunAsync(), do not use `shared_ptr` to manage the lifetime of objects that (potentially) borrow from runtime objects. The present code runs the destructor after the `done()` callback is called, but the `done()` callback may be the last action in a session, and thus trigger destruction of those borrowed objects. In that case, the `shared_ptr` destructor may use the borrowed objects after they are freed. PiperOrigin-RevId: 213872829
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc37
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc130
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc8
4 files changed, 87 insertions, 94 deletions
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 8a5d30a27c..b5f4072e89 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -427,17 +427,17 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
done(s);
return;
}
- std::shared_ptr<OwnedArgsCallFrame> frame(
- new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_));
+ OwnedArgsCallFrame* frame =
+ new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
- std::shared_ptr<ScopedStepContainer> step_container(new ScopedStepContainer(
+ ScopedStepContainer* step_container = new ScopedStepContainer(
f_opts.step_id, [resource_mgr](const string& name) {
resource_mgr->Cleanup(name).IgnoreError();
- }));
- f_opts.step_container = step_container.get();
+ });
+ f_opts.step_container = step_container;
f_opts.runner = ctx->runner();
if (ctx->lib()->device()->device_type() != DEVICE_CPU) {
f_opts.create_rendezvous = true;
@@ -448,8 +448,8 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// (such as queue kernels) that depend on the non-nullness of
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
- std::shared_ptr<CancellationManager> c_mgr(new CancellationManager);
- f_opts.cancellation_manager = c_mgr.get();
+ CancellationManager* c_mgr = new CancellationManager;
+ f_opts.cancellation_manager = c_mgr;
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
std::shared_ptr<model::Node> node;
if (ctx->model()) {
@@ -460,19 +460,19 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
}
f_opts.stats_collector = stats_collector.get();
- OwnedArgsCallFrame* raw_frame = frame.get();
auto callback = std::bind(
- [rets](const std::shared_ptr<CancellationManager>& c_mgr,
- const FunctionLibraryRuntime::DoneCallback& done,
- const std::shared_ptr<OwnedArgsCallFrame>& frame,
- const std::shared_ptr<model::Node>& node,
- const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
- const std::shared_ptr<ScopedStepContainer>& step_container,
- // Begin unbound arguments.
- Status s) {
+ [rets, step_container, c_mgr, frame](
+ const FunctionLibraryRuntime::DoneCallback& done,
+ const std::shared_ptr<model::Node>& node,
+ const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
+ // Begin unbound arguments.
+ Status s) {
+ delete step_container;
+ delete c_mgr;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
}
+ delete frame;
if (node) {
node->add_processing_time(stats_collector->processing_time());
node->start_work();
@@ -482,11 +482,10 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
node->stop_work();
}
},
- std::move(c_mgr), std::move(done), std::move(frame), std::move(node),
- std::move(stats_collector), std::move(step_container),
+ std::move(done), std::move(node), std::move(stats_collector),
std::placeholders::_1);
- ctx->lib()->Run(f_opts, handle, raw_frame, std::move(callback));
+ ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
}
CapturedFunction::CapturedFunction(const NameAttrList& func,
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 83896219a3..fb022ddf12 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -206,9 +206,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
{
mutex_lock l(mu_);
num_parallel_calls_ = value;
+ cond_var_.notify_all();
}
VLOG(2) << "setting parallelism knob to " << value;
- cond_var_.notify_all();
};
AddTunableParameter(
ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
@@ -236,8 +236,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
@@ -340,11 +340,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(mu_) {
- {
mutex_lock l(mu_);
num_calls_--;
result->num_calls--;
- }
cond_var_.notify_all();
}
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 9cd46bf5dd..3dac7902f0 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -1241,9 +1241,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
{
mutex_lock l(mu_);
num_parallel_calls_ = value;
+ cond_var_.notify_all();
}
VLOG(2) << "setting parallelism knob to " << value;
- cond_var_.notify_all();
};
AddTunableParameter(
ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
@@ -1278,8 +1278,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
StopWork(ctx);
result->notification.WaitForNotification();
StartWork(ctx);
@@ -1425,17 +1425,15 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Release the ownership of the cycle element iterator, closing the
// iterator if end of input was encountered.
- {
- if (end_of_input) {
- current_elements_[cycle_index].reset();
- }
- mutex_lock l(mu_);
- element_in_use_[cycle_index] = false;
- num_calls_--;
- if (end_of_input) {
- args_list_[cycle_index].clear();
- num_open_--;
- }
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
}
cond_var_.notify_all();
}
@@ -1453,32 +1451,44 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
StartWork(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
while (true) {
- {
- mutex_lock l(mu_);
- // Wait until this thread is cancelled, the end of input has been
- // reached, or the cycle element at the `cycle_index_` position is
- // not in use and there is space in the `invocation_results_` queue.
- while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
- (element_in_use_[cycle_index_] ||
- num_calls_ >= num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
- StopWork(ctx.get());
- cond_var_.wait(l);
- StartWork(ctx.get());
- }
+ mutex_lock l(mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
+ (element_in_use_[cycle_index_] ||
+ num_calls_ >= num_parallel_calls_ ||
+ invocation_results_.size() >= MaxInvocationResults())) {
+ StopWork(ctx.get());
+ cond_var_.wait(l);
+ StartWork(ctx.get());
+ }
- if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
- return;
- }
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
- while (!element_in_use_[cycle_index_] &&
- (!end_of_input_ || num_open_ > 0) &&
- num_calls_ < num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
- if (!current_elements_[cycle_index_]) {
- // Try to create a new iterator from the next input element.
- Status status = input_impl_->GetNext(
- ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ while (!element_in_use_[cycle_index_] &&
+ (!end_of_input_ || num_open_ > 0) &&
+ num_calls_ < num_parallel_calls_ &&
+ invocation_results_.size() < MaxInvocationResults()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ &current_elements_[cycle_index_]);
if (!status.ok()) {
invocation_results_.emplace_back(new InvocationResult());
std::shared_ptr<InvocationResult>& result =
@@ -1487,39 +1497,25 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
result->notification.Notify();
break;
}
- if (!end_of_input_) {
- Status status = MakeIteratorFromInputElement(
- ctx.get(), args_list_[cycle_index_], cycle_index_,
- dataset()->captured_func_.get(), prefix(),
- &current_elements_[cycle_index_]);
- if (!status.ok()) {
- invocation_results_.emplace_back(new InvocationResult());
- std::shared_ptr<InvocationResult>& result =
- invocation_results_.back();
- result->status.Update(status);
- result->notification.Notify();
- break;
- }
- ++num_open_;
- }
+ ++num_open_;
}
- if (current_elements_[cycle_index_]) {
- // Pre-allocate invocation results for outputs to be fetched
- // and then fetch the outputs asynchronously.
- std::vector<std::shared_ptr<InvocationResult>> results;
- results.reserve(dataset()->block_length_);
- for (int i = 0; i < dataset()->block_length_; ++i) {
- invocation_results_.emplace_back(new InvocationResult());
- results.push_back(invocation_results_.back());
- }
- num_calls_++;
- element_in_use_[cycle_index_] = true;
- thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
- ctx, cycle_index_,
- std::move(results)));
+ }
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
}
- cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
}
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
cond_var_.notify_all();
}
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 5f6052ce83..20ac518f37 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -63,9 +63,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
{
mutex_lock l(mu_);
num_parallel_calls_ = value;
+ cond_var_.notify_all();
}
VLOG(2) << "setting parallelism knob to " << value;
- cond_var_.notify_all();
};
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
// use it here for the maximum.
@@ -96,8 +96,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
StopWork(ctx);
result->notification.WaitForNotification();
StartWork(ctx);
@@ -201,9 +201,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
{
mutex_lock l(mu_);
num_calls_--;
+ cond_var_.notify_all();
}
result->notification.Notify();
- cond_var_.notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
@@ -275,8 +275,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}