diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2016-10-27 14:12:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-27 15:19:44 -0700 |
commit | 71b993a63c9f4c62d45303623f926219066902cc (patch) | |
tree | 9d290c67c0b4e9bad47e8afad81db74a58f536aa /tensorflow/cc/training | |
parent | 44546e1e4e87b8127334dec8b0066c7df6b3f037 (diff) |
Add Stop() in C++ QueueRunner.
Change: 137447384
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 39 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 6 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 205 |
3 files changed, 191 insertions, 59 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 585ee15872..79d306f367 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -54,7 +54,8 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { } QueueRunner::~QueueRunner() { - should_stop_ = true; + // Cannot run Stop() here because the session might already be closed or + // destroyed. Join(); } @@ -72,6 +73,15 @@ Status QueueRunner::Start(Session* sess) { return Status::OK(); } +Status QueueRunner::Stop(Session* sess) { + should_stop_ = true; + if (cancel_op_name_.empty()) { + return Status::OK(); + } else { + return sess->Run({}, {}, {cancel_op_name_}, nullptr); + } +} + Status QueueRunner::Join() { thread_pool_.reset(); started_ = false; @@ -81,8 +91,7 @@ Status QueueRunner::Join() { void QueueRunner::Run(Session* sess, const string& enqueue_op) { bool decremented = false; while (!should_stop_.load()) { - std::vector<Tensor> outputs; - auto status = sess->Run({}, {}, {enqueue_op}, &outputs); + auto status = sess->Run({}, {}, {enqueue_op}, nullptr); if (status.ok()) { continue; } else if (queue_closed_exception_types_.count( @@ -94,19 +103,25 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { // If all enqueue ops have finished, run the close op. if (runs_ == 0 && !close_op_name_.empty()) { - std::vector<Tensor> outputs; - auto s = sess->Run({}, {}, {close_op_name_}, &outputs); - if (!s.ok()) { - status_ = status; + auto s = sess->Run({}, {}, {close_op_name_}, nullptr); + if (!s.ok() && status_.ok() && + queue_closed_exception_types_.count(static_cast<int>(s.code())) == + 0) { + status_ = s; } } } else { - mutex_lock l(mu_); - should_stop_ = true; - // Only record the first failure status. - if (status_.ok()) { - status_ = status; + { + mutex_lock l(mu_); + should_stop_ = true; + // Only record the first failure status. + if (status_.ok()) { + status_ = status; + } } + // Stop the queue runner immediately to propagate the error to + // subsequent queues. + Stop(sess); } } diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index 09d8d49821..c3fe4026ef 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <unordered_set> #include <vector> + #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -49,6 +50,9 @@ class QueueRunner { // Starts the queue runner with the given session. Status Start(Session* sess); + // Requests to stop and runs the cancel op. + Status Stop(Session* sess); + // Joins all the threads. Returns okay if all threads run successfully; // otherwise returns the first captured failure status. Status Join(); @@ -60,7 +64,6 @@ class QueueRunner { string queue_name_; std::vector<string> enqueue_op_names_; string close_op_name_; - // The cancel op is not being called currently. string cancel_op_name_; // code::Code casted to int to avoid a hash function. std::unordered_set<int> queue_closed_exception_types_; @@ -68,6 +71,7 @@ class QueueRunner { std::unique_ptr<thread::ThreadPool> thread_pool_; std::atomic<bool> should_stop_; std::atomic<bool> started_; + condition_variable wait_to_close_; mutex mu_; // TODO(yuefengz): implement c++ coordinator. int runs_ = 0; diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 8719677274..29165778c5 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -14,8 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/cc/training/queue_runner.h" + #include <string> #include <vector> + #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/graph.pb.h" @@ -23,39 +25,42 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/queue_runner.pb.h" #include "tensorflow/core/public/session.h" +namespace tensorflow { namespace { -using ::tensorflow::DataType; -using ::tensorflow::error::Code; -using ::tensorflow::GraphDef; -using ::tensorflow::ops::Assign; -using ::tensorflow::ops::Const; -using ::tensorflow::ops::CountUpTo; -using ::tensorflow::ops::FIFOQueue; -using ::tensorflow::ops::InputList; -using ::tensorflow::ops::QueueClose; -using ::tensorflow::ops::QueueDequeue; -using ::tensorflow::ops::QueueEnqueue; -using ::tensorflow::ops::Square; -using ::tensorflow::ops::Variable; -using ::tensorflow::QueueRunner; -using ::tensorflow::QueueRunnerDef; -using ::tensorflow::Scope; -using ::tensorflow::Session; -using ::tensorflow::SessionOptions; -using ::tensorflow::Tensor; -using ::tensorflow::TensorShape; +using error::Code; +using ops::Assign; +using ops::Const; +using ops::CountUpTo; +using ops::FIFOQueue; +using ops::QueueClose; +using ops::QueueDequeue; +using ops::QueueEnqueue; +using ops::Square; +using ops::Variable; constexpr char kAssignOpName[] = "assign"; +constexpr char kCancelOp0[] = "cancel0"; +constexpr char kCancelOp1[] = "cancel1"; +constexpr char kCloseOp0[] = "close0"; +constexpr char kCloseOp1[] = "close1"; constexpr char kCountUpToOpName[] = "count"; +constexpr char kDequeueOp0[] = "dequeue0"; +constexpr char kDequeueOp1[] = "dequeue1"; +constexpr char kEnqueueOp0[] = "enqueue0"; +constexpr char kEnqueueOp1[] = "enqueue1"; constexpr char kIllegalOpName1[] = "would fail"; constexpr char kIllegalOpName2[] = "fail again"; constexpr char kQueueName[] = "unit_test"; +constexpr char kQueueName0[] = "q0"; +constexpr char kQueueName1[] = "q1"; constexpr char kSquareOpName[] = "square"; constexpr char kVarOpName[] = "var"; @@ -75,7 +80,7 @@ GraphDef BuildSimpleGraph() { QueueRunnerDef BuildQueueRunnerDef( const std::string& queue_name, const std::vector<std::string>& enqueue_ops, - const std::string& close_op, + const std::string& close_op, const std::string& cancel_op, const std::vector<Code>& queue_closed_error_codes) { QueueRunnerDef queue_runner_def; *queue_runner_def.mutable_queue_name() = kQueueName; @@ -83,6 +88,7 @@ QueueRunnerDef BuildQueueRunnerDef( *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op; } *queue_runner_def.mutable_close_op_name() = close_op; + *queue_runner_def.mutable_cancel_op_name() = cancel_op; for (const auto& error_code : queue_closed_error_codes) { *queue_runner_def.mutable_queue_closed_exception_types()->Add() = error_code; @@ -96,8 +102,7 @@ std::unique_ptr<Session> BuildSessionAndInitVariable( std::unique_ptr<Session> session(NewSession(options)); TF_CHECK_OK(session->Create(graph_def)); - std::vector<Tensor> nothing; - TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, ¬hing)); + TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr)); return session; } @@ -106,7 +111,7 @@ TEST(QueueRunnerTest, BasicTest) { auto session = BuildSessionAndInitVariable(graph_def); QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, {}); + kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {}); QueueRunner qr(queue_runner_def); qr.Start(session.get()); @@ -123,7 +128,7 @@ TEST(QueueRunnerTest, QueueClosedCode) { auto session = BuildSessionAndInitVariable(graph_def); QueueRunnerDef queue_runner_def = - BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, + BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, "", {Code::OUT_OF_RANGE, Code::CANCELLED}); QueueRunner qr(queue_runner_def); @@ -141,60 +146,167 @@ TEST(QueueRunnerDef, CatchErrorInJoin) { auto session = BuildSessionAndInitVariable(graph_def); QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, {}); + kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {}); QueueRunner qr(queue_runner_def); qr.Start(session.get()); EXPECT_EQ(qr.Join().code(), Code::NOT_FOUND); } -TEST(QueueRunnerTest, RealEnqueueDequeue) { +GraphDef BuildDoubleQueueGraph() { Scope root = Scope::NewRootScope(); - auto q0 = FIFOQueue(root.WithOpName("q0"), {DataType::DT_INT32}); + auto q0 = FIFOQueue(root.WithOpName(kQueueName0), {DataType::DT_INT32}); auto ten = Const(root, 10); - auto enqueue0 = QueueEnqueue(root.WithOpName("enqueue0"), q0, {ten}); - auto close0 = QueueClose(root.WithOpName("close0"), q0); - auto q1 = FIFOQueue(root.WithOpName("q1"), {DataType::DT_INT32}); + auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {ten}); + auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0); + auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0, + QueueClose::CancelPendingEnqueues(true)); + auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32}); auto dequeue0 = - QueueDequeue(root.WithOpName("dequeue0"), q0, {DataType::DT_INT32}); - auto enqueue1 = QueueEnqueue(root.WithOpName("enqueue1"), q1, {dequeue0[0]}); + QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32}); + auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]}); auto dequeue1 = - QueueDequeue(root.WithOpName("dequeue1"), q1, {DataType::DT_INT32}); - auto close1 = QueueClose(root.WithOpName("close1"), q1); + QueueDequeue(root.WithOpName(kDequeueOp1), q1, {DataType::DT_INT32}); + auto close1 = QueueClose(root.WithOpName(kCloseOp1), q1); + auto cancel1 = QueueClose(root.WithOpName(kCancelOp1), q1, + QueueClose::CancelPendingEnqueues(true)); GraphDef graph_def; TF_EXPECT_OK(root.ToGraphDef(&graph_def)); + return graph_def; +} + +TEST(QueueRunnerTest, RealEnqueueDequeue) { + auto graph_def = BuildDoubleQueueGraph(); SessionOptions options; std::unique_ptr<Session> session(NewSession(options)); TF_CHECK_OK(session->Create(graph_def)); QueueRunnerDef queue_runner_def = - BuildQueueRunnerDef(kQueueName, {"enqueue1"}, "close1", {}); + BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {}); QueueRunner qr; qr.Init(queue_runner_def); TF_CHECK_OK(qr.Start(session.get())); - std::vector<Tensor> outputs; - TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs)); - TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs)); - TF_EXPECT_OK(session->Run({}, {}, {"close0"}, &outputs)); + TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); + TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); + // Closing queue 0 would also close the queue runner. + TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr)); TF_EXPECT_OK(qr.Join()); std::vector<Tensor> dq1; - TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq1)); + TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1)); EXPECT_EQ(*dq1[0].scalar<int>().data(), 10); std::vector<Tensor> dq2; - TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq2)); + TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq2)); EXPECT_EQ(*dq2[0].scalar<int>().data(), 10); - EXPECT_EQ(session->Run({}, {"dequeue1"}, {}, &dq1).code(), + EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(), Code::OUT_OF_RANGE); } +void JoinThread(QueueRunner* queue_runner, bool* join_succeeded, + Notification* join_done) { + EXPECT_EQ(queue_runner->Join().code(), Code::CANCELLED); + *join_succeeded = true; + join_done->Notify(); +} + +TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) { + auto graph_def = BuildDoubleQueueGraph(); + + SessionOptions options; + std::unique_ptr<Session> session(NewSession(options)); + TF_CHECK_OK(session->Create(graph_def)); + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); + QueueRunner qr; + qr.Init(queue_runner_def); + TF_CHECK_OK(qr.Start(session.get())); + + TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); + + std::vector<Tensor> dq1; + TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1)); + EXPECT_EQ(*dq1[0].scalar<int>().data(), 10); + + // The expected behavior is the QueueRunner::Join() call is blocked until + // Session::Close() is called. + bool join_succeeded = false; + Notification join_done; + Env::Default()->SchedClosure( + std::bind(&JoinThread, &qr, &join_succeeded, &join_done)); + + Env::Default()->SleepForMicroseconds(10000000); + EXPECT_EQ(join_succeeded, false); + + // Closing the session is required to cancel pending enqueue nodes. + TF_EXPECT_OK(session->Close()); + + join_done.WaitForNotification(); + EXPECT_EQ(join_succeeded, true); +} + +TEST(QueueRunnerTest, Stop) { + auto graph_def = BuildDoubleQueueGraph(); + + SessionOptions options; + std::unique_ptr<Session> session(NewSession(options)); + TF_CHECK_OK(session->Create(graph_def)); + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); + QueueRunner qr; + qr.Init(queue_runner_def); + TF_CHECK_OK(qr.Start(session.get())); + + TF_EXPECT_OK(qr.Stop(session.get())); + + TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); + + EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(), + Code::OUT_OF_RANGE); + + // qr is already stopped + TF_EXPECT_OK(qr.Join()); +} + +TEST(QueueRunnerTest, StopTwoQueues) { + auto graph_def = BuildDoubleQueueGraph(); + + SessionOptions options; + std::unique_ptr<Session> session(NewSession(options)); + TF_CHECK_OK(session->Create(graph_def)); + + QueueRunnerDef queue_runner0 = + BuildQueueRunnerDef(kQueueName0, {kEnqueueOp0}, kCloseOp0, kCancelOp0, + {Code::OUT_OF_RANGE, Code::CANCELLED}); + QueueRunnerDef queue_runner1 = + BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, + {Code::OUT_OF_RANGE, Code::CANCELLED}); + QueueRunner qr0; + qr0.Init(queue_runner0); + TF_CHECK_OK(qr0.Start(session.get())); + QueueRunner qr1; + qr1.Init(queue_runner1); + TF_CHECK_OK(qr1.Start(session.get())); + + std::vector<Tensor> dq; + TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq)); + EXPECT_EQ(*dq[0].scalar<int>().data(), 10); + + TF_EXPECT_OK(qr0.Stop(session.get())); + TF_EXPECT_OK(qr1.Stop(session.get())); + + TF_EXPECT_OK(qr0.Join()); + TF_EXPECT_OK(qr1.Join()); +} + TEST(QueueRunnerTest, EmptyEnqueueOps) { QueueRunnerDef queue_runner_def = - BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, {}); + BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {}); QueueRunner qr; EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::INVALID_ARGUMENT); @@ -203,8 +315,8 @@ TEST(QueueRunnerTest, EmptyEnqueueOps) { TEST(QueueRunnerTest, InitAfterStart) { GraphDef graph_def = BuildSimpleGraph(); auto session = BuildSessionAndInitVariable(graph_def); - QueueRunnerDef queue_runner_def = - BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kCountUpToOpName, {}); + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName, {kCountUpToOpName}, kCountUpToOpName, "", {}); QueueRunner qr; TF_EXPECT_OK(qr.Init(queue_runner_def)); @@ -213,3 +325,4 @@ TEST(QueueRunnerTest, InitAfterStart) { } } // namespace +} // namespace tensorflow |