diff options
author | 2016-11-01 12:49:38 -0800 | |
---|---|---|
committer | 2016-11-01 14:17:00 -0700 | |
commit | 99f55f806f426a50c01dd06bd71a478009a84af2 (patch) | |
tree | cb6ecc95f412765fa9be10fcd5dbe7bb8ff82dd3 /tensorflow/cc/training/queue_runner_test.cc | |
parent | 4863a6074f19e9546e195ab495061a6df7b18ce2 (diff) |
Add C++ Coordinator
Change: 137866409
Diffstat (limited to 'tensorflow/cc/training/queue_runner_test.cc')
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 84 |
1 files changed, 31 insertions, 53 deletions
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index df48b25543..73ea5a307f 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/training/coordinator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -111,7 +112,7 @@ TEST(QueueRunnerTest, BasicTest) { auto session = BuildSessionAndInitVariable(graph_def); QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {}); + kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); std::unique_ptr<QueueRunner> qr; TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); @@ -164,7 +165,8 @@ GraphDef BuildDoubleQueueGraph() { auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0); auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0, QueueClose::CancelPendingEnqueues(true)); - auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32}); + auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32}, + FIFOQueue::Capacity(3)); auto dequeue0 = QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32}); auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]}); @@ -252,34 +254,34 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) { EXPECT_EQ(join_succeeded, true); } -TEST(QueueRunnerTest, Stop) { - auto graph_def = BuildDoubleQueueGraph(); +TEST(QueueRunnerTest, EmptyEnqueueOps) { + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {}); + std::unique_ptr<QueueRunner> qr; + EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(), + Code::INVALID_ARGUMENT); +} + +TEST(QueueRunnerTest, StartTimeout) { + GraphDef graph_def = BuildDoubleQueueGraph(); SessionOptions options; std::unique_ptr<Session> session(NewSession(options)); TF_CHECK_OK(session->Create(graph_def)); - QueueRunnerDef queue_runner_def = - BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, - {Code::OUT_OF_RANGE, Code::CANCELLED}); + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); + std::unique_ptr<QueueRunner> qr; TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); - TF_CHECK_OK(qr->Start(session.get())); - - TF_EXPECT_OK(qr->Stop(session.get())); - - TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); - - EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(), - Code::OUT_OF_RANGE); - - // qr is already stopped - TF_EXPECT_OK(qr->Join()); + // This will timeout since queue0 is not fed and queue1 is fetching data from + // queue0. + EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED); + session->Close(); } -TEST(QueueRunnerTest, StopTwoQueues) { +TEST(QueueRunnerTest, TestCoordinatorStop) { auto graph_def = BuildDoubleQueueGraph(); - SessionOptions options; std::unique_ptr<Session> session(NewSession(options)); TF_CHECK_OK(session->Create(graph_def)); @@ -290,48 +292,24 @@ TEST(QueueRunnerTest, StopTwoQueues) { QueueRunnerDef queue_runner1 = BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {Code::OUT_OF_RANGE, Code::CANCELLED}); + + Coordinator coord; std::unique_ptr<QueueRunner> qr0; - TF_EXPECT_OK(QueueRunner::New(queue_runner0, &qr0)); + TF_EXPECT_OK(QueueRunner::New(queue_runner0, &coord, &qr0)); TF_CHECK_OK(qr0->Start(session.get())); std::unique_ptr<QueueRunner> qr1; - TF_EXPECT_OK(QueueRunner::New(queue_runner1, &qr1)); + TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1)); TF_CHECK_OK(qr1->Start(session.get())); + coord.RegisterRunner(std::move(qr0)); + coord.RegisterRunner(std::move(qr1)); + std::vector<Tensor> dq; TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq)); EXPECT_EQ(*dq[0].scalar<int>().data(), 10); - TF_EXPECT_OK(qr0->Stop(session.get())); - TF_EXPECT_OK(qr1->Stop(session.get())); - - TF_EXPECT_OK(qr0->Join()); - TF_EXPECT_OK(qr1->Join()); -} - -TEST(QueueRunnerTest, EmptyEnqueueOps) { - QueueRunnerDef queue_runner_def = - BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {}); - - std::unique_ptr<QueueRunner> qr; - EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(), - Code::INVALID_ARGUMENT); -} - -TEST(QueueRunnerTest, StartTimeout) { - GraphDef graph_def = BuildDoubleQueueGraph(); - SessionOptions options; - std::unique_ptr<Session> session(NewSession(options)); - TF_CHECK_OK(session->Create(graph_def)); - - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); - - std::unique_ptr<QueueRunner> qr; - TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); - // This will timeout since queue0 is not fed and queue1 is fetching data from - // queue0. - EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED); - session->Close(); + TF_EXPECT_OK(coord.RequestStop()); + TF_EXPECT_OK(coord.Join()); } } // namespace |