aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-12-13 13:46:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-13 14:02:52 -0800
commit2b771b9fc34190d0acea439f9a930a078e54c37a (patch)
tree8d9ddf664639e56b7a2113b5be8a3909bbc70b91
parent63fbdc8ec9560e6b40d8a55a3a0ad279f2269709 (diff)
Fix use-after-free in the DirectSession timeout handling code.
Change: 141934447
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc4
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc53
2 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0950ca8d8b..e1e794badb 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1211,6 +1211,10 @@ void DirectSession::WaitForNotification(RunState* run_state,
run_state->status.Update(status);
}
cm->StartCancel();
+ // We must wait for the executors to complete, because they have borrowed
+ // references to `cm` and other per-step state. After this notification, it
+ // is safe to clean up the step.
+ run_state->executors_done.WaitForNotification();
}
}
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 4b0165bae7..d25e15f414 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -777,6 +777,59 @@ TEST(DirectSessionTest, TimeoutSession) {
session->Close();
}
+// Accesses the cancellation manager for the step after the step has been
+// cancelled.
+class CancellationMgrPollingOp : public OpKernel {
+ public:
+ explicit CancellationMgrPollingOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) override {
+ CancellationManager* cm = ctx->cancellation_manager();
+ while (!cm->IsCancelled()) {
+ ctx->env()->SleepForMicroseconds(1000);
+ }
+ notification.Notify();
+ }
+ static Notification notification;
+};
+Notification CancellationMgrPollingOp::notification;
+
+REGISTER_KERNEL_BUILDER(Name("CancellationMgrPollingOp").Device(DEVICE_CPU),
+ CancellationMgrPollingOp);
+REGISTER_OP("CancellationMgrPollingOp").Doc("");
+
+TEST(DirectSessionTest, TestTimeoutCleanShutdown) {
+ GraphDef graph;
+ // Creates a graph with one FIFOQueue and one dequeue op.
+ protobuf::TextFormat::ParseFromString(R"proto(
+ node {
+ name: 'cm_polling'
+ op: 'CancellationMgrPollingOp'
+ device: '/device:CPU:0'
+ }
+ versions {
+ producer: 9
+ }
+ )proto",
+ &graph);
+
+ // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
+ SessionOptions options;
+ options.config.set_operation_timeout_in_ms(100);
+ std::unique_ptr<Session> session(NewSession(options));
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(graph));
+
+ // Verifies that the error code is DEADLINE_EXCEEDED.
+ Status s = session->Run({}, {}, {"cm_polling"}, nullptr);
+ ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
+
+ // Verify that the op ran to completion.
+ ASSERT_TRUE(CancellationMgrPollingOp::notification.HasBeenNotified());
+
+ session->Close();
+}
+
class BlockingOpState {
public:
void AwaitState(int awaiting_state) {