aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2016-10-31 13:26:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 14:31:31 -0700
commit10351d64fb59cdd693561b7bc847f3f221eb498a (patch)
tree7f0b3e5a266503597a9f53feaab14f3a209e434c /tensorflow/cc/training
parent41734d78d3facf652c25b2a2761aadd978b3f2ef (diff)
Replace the C++ QueueRunner constructor with a static NewQueueRunner function.
Change: 137749204
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/queue_runner.cc20
-rw-r--r--tensorflow/cc/training/queue_runner.h16
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc89
3 files changed, 54 insertions, 71 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index 371d528143..ed1d0a5da0 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -19,18 +19,15 @@ limitations under the License.
namespace tensorflow {
-QueueRunner::QueueRunner() : started_(false) {}
-
-QueueRunner::QueueRunner(const QueueRunnerDef& queue_runner_def)
- : started_(false) {
- TF_CHECK_OK(Init(queue_runner_def));
+Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
+ std::unique_ptr<QueueRunner>* result) {
+ result->reset(new QueueRunner());
+ return (*result)->Init(queue_runner_def);
}
Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
- if (started_.load()) {
- return Status(error::ALREADY_EXISTS, "QueueRunner is already running.");
- }
queue_name_ = queue_runner_def.queue_name();
+ enqueue_op_names_.clear();
enqueue_op_names_.insert(enqueue_op_names_.end(),
queue_runner_def.enqueue_op_name().begin(),
queue_runner_def.enqueue_op_name().end());
@@ -61,12 +58,6 @@ QueueRunner::~QueueRunner() {
}
Status QueueRunner::Start(Session* sess) {
- if (runs_ == 0) {
- return Status(
- error::INVALID_ARGUMENT,
- "No enqueue ops to run. You may want to Init the QueueRunner first.");
- }
- started_ = true;
for (const string& enqueue_op : enqueue_op_names_) {
thread_pool_->Schedule(
std::bind(&QueueRunner::Run, this, sess, enqueue_op));
@@ -85,7 +76,6 @@ Status QueueRunner::Stop(Session* sess) {
Status QueueRunner::Join() {
thread_pool_.reset();
- started_ = false;
return status_;
}
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index fd0cbebe34..9374fe3605 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -34,19 +34,15 @@ namespace tensorflow {
// which creates a thread for each enqueue op, runs close op on completion.
class QueueRunner {
public:
- QueueRunner();
-
- // The constructor initializes the class from the proto.
+ // Creates a new QueueRunner from proto.
// TODO(yuefengz): we may want to initialize from queues and ops in the
// future.
- explicit QueueRunner(const QueueRunnerDef& queue_runner_def);
+ static Status New(const QueueRunnerDef& queue_runner_def,
+ std::unique_ptr<QueueRunner>* result);
// The destructor would join all the threads.
~QueueRunner();
- // Initializes the instance with the QueueRunnerDef proto.
- Status Init(const QueueRunnerDef& queue_runner_def);
-
// Starts the queue runner with the given session.
Status Start(Session* sess);
@@ -61,6 +57,11 @@ class QueueRunner {
Status GetStatus();
private:
+ QueueRunner() {}
+
+ // Initializes the instance with the QueueRunnerDef proto.
+ Status Init(const QueueRunnerDef& queue_runner_def);
+
// The Run function for each thread.
void Run(Session* sess, const string& enqueue_op);
@@ -73,7 +74,6 @@ class QueueRunner {
std::unique_ptr<thread::ThreadPool> thread_pool_;
std::atomic<bool> should_stop_;
- std::atomic<bool> started_;
condition_variable wait_to_close_;
mutex mu_;
// TODO(yuefengz): implement c++ coordinator.
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 29165778c5..0d06c62056 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -113,9 +113,10 @@ TEST(QueueRunnerTest, BasicTest) {
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {});
- QueueRunner qr(queue_runner_def);
- qr.Start(session.get());
- TF_EXPECT_OK(qr.Join());
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ TF_CHECK_OK(qr->Start(session.get()));
+ TF_EXPECT_OK(qr->Join());
std::vector<Tensor> outputs;
TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs));
@@ -131,9 +132,10 @@ TEST(QueueRunnerTest, QueueClosedCode) {
BuildQueueRunnerDef(kQueueName, {kCountUpToOpName}, kSquareOpName, "",
{Code::OUT_OF_RANGE, Code::CANCELLED});
- QueueRunner qr(queue_runner_def);
- qr.Start(session.get());
- TF_EXPECT_OK(qr.Join());
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ TF_EXPECT_OK(qr->Start(session.get()));
+ TF_EXPECT_OK(qr->Join());
std::vector<Tensor> outputs;
TF_EXPECT_OK(session->Run({}, {kSquareOpName}, {}, &outputs));
@@ -148,9 +150,10 @@ TEST(QueueRunnerDef, CatchErrorInJoin) {
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
- QueueRunner qr(queue_runner_def);
- qr.Start(session.get());
- EXPECT_EQ(qr.Join().code(), Code::NOT_FOUND);
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ TF_EXPECT_OK(qr->Start(session.get()));
+ EXPECT_EQ(qr->Join().code(), Code::NOT_FOUND);
}
GraphDef BuildDoubleQueueGraph() {
@@ -185,16 +188,16 @@ TEST(QueueRunnerTest, RealEnqueueDequeue) {
QueueRunnerDef queue_runner_def =
BuildQueueRunnerDef(kQueueName, {kEnqueueOp1}, kCloseOp1, "", {});
- QueueRunner qr;
- qr.Init(queue_runner_def);
- TF_CHECK_OK(qr.Start(session.get()));
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ TF_CHECK_OK(qr->Start(session.get()));
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
// Closing queue 0 would also close the queue runner.
TF_EXPECT_OK(session->Run({}, {}, {kCloseOp0}, nullptr));
- TF_EXPECT_OK(qr.Join());
+ TF_EXPECT_OK(qr->Join());
std::vector<Tensor> dq1;
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq1));
EXPECT_EQ(*dq1[0].scalar<int>().data(), 10);
@@ -222,9 +225,9 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
- QueueRunner qr;
- qr.Init(queue_runner_def);
- TF_CHECK_OK(qr.Start(session.get()));
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ TF_CHECK_OK(qr->Start(session.get()));
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
@@ -237,7 +240,7 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
bool join_succeeded = false;
Notification join_done;
Env::Default()->SchedClosure(
- std::bind(&JoinThread, &qr, &join_succeeded, &join_done));
+ std::bind(&JoinThread, qr.get(), &join_succeeded, &join_done));
Env::Default()->SleepForMicroseconds(10000000);
EXPECT_EQ(join_succeeded, false);
@@ -256,13 +259,14 @@ TEST(QueueRunnerTest, Stop) {
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
- QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
- kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
- QueueRunner qr;
- qr.Init(queue_runner_def);
- TF_CHECK_OK(qr.Start(session.get()));
+ QueueRunnerDef queue_runner_def =
+ BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
+ {Code::OUT_OF_RANGE, Code::CANCELLED});
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ TF_CHECK_OK(qr->Start(session.get()));
- TF_EXPECT_OK(qr.Stop(session.get()));
+ TF_EXPECT_OK(qr->Stop(session.get()));
TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
@@ -270,7 +274,7 @@ TEST(QueueRunnerTest, Stop) {
Code::OUT_OF_RANGE);
// qr is already stopped
- TF_EXPECT_OK(qr.Join());
+ TF_EXPECT_OK(qr->Join());
}
TEST(QueueRunnerTest, StopTwoQueues) {
@@ -286,42 +290,31 @@ TEST(QueueRunnerTest, StopTwoQueues) {
QueueRunnerDef queue_runner1 =
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
{Code::OUT_OF_RANGE, Code::CANCELLED});
- QueueRunner qr0;
- qr0.Init(queue_runner0);
- TF_CHECK_OK(qr0.Start(session.get()));
- QueueRunner qr1;
- qr1.Init(queue_runner1);
- TF_CHECK_OK(qr1.Start(session.get()));
+ std::unique_ptr<QueueRunner> qr0;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner0, &qr0));
+ TF_CHECK_OK(qr0->Start(session.get()));
+ std::unique_ptr<QueueRunner> qr1;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner1, &qr1));
+ TF_CHECK_OK(qr1->Start(session.get()));
std::vector<Tensor> dq;
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
- TF_EXPECT_OK(qr0.Stop(session.get()));
- TF_EXPECT_OK(qr1.Stop(session.get()));
+ TF_EXPECT_OK(qr0->Stop(session.get()));
+ TF_EXPECT_OK(qr1->Stop(session.get()));
- TF_EXPECT_OK(qr0.Join());
- TF_EXPECT_OK(qr1.Join());
+ TF_EXPECT_OK(qr0->Join());
+ TF_EXPECT_OK(qr1->Join());
}
TEST(QueueRunnerTest, EmptyEnqueueOps) {
QueueRunnerDef queue_runner_def =
BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
- QueueRunner qr;
- EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::INVALID_ARGUMENT);
-}
-
-TEST(QueueRunnerTest, InitAfterStart) {
- GraphDef graph_def = BuildSimpleGraph();
- auto session = BuildSessionAndInitVariable(graph_def);
- QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
- kQueueName, {kCountUpToOpName}, kCountUpToOpName, "", {});
-
- QueueRunner qr;
- TF_EXPECT_OK(qr.Init(queue_runner_def));
- TF_EXPECT_OK(qr.Start(session.get()));
- EXPECT_EQ(qr.Init(queue_runner_def).code(), Code::ALREADY_EXISTS);
+ std::unique_ptr<QueueRunner> qr;
+ EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
+ Code::INVALID_ARGUMENT);
}
} // namespace