diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2016-10-31 13:26:16 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-31 14:31:31 -0700 |
commit | 10351d64fb59cdd693561b7bc847f3f221eb498a (patch) | |
tree | 7f0b3e5a266503597a9f53feaab14f3a209e434c /tensorflow/cc/training | |
parent | 41734d78d3facf652c25b2a2761aadd978b3f2ef (diff) |
Replace the C++ QueueRunner constructor with a static NewQueueRunner function.
Change: 137749204
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 20 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 16 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 89 |
3 files changed, 54 insertions, 71 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 371d528143..ed1d0a5da0 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -19,18 +19,15 @@ limitations under the License. namespace tensorflow { -QueueRunner::QueueRunner() : started_(false) {} - -QueueRunner::QueueRunner(const QueueRunnerDef& queue_runner_def) - : started_(false) { - TF_CHECK_OK(Init(queue_runner_def)); +Status QueueRunner::New(const QueueRunnerDef& queue_runner_def, + std::unique_ptr<QueueRunner>* result) { + result->reset(new QueueRunner()); + return (*result)->Init(queue_runner_def); } Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { - if (started_.load()) { - return Status(error::ALREADY_EXISTS, "QueueRunner is already running."); - } queue_name_ = queue_runner_def.queue_name(); + enqueue_op_names_.clear(); enqueue_op_names_.insert(enqueue_op_names_.end(), queue_runner_def.enqueue_op_name().begin(), queue_runner_def.enqueue_op_name().end()); @@ -61,12 +58,6 @@ QueueRunner::~QueueRunner() { } Status QueueRunner::Start(Session* sess) { - if (runs_ == 0) { - return Status( - error::INVALID_ARGUMENT, - "No enqueue ops to run. You may want to Init the QueueRunner first."); - } - started_ = true; for (const string& enqueue_op : enqueue_op_names_) { thread_pool_->Schedule( std::bind(&QueueRunner::Run, this, sess, enqueue_op)); @@ -85,7 +76,6 @@ Status QueueRunner::Stop(Session* sess) { Status QueueRunner::Join() { thread_pool_.reset(); - started_ = false; return status_; } diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index fd0cbebe34..9374fe3605 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -34,19 +34,15 @@ namespace tensorflow { // which creates a thread for each enqueue op, runs close op on completion. class QueueRunner { public: - QueueRunner(); - - // The constructor initializes the class from the proto. + // Creates a new QueueRunner from proto. // TODO(yuefengz): we may want to initialize from queues and ops in the // future. - explicit QueueRunner(const QueueRunnerDef& queue_runner_def); + static Status New(const QueueRunnerDef& queue_runner_def, + std::unique_ptr<QueueRunner>* result); // The destructor would join all the threads. ~QueueRunner(); - // Initializes the instance with the QueueRunnerDef proto. - Status Init(const QueueRunnerDef& queue_runner_def); - // Starts the queue runner with the given session. Status Start(Session* sess); @@ -61,6 +57,11 @@ class QueueRunner { Status GetStatus(); private: + QueueRunner() {} + + // Initializes the instance with the QueueRunnerDef proto. + Status Init(const QueueRunnerDef& queue_runner_def); + // The Run function for each thread. void Run(Session* sess, const string& enqueue_op); @@ -73,7 +74,6 @@ class QueueRunner { std::unique_ptr<thread::ThreadPool> thread_pool_; std::atomic<bool> should_stop_; - std::atomic<bool> started_; condition_variable wait_to_close_; mutex mu_; // TODO(yuefengz): implement c++ coordinator. diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 29165778c5..0d06c62056 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -113,9 +113,10 @@ TEST(QueueRunnerTest, BasicTest) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {}); - QueueRunner qr(queue_runner_def); - qr.Start(session.get()); - TF_EXPECT_OK(qr.Join()); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); + TF_EXPECT_OK(qr->Join()); std::vector<Tensor> outputs; TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs)); @@ -131,9 +132,10 @@ TEST(QueueRunnerTest, QueueClosedCode) { BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, "", {Code::OUT_OF_RANGE, Code::CANCELLED}); - QueueRunner qr(queue_runner_def); - qr.Start(session.get()); - TF_EXPECT_OK(qr.Join()); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_EXPECT_OK(qr->Start(session.get())); + TF_EXPECT_OK(qr->Join()); std::vector<Tensor> outputs; TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs)); @@ -148,9 +150,10 @@ TEST(QueueRunnerDef, CatchErrorInJoin) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {}); - QueueRunner qr(queue_runner_def); - qr.Start(session.get()); - EXPECT_EQ(qr.Join().code(), Code::NOT_FOUND); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_EXPECT_OK(qr->Start(session.get())); + EXPECT_EQ(qr->Join().code(), Code::NOT_FOUND); } GraphDef BuildDoubleQueueGraph() { @@ -185,16 +188,16 @@ TEST(QueueRunnerTest, RealEnqueueDequeue) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {}); - QueueRunner qr; - qr.Init(queue_runner_def); - TF_CHECK_OK(qr.Start(session.get())); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); // Closing queue 0 would also close the queue runner. TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr)); - TF_EXPECT_OK(qr.Join()); + TF_EXPECT_OK(qr->Join()); std::vector<Tensor> dq1; TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1)); EXPECT_EQ(*dq1[0].scalar<int>().data(), 10); @@ -222,9 +225,9 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); - QueueRunner qr; - qr.Init(queue_runner_def); - TF_CHECK_OK(qr.Start(session.get())); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); @@ -237,7 +240,7 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) { bool join_succeeded = false; Notification join_done; Env::Default()->SchedClosure( - std::bind(&JoinThread, &qr, &join_succeeded, &join_done)); + std::bind(&JoinThread, qr.get(), &join_succeeded, &join_done)); Env::Default()->SleepForMicroseconds(10000000); EXPECT_EQ(join_succeeded, false); @@ -256,13 +259,14 @@ TEST(QueueRunnerTest, Stop) { std::unique_ptr<Session> session(NewSession(options)); TF_CHECK_OK(session->Create(graph_def)); - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); - QueueRunner qr; - qr.Init(queue_runner_def); - TF_CHECK_OK(qr.Start(session.get())); + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, + {Code::OUT_OF_RANGE, Code::CANCELLED}); + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); - TF_EXPECT_OK(qr.Stop(session.get())); + TF_EXPECT_OK(qr->Stop(session.get())); TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); @@ -270,7 +274,7 @@ TEST(QueueRunnerTest, Stop) { Code::OUT_OF_RANGE); // qr is already stopped - TF_EXPECT_OK(qr.Join()); + TF_EXPECT_OK(qr->Join()); } TEST(QueueRunnerTest, StopTwoQueues) { @@ -286,42 +290,31 @@ TEST(QueueRunnerTest, StopTwoQueues) { QueueRunnerDef queue_runner1 = BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {Code::OUT_OF_RANGE, Code::CANCELLED}); - QueueRunner qr0; - qr0.Init(queue_runner0); - TF_CHECK_OK(qr0.Start(session.get())); - QueueRunner qr1; - qr1.Init(queue_runner1); - TF_CHECK_OK(qr1.Start(session.get())); + std::unique_ptr<QueueRunner> qr0; + TF_EXPECT_OK(QueueRunner::New(queue_runner0, &qr0)); + TF_CHECK_OK(qr0->Start(session.get())); + std::unique_ptr<QueueRunner> qr1; + TF_EXPECT_OK(QueueRunner::New(queue_runner1, &qr1)); + TF_CHECK_OK(qr1->Start(session.get())); std::vector<Tensor> dq; TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq)); EXPECT_EQ(*dq[0].scalar<int>().data(), 10); - TF_EXPECT_OK(qr0.Stop(session.get())); - TF_EXPECT_OK(qr1.Stop(session.get())); + TF_EXPECT_OK(qr0->Stop(session.get())); + TF_EXPECT_OK(qr1->Stop(session.get())); - TF_EXPECT_OK(qr0.Join()); - TF_EXPECT_OK(qr1.Join()); + TF_EXPECT_OK(qr0->Join()); + TF_EXPECT_OK(qr1->Join()); } TEST(QueueRunnerTest, EmptyEnqueueOps) { QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {}); - QueueRunner qr; - EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::INVALID_ARGUMENT); -} - -TEST(QueueRunnerTest, InitAfterStart) { - GraphDef graph_def = BuildSimpleGraph(); - auto session = BuildSessionAndInitVariable(graph_def); - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName, {kCountUpToOpName}, kCountUpToOpName, "", {}); - - QueueRunner qr; - TF_EXPECT_OK(qr.Init(queue_runner_def)); - TF_EXPECT_OK(qr.Start(session.get())); - EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::ALREADY_EXISTS); + std::unique_ptr<QueueRunner> qr; + EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(), + Code::INVALID_ARGUMENT); } } // namespace |