diff options
author | 2016-10-31 13:26:16 -0800 | |
---|---|---|
committer | 2016-10-31 14:31:31 -0700 | |
commit | 10351d64fb59cdd693561b7bc847f3f221eb498a (patch) | |
tree | 7f0b3e5a266503597a9f53feaab14f3a209e434c /tensorflow/cc/training/queue_runner_test.cc | |
parent | 41734d78d3facf652c25b2a2761aadd978b3f2ef (diff) |
Replace the C++ QueueRunner constructor with a static NewQueueRunner function.
Change: 137749204
Diffstat (limited to 'tensorflow/cc/training/queue_runner_test.cc')
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 89 |
1 files changed, 41 insertions, 48 deletions
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 29165778c5..0d06c62056 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -113,9 +113,10 @@ TEST(QueueRunnerTest, BasicTest) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {}); - QueueRunner qr(queue_runner_def); - qr.Start(session.get()); - TF_EXPECT_OK(qr.Join()); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); + TF_EXPECT_OK(qr->Join()); std::vector<Tensor> outputs; TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs)); @@ -131,9 +132,10 @@ TEST(QueueRunnerTest, QueueClosedCode) { BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, "", {Code::OUT_OF_RANGE, Code::CANCELLED}); - QueueRunner qr(queue_runner_def); - qr.Start(session.get()); - TF_EXPECT_OK(qr.Join()); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_EXPECT_OK(qr->Start(session.get())); + TF_EXPECT_OK(qr->Join()); std::vector<Tensor> outputs; TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs)); @@ -148,9 +150,10 @@ TEST(QueueRunnerDef, CatchErrorInJoin) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {}); - QueueRunner qr(queue_runner_def); - qr.Start(session.get()); - EXPECT_EQ(qr.Join().code(), Code::NOT_FOUND); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_EXPECT_OK(qr->Start(session.get())); + EXPECT_EQ(qr->Join().code(), Code::NOT_FOUND); } GraphDef BuildDoubleQueueGraph() { @@ -185,16 +188,16 @@ TEST(QueueRunnerTest, RealEnqueueDequeue) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {}); - QueueRunner qr; - qr.Init(queue_runner_def); - TF_CHECK_OK(qr.Start(session.get())); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); // Closing queue 0 would also close the queue runner. TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr)); - TF_EXPECT_OK(qr.Join()); + TF_EXPECT_OK(qr->Join()); std::vector<Tensor> dq1; TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1)); EXPECT_EQ(*dq1[0].scalar<int>().data(), 10); @@ -222,9 +225,9 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); - QueueRunner qr; - qr.Init(queue_runner_def); - TF_CHECK_OK(qr.Start(session.get())); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); @@ -237,7 +240,7 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) { bool join_succeeded = false; Notification join_done; Env::Default()->SchedClosure( - std::bind(&JoinThread, &qr, &join_succeeded, &join_done)); + std::bind(&JoinThread, qr.get(), &join_succeeded, &join_done)); Env::Default()->SleepForMicroseconds(10000000); EXPECT_EQ(join_succeeded, false); @@ -256,13 +259,14 @@ TEST(QueueRunnerTest, Stop) { std::unique_ptr<Session> session(NewSession(options)); TF_CHECK_OK(session->Create(graph_def)); - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); - QueueRunner qr; - qr.Init(queue_runner_def); - TF_CHECK_OK(qr.Start(session.get())); + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, + {Code::OUT_OF_RANGE, Code::CANCELLED}); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); - TF_EXPECT_OK(qr.Stop(session.get())); + TF_EXPECT_OK(qr->Stop(session.get())); TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); @@ -270,7 +274,7 @@ TEST(QueueRunnerTest, Stop) { Code::OUT_OF_RANGE); // qr is already stopped - TF_EXPECT_OK(qr.Join()); + TF_EXPECT_OK(qr->Join()); } TEST(QueueRunnerTest, StopTwoQueues) { @@ -286,42 +290,31 @@ TEST(QueueRunnerTest, StopTwoQueues) { QueueRunnerDef queue_runner1 = BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {Code::OUT_OF_RANGE, Code::CANCELLED}); - QueueRunner qr0; - qr0.Init(queue_runner0); - TF_CHECK_OK(qr0.Start(session.get())); - QueueRunner qr1; - qr1.Init(queue_runner1); - TF_CHECK_OK(qr1.Start(session.get())); + std::unique_ptr<QueueRunner> qr0; + TF_EXPECT_OK(QueueRunner::New(queue_runner0, &qr0)); + TF_CHECK_OK(qr0->Start(session.get())); + std::unique_ptr<QueueRunner> qr1; + TF_EXPECT_OK(QueueRunner::New(queue_runner1, &qr1)); + TF_CHECK_OK(qr1->Start(session.get())); std::vector<Tensor> dq; TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq)); EXPECT_EQ(*dq[0].scalar<int>().data(), 10); - TF_EXPECT_OK(qr0.Stop(session.get())); - TF_EXPECT_OK(qr1.Stop(session.get())); + TF_EXPECT_OK(qr0->Stop(session.get())); + TF_EXPECT_OK(qr1->Stop(session.get())); - TF_EXPECT_OK(qr0.Join()); - TF_EXPECT_OK(qr1.Join()); + TF_EXPECT_OK(qr0->Join()); + TF_EXPECT_OK(qr1->Join()); } TEST(QueueRunnerTest, EmptyEnqueueOps) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {}); - QueueRunner qr; - EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::INVALID_ARGUMENT); -} - -TEST(QueueRunnerTest, InitAfterStart) { - GraphDef graph_def = BuildSimpleGraph(); - auto session = BuildSessionAndInitVariable(graph_def); - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName, {kCountUpToOpName}, kCountUpToOpName, "", {}); - - QueueRunner qr; - TF_EXPECT_OK(qr.Init(queue_runner_def)); - TF_EXPECT_OK(qr.Start(session.get())); - EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::ALREADY_EXISTS); + std::unique_ptr<QueueRunner> qr; + EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(), + Code::INVALID_ARGUMENT); } } // namespace |