aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training/queue_runner_test.cc
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2016-11-01 12:49:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-01 14:17:00 -0700
commit99f55f806f426a50c01dd06bd71a478009a84af2 (patch)
treecb6ecc95f412765fa9be10fcd5dbe7bb8ff82dd3 /tensorflow/cc/training/queue_runner_test.cc
parent4863a6074f19e9546e195ab495061a6df7b18ce2 (diff)
Add C++ Coordinator
Change: 137866409
Diffstat (limited to 'tensorflow/cc/training/queue_runner_test.cc')
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc84
1 files changed, 31 insertions, 53 deletions
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index df48b25543..73ea5a307f 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -111,7 +112,7 @@ TEST(QueueRunnerTest, BasicTest) {
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
- kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {});
+ kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
@@ -164,7 +165,8 @@ GraphDef BuildDoubleQueueGraph() {
auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
QueueClose::CancelPendingEnqueues(true));
- auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32});
+ auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32},
+ FIFOQueue::Capacity(3));
auto dequeue0 =
QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
@@ -252,34 +254,34 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
EXPECT_EQ(join_succeeded, true);
}
-TEST(QueueRunnerTest, Stop) {
- auto graph_def = BuildDoubleQueueGraph();
+TEST(QueueRunnerTest, EmptyEnqueueOps) {
+ QueueRunnerDef queue_runner_def =
+ BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
+ std::unique_ptr<QueueRunner> qr;
+ EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
+ Code::INVALID_ARGUMENT);
+}
+
+TEST(QueueRunnerTest, StartTimeout) {
+ GraphDef graph_def = BuildDoubleQueueGraph();
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
- QueueRunnerDef queue_runner_def =
- BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
- {Code::OUT_OF_RANGE, Code::CANCELLED});
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
+
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
- TF_CHECK_OK(qr->Start(session.get()));
-
- TF_EXPECT_OK(qr->Stop(session.get()));
-
- TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
-
- EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
- Code::OUT_OF_RANGE);
-
- // qr is already stopped
- TF_EXPECT_OK(qr->Join());
+ // This will timeout since queue0 is not fed and queue1 is fetching data from
+ // queue0.
+ EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
+ session->Close();
}
-TEST(QueueRunnerTest, StopTwoQueues) {
+TEST(QueueRunnerTest, TestCoordinatorStop) {
auto graph_def = BuildDoubleQueueGraph();
-
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
@@ -290,48 +292,24 @@ TEST(QueueRunnerTest, StopTwoQueues) {
QueueRunnerDef queue_runner1 =
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
{Code::OUT_OF_RANGE, Code::CANCELLED});
+
+ Coordinator coord;
std::unique_ptr<QueueRunner> qr0;
- TF_EXPECT_OK(QueueRunner::New(queue_runner0, &qr0));
+ TF_EXPECT_OK(QueueRunner::New(queue_runner0, &coord, &qr0));
TF_CHECK_OK(qr0->Start(session.get()));
std::unique_ptr<QueueRunner> qr1;
- TF_EXPECT_OK(QueueRunner::New(queue_runner1, &qr1));
+ TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1));
TF_CHECK_OK(qr1->Start(session.get()));
+ coord.RegisterRunner(std::move(qr0));
+ coord.RegisterRunner(std::move(qr1));
+
std::vector<Tensor> dq;
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
- TF_EXPECT_OK(qr0->Stop(session.get()));
- TF_EXPECT_OK(qr1->Stop(session.get()));
-
- TF_EXPECT_OK(qr0->Join());
- TF_EXPECT_OK(qr1->Join());
-}
-
-TEST(QueueRunnerTest, EmptyEnqueueOps) {
- QueueRunnerDef queue_runner_def =
- BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
-
- std::unique_ptr<QueueRunner> qr;
- EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
- Code::INVALID_ARGUMENT);
-}
-
-TEST(QueueRunnerTest, StartTimeout) {
- GraphDef graph_def = BuildDoubleQueueGraph();
- SessionOptions options;
- std::unique_ptr<Session> session(NewSession(options));
- TF_CHECK_OK(session->Create(graph_def));
-
- QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
- kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
-
- std::unique_ptr<QueueRunner> qr;
- TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
- // This will timeout since queue0 is not fed and queue1 is fetching data from
- // queue0.
- EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
- session->Close();
+ TF_EXPECT_OK(coord.RequestStop());
+ TF_EXPECT_OK(coord.Join());
}
} // namespace