aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2016-10-27 14:12:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-27 15:19:44 -0700
commit71b993a63c9f4c62d45303623f926219066902cc (patch)
tree9d290c67c0b4e9bad47e8afad81db74a58f536aa /tensorflow/cc/training
parent44546e1e4e87b8127334dec8b0066c7df6b3f037 (diff)
Add Stop() in C++ QueueRunner.
Change: 137447384
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/queue_runner.cc39
-rw-r--r--tensorflow/cc/training/queue_runner.h6
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc205
3 files changed, 191 insertions, 59 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index 585ee15872..79d306f367 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -54,7 +54,8 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
}
QueueRunner::~QueueRunner() {
- should_stop_ = true;
+ // Cannot run Stop() here because the session might already be closed or
+ // destroyed.
Join();
}
@@ -72,6 +73,15 @@ Status QueueRunner::Start(Session* sess) {
return Status::OK();
}
+Status QueueRunner::Stop(Session* sess) {
+ should_stop_ = true;
+ if (cancel_op_name_.empty()) {
+ return Status::OK();
+ } else {
+ return sess->Run({}, {}, {cancel_op_name_}, nullptr);
+ }
+}
+
Status QueueRunner::Join() {
thread_pool_.reset();
started_ = false;
@@ -81,8 +91,7 @@ Status QueueRunner::Join() {
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
bool decremented = false;
while (!should_stop_.load()) {
- std::vector<Tensor> outputs;
- auto status = sess->Run({}, {}, {enqueue_op}, &outputs);
+ auto status = sess->Run({}, {}, {enqueue_op}, nullptr);
if (status.ok()) {
continue;
} else if (queue_closed_exception_types_.count(
@@ -94,19 +103,25 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
// If all enqueue ops have finished, run the close op.
if (runs_ == 0 && !close_op_name_.empty()) {
- std::vector<Tensor> outputs;
- auto s = sess->Run({}, {}, {close_op_name_}, &outputs);
- if (!s.ok()) {
- status_ = status;
+ auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
+ if (!s.ok() && status_.ok() &&
+ queue_closed_exception_types_.count(static_cast<int>(s.code())) ==
+ 0) {
+ status_ = s;
}
}
} else {
- mutex_lock l(mu_);
- should_stop_ = true;
- // Only record the first failure status.
- if (status_.ok()) {
- status_ = status;
+ {
+ mutex_lock l(mu_);
+ should_stop_ = true;
+ // Only record the first failure status.
+ if (status_.ok()) {
+ status_ = status;
+ }
}
+ // Stop the queue runner immediately to propagate the error to
+ // subsequent queues.
+ Stop(sess);
}
}
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index 09d8d49821..c3fe4026ef 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <unordered_set>
#include <vector>
+
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -49,6 +50,9 @@ class QueueRunner {
// Starts the queue runner with the given session.
Status Start(Session* sess);
+ // Requests to stop and runs the cancel op.
+ Status Stop(Session* sess);
+
// Joins all the threads. Returns okay if all threads run successfully;
// otherwise returns the first captured failure status.
Status Join();
@@ -60,7 +64,6 @@ class QueueRunner {
string queue_name_;
std::vector<string> enqueue_op_names_;
string close_op_name_;
- // The cancel op is not being called currently.
string cancel_op_name_;
// code::Code casted to int to avoid a hash function.
std::unordered_set<int> queue_closed_exception_types_;
@@ -68,6 +71,7 @@ class QueueRunner {
std::unique_ptr<thread::ThreadPool> thread_pool_;
std::atomic<bool> should_stop_;
std::atomic<bool> started_;
+ condition_variable wait_to_close_;
mutex mu_;
// TODO(yuefengz): implement c++ coordinator.
int runs_ = 0;
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 8719677274..29165778c5 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -14,8 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/cc/training/queue_runner.h"
+
#include <string>
#include <vector>
+
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -23,39 +25,42 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
+namespace tensorflow {
namespace {
-using ::tensorflow::DataType;
-using ::tensorflow::error::Code;
-using ::tensorflow::GraphDef;
-using ::tensorflow::ops::Assign;
-using ::tensorflow::ops::Const;
-using ::tensorflow::ops::CountUpTo;
-using ::tensorflow::ops::FIFOQueue;
-using ::tensorflow::ops::InputList;
-using ::tensorflow::ops::QueueClose;
-using ::tensorflow::ops::QueueDequeue;
-using ::tensorflow::ops::QueueEnqueue;
-using ::tensorflow::ops::Square;
-using ::tensorflow::ops::Variable;
-using ::tensorflow::QueueRunner;
-using ::tensorflow::QueueRunnerDef;
-using ::tensorflow::Scope;
-using ::tensorflow::Session;
-using ::tensorflow::SessionOptions;
-using ::tensorflow::Tensor;
-using ::tensorflow::TensorShape;
+using error::Code;
+using ops::Assign;
+using ops::Const;
+using ops::CountUpTo;
+using ops::FIFOQueue;
+using ops::QueueClose;
+using ops::QueueDequeue;
+using ops::QueueEnqueue;
+using ops::Square;
+using ops::Variable;
constexpr char kAssignOpName[] = "assign";
+constexpr char kCancelOp0[] = "cancel0";
+constexpr char kCancelOp1[] = "cancel1";
+constexpr char kCloseOp0[] = "close0";
+constexpr char kCloseOp1[] = "close1";
constexpr char kCountUpToOpName[] = "count";
+constexpr char kDequeueOp0[] = "dequeue0";
+constexpr char kDequeueOp1[] = "dequeue1";
+constexpr char kEnqueueOp0[] = "enqueue0";
+constexpr char kEnqueueOp1[] = "enqueue1";
constexpr char kIllegalOpName1[] = "would fail";
constexpr char kIllegalOpName2[] = "fail again";
constexpr char kQueueName[] = "unit_test";
+constexpr char kQueueName0[] = "q0";
+constexpr char kQueueName1[] = "q1";
constexpr char kSquareOpName[] = "square";
constexpr char kVarOpName[] = "var";
@@ -75,7 +80,7 @@ GraphDef BuildSimpleGraph() {
QueueRunnerDef BuildQueueRunnerDef(
const std::string& queue_name, const std::vector<std::string>& enqueue_ops,
- const std::string& close_op,
+ const std::string& close_op, const std::string& cancel_op,
const std::vector<Code>& queue_closed_error_codes) {
QueueRunnerDef queue_runner_def;
*queue_runner_def.mutable_queue_name() = kQueueName;
@@ -83,6 +88,7 @@ QueueRunnerDef BuildQueueRunnerDef(
*queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op;
}
*queue_runner_def.mutable_close_op_name() = close_op;
+ *queue_runner_def.mutable_cancel_op_name() = cancel_op;
for (const auto& error_code : queue_closed_error_codes) {
*queue_runner_def.mutable_queue_closed_exception_types()->Add() =
error_code;
@@ -96,8 +102,7 @@ std::unique_ptr<Session> BuildSessionAndInitVariable(
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
- std::vector<Tensor> nothing;
- TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, &nothing));
+ TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr));
return session;
}
@@ -106,7 +111,7 @@ TEST(QueueRunnerTest, BasicTest) {
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
- kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, {});
+ kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {});
QueueRunner qr(queue_runner_def);
qr.Start(session.get());
@@ -123,7 +128,7 @@ TEST(QueueRunnerTest, QueueClosedCode) {
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def =
- BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName,
+ BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, "",
{Code::OUT_OF_RANGE, Code::CANCELLED});
QueueRunner qr(queue_runner_def);
@@ -141,60 +146,167 @@ TEST(QueueRunnerDef, CatchErrorInJoin) {
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
- kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, {});
+ kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
QueueRunner qr(queue_runner_def);
qr.Start(session.get());
EXPECT_EQ(qr.Join().code(), Code::NOT_FOUND);
}
-TEST(QueueRunnerTest, RealEnqueueDequeue) {
+GraphDef BuildDoubleQueueGraph() {
Scope root = Scope::NewRootScope();
- auto q0 = FIFOQueue(root.WithOpName("q0"), {DataType::DT_INT32});
+ auto q0 = FIFOQueue(root.WithOpName(kQueueName0), {DataType::DT_INT32});
auto ten = Const(root, 10);
- auto enqueue0 = QueueEnqueue(root.WithOpName("enqueue0"), q0, {ten});
- auto close0 = QueueClose(root.WithOpName("close0"), q0);
- auto q1 = FIFOQueue(root.WithOpName("q1"), {DataType::DT_INT32});
+ auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {ten});
+ auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
+ auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
+ QueueClose::CancelPendingEnqueues(true));
+ auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32});
auto dequeue0 =
- QueueDequeue(root.WithOpName("dequeue0"), q0, {DataType::DT_INT32});
- auto enqueue1 = QueueEnqueue(root.WithOpName("enqueue1"), q1, {dequeue0[0]});
+ QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
+ auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
auto dequeue1 =
- QueueDequeue(root.WithOpName("dequeue1"), q1, {DataType::DT_INT32});
- auto close1 = QueueClose(root.WithOpName("close1"), q1);
+ QueueDequeue(root.WithOpName(kDequeueOp1), q1, {DataType::DT_INT32});
+ auto close1 = QueueClose(root.WithOpName(kCloseOp1), q1);
+ auto cancel1 = QueueClose(root.WithOpName(kCancelOp1), q1,
+ QueueClose::CancelPendingEnqueues(true));
GraphDef graph_def;
TF_EXPECT_OK(root.ToGraphDef(&graph_def));
+ return graph_def;
+}
+
+TEST(QueueRunnerTest, RealEnqueueDequeue) {
+ auto graph_def = BuildDoubleQueueGraph();
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
QueueRunnerDef queue_runner_def =
- BuildQueueRunnerDef(kQueueName, {"enqueue1"}, "close1", {});
+ BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {});
QueueRunner qr;
qr.Init(queue_runner_def);
TF_CHECK_OK(qr.Start(session.get()));
- std::vector<Tensor> outputs;
- TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs));
- TF_EXPECT_OK(session->Run({}, {}, {"enqueue0"}, &outputs));
- TF_EXPECT_OK(session->Run({}, {}, {"close0"}, &outputs));
+ TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
+ TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
+ // Closing queue 0 would also close the queue runner.
+ TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr));
TF_EXPECT_OK(qr.Join());
std::vector<Tensor> dq1;
- TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq1));
+ TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
std::vector<Tensor> dq2;
- TF_EXPECT_OK(session->Run({}, {"dequeue1"}, {}, &dq2));
+ TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq2));
EXPECT_EQ(*dq2[0].scalar<int>().data(), 10);
- EXPECT_EQ(session->Run({}, {"dequeue1"}, {}, &dq1).code(),
+ EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
Code::OUT_OF_RANGE);
}
+void JoinThread(QueueRunner* queue_runner, bool* join_succeeded,
+ Notification* join_done) {
+ EXPECT_EQ(queue_runner->Join().code(), Code::CANCELLED);
+ *join_succeeded = true;
+ join_done->Notify();
+}
+
+TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
+ auto graph_def = BuildDoubleQueueGraph();
+
+ SessionOptions options;
+ std::unique_ptr<Session> session(NewSession(options));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
+ QueueRunner qr;
+ qr.Init(queue_runner_def);
+ TF_CHECK_OK(qr.Start(session.get()));
+
+ TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
+
+ std::vector<Tensor> dq1;
+ TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
+ EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
+
+ // The expected behavior is the QueueRunner::Join() call is blocked until
+ // Session::Close() is called.
+ bool join_succeeded = false;
+ Notification join_done;
+ Env::Default()->SchedClosure(
+ std::bind(&JoinThread, &qr, &join_succeeded, &join_done));
+
+ Env::Default()->SleepForMicroseconds(10000000);
+ EXPECT_EQ(join_succeeded, false);
+
+ // Closing the session is required to cancel pending enqueue nodes.
+ TF_EXPECT_OK(session->Close());
+
+ join_done.WaitForNotification();
+ EXPECT_EQ(join_succeeded, true);
+}
+
+TEST(QueueRunnerTest, Stop) {
+ auto graph_def = BuildDoubleQueueGraph();
+
+ SessionOptions options;
+ std::unique_ptr<Session> session(NewSession(options));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
+ QueueRunner qr;
+ qr.Init(queue_runner_def);
+ TF_CHECK_OK(qr.Start(session.get()));
+
+ TF_EXPECT_OK(qr.Stop(session.get()));
+
+ TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
+
+ EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
+ Code::OUT_OF_RANGE);
+
+ // qr is already stopped
+ TF_EXPECT_OK(qr.Join());
+}
+
+TEST(QueueRunnerTest, StopTwoQueues) {
+ auto graph_def = BuildDoubleQueueGraph();
+
+ SessionOptions options;
+ std::unique_ptr<Session> session(NewSession(options));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ QueueRunnerDef queue_runner0 =
+ BuildQueueRunnerDef(kQueueName0, {kEnqueueOp0}, kCloseOp0, kCancelOp0,
+ {Code::OUT_OF_RANGE, Code::CANCELLED});
+ QueueRunnerDef queue_runner1 =
+ BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
+ {Code::OUT_OF_RANGE, Code::CANCELLED});
+ QueueRunner qr0;
+ qr0.Init(queue_runner0);
+ TF_CHECK_OK(qr0.Start(session.get()));
+ QueueRunner qr1;
+ qr1.Init(queue_runner1);
+ TF_CHECK_OK(qr1.Start(session.get()));
+
+ std::vector<Tensor> dq;
+ TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
+ EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
+
+ TF_EXPECT_OK(qr0.Stop(session.get()));
+ TF_EXPECT_OK(qr1.Stop(session.get()));
+
+ TF_EXPECT_OK(qr0.Join());
+ TF_EXPECT_OK(qr1.Join());
+}
+
TEST(QueueRunnerTest, EmptyEnqueueOps) {
QueueRunnerDef queue_runner_def =
- BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, {});
+ BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
QueueRunner qr;
EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::INVALID_ARGUMENT);
@@ -203,8 +315,8 @@ TEST(QueueRunnerTest, EmptyEnqueueOps) {
TEST(QueueRunnerTest, InitAfterStart) {
GraphDef graph_def = BuildSimpleGraph();
auto session = BuildSessionAndInitVariable(graph_def);
- QueueRunnerDef queue_runner_def =
- BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kCountUpToOpName, {});
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName, {kCountUpToOpName}, kCountUpToOpName, "", {});
QueueRunner qr;
TF_EXPECT_OK(qr.Init(queue_runner_def));
@@ -213,3 +325,4 @@ TEST(QueueRunnerTest, InitAfterStart) {
}
} // namespace
+} // namespace tensorflow