diff options
author | Derek Murray <mrry@google.com> | 2018-09-20 13:36:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 13:45:10 -0700 |
commit | 88cfc00ad2a33ef1440d8474fa830bce44c13056 (patch) | |
tree | 91e020af6c2beb2b71b8a0d2130b4b7989d1b5f5 | |
parent | c367ba02acc1d292738e3213173acbc0fe04335e (diff) |
[tf.data] Fixes for two recently introduced use-after-free bugs.
1. In ParallelMapIterator, do not call `cond_var_.notify_all()` without holding
the associated mutex. In some cases, the iterator may have been deleted
between releasing the lock and notifying the condition variable, which
leads to a use-after-free. This change applies this style to all use of
condition variables in tensorflow/core/kernels/data/.
2. In CapturedFunction::RunAsync(), do not use `shared_ptr` to manage
the lifetime of objects that (potentially) borrow from runtime
objects. The present code runs the destructor after the `done()`
callback is called, but the `done()` callback may be the last
action in a session, and thus trigger destruction of those borrowed
objects. In that case, the `shared_ptr` destructor may use the
borrowed objects after they are freed.
PiperOrigin-RevId: 213872829
4 files changed, 87 insertions, 94 deletions
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 8a5d30a27c..b5f4072e89 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -427,17 +427,17 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, done(s); return; } - std::shared_ptr<OwnedArgsCallFrame> frame( - new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_)); + OwnedArgsCallFrame* frame = + new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_); FunctionLibraryRuntime::Options f_opts; f_opts.step_id = CapturedFunction::generate_step_id(); ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager(); - std::shared_ptr<ScopedStepContainer> step_container(new ScopedStepContainer( + ScopedStepContainer* step_container = new ScopedStepContainer( f_opts.step_id, [resource_mgr](const string& name) { resource_mgr->Cleanup(name).IgnoreError(); - })); - f_opts.step_container = step_container.get(); + }); + f_opts.step_container = step_container; f_opts.runner = ctx->runner(); if (ctx->lib()->device()->device_type() != DEVICE_CPU) { f_opts.create_rendezvous = true; @@ -448,8 +448,8 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, // (such as queue kernels) that depend on the non-nullness of // `OpKernelContext::cancellation_manager()`, but additional effort // will be required to plumb it through the `IteratorContext`. - std::shared_ptr<CancellationManager> c_mgr(new CancellationManager); - f_opts.cancellation_manager = c_mgr.get(); + CancellationManager* c_mgr = new CancellationManager; + f_opts.cancellation_manager = c_mgr; std::shared_ptr<SimpleStepStatsCollector> stats_collector; std::shared_ptr<model::Node> node; if (ctx->model()) { @@ -460,19 +460,19 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, } f_opts.stats_collector = stats_collector.get(); - OwnedArgsCallFrame* raw_frame = frame.get(); auto callback = std::bind( - [rets](const std::shared_ptr<CancellationManager>& c_mgr, - const FunctionLibraryRuntime::DoneCallback& done, - const std::shared_ptr<OwnedArgsCallFrame>& frame, - const std::shared_ptr<model::Node>& node, - const std::shared_ptr<SimpleStepStatsCollector>& stats_collector, - const std::shared_ptr<ScopedStepContainer>& step_container, - // Begin unbound arguments. - Status s) { + [rets, step_container, c_mgr, frame]( + const FunctionLibraryRuntime::DoneCallback& done, + const std::shared_ptr<model::Node>& node, + const std::shared_ptr<SimpleStepStatsCollector>& stats_collector, + // Begin unbound arguments. + Status s) { + delete step_container; + delete c_mgr; if (s.ok()) { s = frame->ConsumeRetvals(rets); } + delete frame; if (node) { node->add_processing_time(stats_collector->processing_time()); node->start_work(); @@ -482,11 +482,10 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, node->stop_work(); } }, - std::move(c_mgr), std::move(done), std::move(frame), std::move(node), - std::move(stats_collector), std::move(step_container), + std::move(done), std::move(node), std::move(stats_collector), std::placeholders::_1); - ctx->lib()->Run(f_opts, handle, raw_frame, std::move(callback)); + ctx->lib()->Run(f_opts, handle, frame, std::move(callback)); } CapturedFunction::CapturedFunction(const NameAttrList& func, diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 83896219a3..fb022ddf12 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -206,9 +206,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { { mutex_lock l(mu_); num_parallel_calls_ = value; + cond_var_.notify_all(); } VLOG(2) << "setting parallelism knob to " << value; - cond_var_.notify_all(); }; AddTunableParameter( ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */, @@ -236,8 +236,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } std::swap(result, batch_results_.front()); batch_results_.pop_front(); + cond_var_.notify_all(); } - cond_var_.notify_all(); return ProcessResult(ctx, result, out_tensors, end_of_sequence); } @@ -340,11 +340,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void CallCompleted(const std::shared_ptr<BatchResult>& result) LOCKS_EXCLUDED(mu_) { - { mutex_lock l(mu_); num_calls_--; result->num_calls--; - } cond_var_.notify_all(); } diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 9cd46bf5dd..3dac7902f0 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -1241,9 +1241,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { { mutex_lock l(mu_); num_parallel_calls_ = value; + cond_var_.notify_all(); } VLOG(2) << "setting parallelism knob to " << value; - cond_var_.notify_all(); }; AddTunableParameter( ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */, @@ -1278,8 +1278,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } + cond_var_.notify_all(); } - cond_var_.notify_all(); StopWork(ctx); result->notification.WaitForNotification(); StartWork(ctx); @@ -1425,17 +1425,15 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // Release the ownership of the cycle element iterator, closing the // iterator if end of input was encountered. - { - if (end_of_input) { - current_elements_[cycle_index].reset(); - } - mutex_lock l(mu_); - element_in_use_[cycle_index] = false; - num_calls_--; - if (end_of_input) { - args_list_[cycle_index].clear(); - num_open_--; - } + if (end_of_input) { + current_elements_[cycle_index].reset(); + } + mutex_lock l(mu_); + element_in_use_[cycle_index] = false; + num_calls_--; + if (end_of_input) { + args_list_[cycle_index].clear(); + num_open_--; } cond_var_.notify_all(); } @@ -1453,32 +1451,44 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { StartWork(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); }); while (true) { - { - mutex_lock l(mu_); - // Wait until this thread is cancelled, the end of input has been - // reached, or the cycle element at the `cycle_index_` position is - // not in use and there is space in the `invocation_results_` queue. - while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && - (element_in_use_[cycle_index_] || - num_calls_ >= num_parallel_calls_ || - invocation_results_.size() >= MaxInvocationResults())) { - StopWork(ctx.get()); - cond_var_.wait(l); - StartWork(ctx.get()); - } + mutex_lock l(mu_); + // Wait until this thread is cancelled, the end of input has been + // reached, or the cycle element at the `cycle_index_` position is + // not in use and there is space in the `invocation_results_` queue. + while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && + (element_in_use_[cycle_index_] || + num_calls_ >= num_parallel_calls_ || + invocation_results_.size() >= MaxInvocationResults())) { + StopWork(ctx.get()); + cond_var_.wait(l); + StartWork(ctx.get()); + } - if (cancelled_ || (end_of_input_ && num_open_ == 0)) { - return; - } + if (cancelled_ || (end_of_input_ && num_open_ == 0)) { + return; + } - while (!element_in_use_[cycle_index_] && - (!end_of_input_ || num_open_ > 0) && - num_calls_ < num_parallel_calls_ && - invocation_results_.size() < MaxInvocationResults()) { - if (!current_elements_[cycle_index_]) { - // Try to create a new iterator from the next input element. - Status status = input_impl_->GetNext( - ctx.get(), &args_list_[cycle_index_], &end_of_input_); + while (!element_in_use_[cycle_index_] && + (!end_of_input_ || num_open_ > 0) && + num_calls_ < num_parallel_calls_ && + invocation_results_.size() < MaxInvocationResults()) { + if (!current_elements_[cycle_index_]) { + // Try to create a new iterator from the next input element. + Status status = input_impl_->GetNext( + ctx.get(), &args_list_[cycle_index_], &end_of_input_); + if (!status.ok()) { + invocation_results_.emplace_back(new InvocationResult()); + std::shared_ptr<InvocationResult>& result = + invocation_results_.back(); + result->status.Update(status); + result->notification.Notify(); + break; + } + if (!end_of_input_) { + Status status = MakeIteratorFromInputElement( + ctx.get(), args_list_[cycle_index_], cycle_index_, + dataset()->captured_func_.get(), prefix(), + ¤t_elements_[cycle_index_]); if (!status.ok()) { invocation_results_.emplace_back(new InvocationResult()); std::shared_ptr<InvocationResult>& result = @@ -1487,39 +1497,25 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { result->notification.Notify(); break; } - if (!end_of_input_) { - Status status = MakeIteratorFromInputElement( - ctx.get(), args_list_[cycle_index_], cycle_index_, - dataset()->captured_func_.get(), prefix(), - ¤t_elements_[cycle_index_]); - if (!status.ok()) { - invocation_results_.emplace_back(new InvocationResult()); - std::shared_ptr<InvocationResult>& result = - invocation_results_.back(); - result->status.Update(status); - result->notification.Notify(); - break; - } - ++num_open_; - } + ++num_open_; } - if (current_elements_[cycle_index_]) { - // Pre-allocate invocation results for outputs to be fetched - // and then fetch the outputs asynchronously. - std::vector<std::shared_ptr<InvocationResult>> results; - results.reserve(dataset()->block_length_); - for (int i = 0; i < dataset()->block_length_; ++i) { - invocation_results_.emplace_back(new InvocationResult()); - results.push_back(invocation_results_.back()); - } - num_calls_++; - element_in_use_[cycle_index_] = true; - thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this, - ctx, cycle_index_, - std::move(results))); + } + if (current_elements_[cycle_index_]) { + // Pre-allocate invocation results for outputs to be fetched + // and then fetch the outputs asynchronously. + std::vector<std::shared_ptr<InvocationResult>> results; + results.reserve(dataset()->block_length_); + for (int i = 0; i < dataset()->block_length_; ++i) { + invocation_results_.emplace_back(new InvocationResult()); + results.push_back(invocation_results_.back()); } - cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; + num_calls_++; + element_in_use_[cycle_index_] = true; + thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this, + ctx, cycle_index_, + std::move(results))); } + cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; } cond_var_.notify_all(); } diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 5f6052ce83..20ac518f37 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -63,9 +63,9 @@ class ParallelMapIterator : public DatasetBaseIterator { { mutex_lock l(mu_); num_parallel_calls_ = value; + cond_var_.notify_all(); } VLOG(2) << "setting parallelism knob to " << value; - cond_var_.notify_all(); }; // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and // use it here for the maximum. @@ -96,8 +96,8 @@ class ParallelMapIterator : public DatasetBaseIterator { } std::swap(result, invocation_results_.front()); invocation_results_.pop_front(); + cond_var_.notify_all(); } - cond_var_.notify_all(); StopWork(ctx); result->notification.WaitForNotification(); StartWork(ctx); @@ -201,9 +201,9 @@ class ParallelMapIterator : public DatasetBaseIterator { { mutex_lock l(mu_); num_calls_--; + cond_var_.notify_all(); } result->notification.Notify(); - cond_var_.notify_all(); } void CallFunction(const std::shared_ptr<IteratorContext>& ctx, @@ -275,8 +275,8 @@ class ParallelMapIterator : public DatasetBaseIterator { new_calls.push_back(invocation_results_.back()); num_calls_++; } + cond_var_.notify_all(); } - cond_var_.notify_all(); for (const auto& call : new_calls) { CallFunction(ctx, call); } |