diff options
author | 2016-12-13 13:46:55 -0800 | |
---|---|---|
committer | 2016-12-13 14:02:52 -0800 | |
commit | 2b771b9fc34190d0acea439f9a930a078e54c37a (patch) | |
tree | 8d9ddf664639e56b7a2113b5be8a3909bbc70b91 | |
parent | 63fbdc8ec9560e6b40d8a55a3a0ad279f2269709 (diff) |
Fix use-after-free in the DirectSession timeout handling code.
Change: 141934447
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session_test.cc | 53 |
2 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 0950ca8d8b..e1e794badb 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1211,6 +1211,10 @@ void DirectSession::WaitForNotification(RunState* run_state, run_state->status.Update(status); } cm->StartCancel(); + // We must wait for the executors to complete, because they have borrowed + // references to `cm` and other per-step state. After this notification, it + // is safe to clean up the step. + run_state->executors_done.WaitForNotification(); } } diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 4b0165bae7..d25e15f414 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -777,6 +777,59 @@ TEST(DirectSessionTest, TimeoutSession) { session->Close(); } +// Accesses the cancellation manager for the step after the step has been +// cancelled. +class CancellationMgrPollingOp : public OpKernel { + public: + explicit CancellationMgrPollingOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + CancellationManager* cm = ctx->cancellation_manager(); + while (!cm->IsCancelled()) { + ctx->env()->SleepForMicroseconds(1000); + } + notification.Notify(); + } + static Notification notification; +}; +Notification CancellationMgrPollingOp::notification; + +REGISTER_KERNEL_BUILDER(Name("CancellationMgrPollingOp").Device(DEVICE_CPU), + CancellationMgrPollingOp); +REGISTER_OP("CancellationMgrPollingOp").Doc(""); + +TEST(DirectSessionTest, TestTimeoutCleanShutdown) { + GraphDef graph; + // Creates a graph with one FIFOQueue and one dequeue op. + protobuf::TextFormat::ParseFromString(R"proto( + node { + name: 'cm_polling' + op: 'CancellationMgrPollingOp' + device: '/device:CPU:0' + } + versions { + producer: 9 + } + )proto", + &graph); + + // Creates a session with operation_timeout_in_ms set to 100 milliseconds. + SessionOptions options; + options.config.set_operation_timeout_in_ms(100); + std::unique_ptr<Session> session(NewSession(options)); + ASSERT_TRUE(session != nullptr); + TF_ASSERT_OK(session->Create(graph)); + + // Verifies that the error code is DEADLINE_EXCEEDED. + Status s = session->Run({}, {}, {"cm_polling"}, nullptr); + ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code()); + + // Verify that the op ran to completion. + ASSERT_TRUE(CancellationMgrPollingOp::notification.HasBeenNotified()); + + session->Close(); +} + class BlockingOpState { public: void AwaitState(int awaiting_state) { |