aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2016-11-01 12:49:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-01 14:17:00 -0700
commit99f55f806f426a50c01dd06bd71a478009a84af2 (patch)
treecb6ecc95f412765fa9be10fcd5dbe7bb8ff82dd3 /tensorflow/cc/training
parent4863a6074f19e9546e195ab495061a6df7b18ce2 (diff)
Add C++ Coordinator
Change: 137866409
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/coordinator.cc90
-rw-r--r--tensorflow/cc/training/coordinator.h109
-rw-r--r--tensorflow/cc/training/coordinator_test.cc183
-rw-r--r--tensorflow/cc/training/queue_runner.cc77
-rw-r--r--tensorflow/cc/training/queue_runner.h25
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc84
6 files changed, 481 insertions, 87 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc
new file mode 100644
index 0000000000..254538d778
--- /dev/null
+++ b/tensorflow/cc/training/coordinator.cc
@@ -0,0 +1,90 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/training/coordinator.h"
+
+namespace tensorflow {
+
+Coordinator::Coordinator() : Coordinator(std::vector<error::Code>()) {}
+
+Coordinator::Coordinator(const std::vector<error::Code>& clean_stop_errors)
+ : should_stop_(false) {
+ if (clean_stop_errors.empty()) {
+ clean_stop_errors_.insert(error::OUT_OF_RANGE);
+ } else {
+ for (const auto& code : clean_stop_errors) {
+ clean_stop_errors_.insert(static_cast<int>(code));
+ }
+ }
+}
+
+Coordinator::~Coordinator() {
+ RequestStop();
+ Join();
+}
+
+Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) {
+ runners_.push_back(std::move(runner));
+ return Status::OK();
+}
+
+Status Coordinator::RequestStop() {
+ mutex_lock l(mu_);
+ if (should_stop_) {
+ return Status(error::FAILED_PRECONDITION,
+ "The Coordinator is not running.");
+ }
+ should_stop_ = true;
+ wait_for_stop_.notify_all();
+ return Status::OK();
+}
+
+bool Coordinator::ShouldStop() {
+ mutex_lock l(mu_);
+ return should_stop_;
+}
+
+Status Coordinator::Join() {
+ // TODO(yuefengz): deal with unexpected calls to Join().
+ // TODO(yuefengz): deal with stragglers.
+ for (const auto& t : runners_) {
+ ReportStatus(t->Join());
+ }
+ runners_.clear();
+ return status_;
+}
+
+void Coordinator::ReportStatus(const Status& status) {
+ mutex_lock l(status_lock_);
+ if (status.ok() || !status_.ok() ||
+ clean_stop_errors_.count(static_cast<int>(status.code())) > 0) {
+ return;
+ }
+ status_ = status;
+}
+
+Status Coordinator::GetStatus() {
+ mutex_lock l(status_lock_);
+ return status_;
+}
+
+void Coordinator::WaitForStop() {
+ mutex_lock l(mu_);
+ while (!should_stop_) {
+ wait_for_stop_.wait(l);
+ }
+}
+
+} // namespace
diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h
new file mode 100644
index 0000000000..987d243fbd
--- /dev/null
+++ b/tensorflow/cc/training/coordinator.h
@@ -0,0 +1,109 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
+
+#include <memory>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// The abstract interface for runners which must implement the Join function.
+class RunnerInterface {
+ public:
+ virtual ~RunnerInterface() {}
+ virtual Status Join() = 0;
+};
+
+// Coordinator class manages the termination of a collection of QueueRunners.
+// Without a coordinator, QueueRunners have to be joined in a specific order;
+// otherwise the QueueRunner::Join() could sometimes hang. The
+// Coordinator::RequestStop() plays the key role which notifies all running
+// threads under a coordinator to stop. This function could be called by any
+// thread or any client.
+// Usage, in the client:
+// Coordinator coord;
+// std::unique_ptr<QueueRunner> qr(&coord, ...);
+// qr.Start(session);
+// coord.RegisterRunner(std::move(qr));
+// // do some work
+// TF_CHECK_OK(coord.Join());
+// In each thread of QueueRunner, the coordinator needs to be used as:
+// void Run() {
+// while (!coord->ShouldStop()) {
+// // do some work
+// if (error) {
+// coord->RequestStop();
+// coord->ReportStatus(error_status);
+// }
+// }
+// }
+class Coordinator {
+ public:
+ Coordinator();
+
+ // Constructor with a list of error codes which would not be taken as errors
+ // in status reporting.
+ Coordinator(const std::vector<error::Code>& clean_stop_errors);
+
+ // In the destructor, RequestStop() and Join() would be called.
+ ~Coordinator();
+
+ // Registers a runner, i.e. a unit of running threads which is usually a
+ // QueueRunner. It takes the ownership of runner to avoid lifecycle-related
+ // problems. Note, the coordinator would not start these threads; they are
+ // supposed to be in running state when they are registered here.
+ Status RegisterRunner(std::unique_ptr<RunnerInterface> runner);
+
+ // Requests all running threads to stop.
+ Status RequestStop();
+
+ // Returns true if its RequestStop() has been called.
+ bool ShouldStop();
+
+ // Joins all threads, returns OK or the first reported and unexpected status.
+ Status Join();
+
+ // Reports status to the coordinator. This is usually called by threads.
+ void ReportStatus(const Status& status);
+
+ // Returns the latest status.
+ Status GetStatus();
+
+ // Returns immediately if the coordinator is stopped or blocks until
+ // RequestStop() is called.
+ void WaitForStop();
+
+ private:
+ std::vector<std::unique_ptr<RunnerInterface>> runners_;
+ std::unordered_set<int> clean_stop_errors_;
+ mutex mu_;
+ bool should_stop_ GUARDED_BY(mu_);
+ mutex status_lock_;
+ Status status_;
+ condition_variable wait_for_stop_;
+ TF_DISALLOW_COPY_AND_ASSIGN(Coordinator);
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc
new file mode 100644
index 0000000000..3bdce5f07f
--- /dev/null
+++ b/tensorflow/cc/training/coordinator_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/training/coordinator.h"
+
+#include "tensorflow/cc/training/queue_runner.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace {
+
+using error::Code;
+
+void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) {
+ coord->WaitForStop();
+ *stopped = true;
+ done->Notify();
+}
+
+TEST(CoordinatorTest, TestStopAndWaitOnStop) {
+ Coordinator coord;
+ EXPECT_EQ(coord.ShouldStop(), false);
+
+ bool stopped = false;
+ Notification done;
+ Env::Default()->SchedClosure(
+ std::bind(&WaitForStopThread, &coord, &stopped, &done));
+ Env::Default()->SleepForMicroseconds(10000000);
+ EXPECT_EQ(stopped, false);
+
+ coord.RequestStop();
+ done.WaitForNotification();
+ EXPECT_EQ(stopped, true);
+ EXPECT_EQ(coord.ShouldStop(), true);
+}
+
+class MockQueueRunner : public RunnerInterface {
+ public:
+ MockQueueRunner(Coordinator* coord) {
+ coord_ = coord;
+ join_counter_ = nullptr;
+ thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10));
+ }
+
+ MockQueueRunner(Coordinator* coord, int* join_counter)
+ : MockQueueRunner(coord) {
+ join_counter_ = join_counter;
+ }
+
+ void StartCounting(std::atomic<int>* counter, int until) {
+ thread_pool_->Schedule(
+ std::bind(&MockQueueRunner::CountThread, this, counter, until));
+ }
+
+ void StartSettingStatus(const Status& status, BlockingCounter* counter) {
+ thread_pool_->Schedule(
+ std::bind(&MockQueueRunner::SetStatusThread, this, status, counter));
+ }
+
+ Status Join() {
+ if (join_counter_ != nullptr) {
+ (*join_counter_)++;
+ }
+ thread_pool_.reset();
+ return status_;
+ }
+
+ Status GetStatus() { return status_; }
+
+ void SetStatus(const Status& status) { status_ = status; }
+
+ private:
+ void CountThread(std::atomic<int>* counter, int until) {
+ while (!coord_->ShouldStop() && counter->load() < until) {
+ (*counter)++;
+ Env::Default()->SleepForMicroseconds(100000);
+ }
+ coord_->RequestStop();
+ }
+ void SetStatusThread(const Status& status, BlockingCounter* counter) {
+ Env::Default()->SleepForMicroseconds(100000);
+ SetStatus(status);
+ counter->DecrementCount();
+ }
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ Status status_;
+ Coordinator* coord_;
+ int* join_counter_;
+};
+
+TEST(CoordinatorTest, TestRealStop) {
+ std::atomic<int> counter(0);
+ Coordinator coord;
+
+ std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
+ qr1->StartCounting(&counter, 100);
+ coord.RegisterRunner(std::move(qr1));
+
+ std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
+ qr2->StartCounting(&counter, 100);
+ coord.RegisterRunner(std::move(qr2));
+
+ // Wait until the counting has started
+ while (counter.load() == 0)
+ ;
+ coord.RequestStop();
+
+ int temp_counter = counter.load();
+ Env::Default()->SleepForMicroseconds(10000000);
+ EXPECT_EQ(temp_counter, counter.load());
+ TF_EXPECT_OK(coord.Join());
+}
+
+TEST(CoordinatorTest, TestRequestStop) {
+ Coordinator coord;
+ std::atomic<int> counter(0);
+ std::unique_ptr<MockQueueRunner> qr;
+ for (int i = 0; i < 10; i++) {
+ qr.reset(new MockQueueRunner(&coord));
+ qr->StartCounting(&counter, 10);
+ coord.RegisterRunner(std::move(qr));
+ }
+
+ coord.WaitForStop();
+ EXPECT_EQ(coord.ShouldStop(), true);
+ EXPECT_EQ(counter.load(), 10);
+ TF_EXPECT_OK(coord.Join());
+}
+
+TEST(CoordinatorTest, TestJoin) {
+ Coordinator coord;
+ int join_counter = 0;
+ std::unique_ptr<MockQueueRunner> qr1(
+ new MockQueueRunner(&coord, &join_counter));
+ coord.RegisterRunner(std::move(qr1));
+ std::unique_ptr<MockQueueRunner> qr2(
+ new MockQueueRunner(&coord, &join_counter));
+ coord.RegisterRunner(std::move(qr2));
+
+ TF_EXPECT_OK(coord.Join());
+ EXPECT_EQ(join_counter, 2);
+}
+
+TEST(CoordinatorTest, StatusReporting) {
+ Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE});
+ BlockingCounter counter(3);
+
+ std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
+ qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter);
+ coord.RegisterRunner(std::move(qr1));
+
+ std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
+ qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter);
+ coord.RegisterRunner(std::move(qr2));
+
+ std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord));
+ qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter);
+ coord.RegisterRunner(std::move(qr3));
+
+ counter.Wait();
+ EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index 6435e88d48..bc48a41ff5 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -25,6 +25,14 @@ Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
return (*result)->Init(queue_runner_def);
}
+Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
+ Coordinator* coord,
+ std::unique_ptr<QueueRunner>* result) {
+ result->reset(new QueueRunner());
+ (*result)->coord_ = coord;
+ return (*result)->Init(queue_runner_def);
+}
+
Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
queue_name_ = queue_runner_def.queue_name();
enqueue_op_names_.clear();
@@ -46,8 +54,7 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
}
thread_pool_.reset(new thread::ThreadPool(
- Env::Default(), SanitizeThreadSuffix(queue_name_), runs_));
- should_stop_ = false;
+ Env::Default(), SanitizeThreadSuffix(queue_name_), runs_ + 1));
return Status::OK();
}
@@ -66,6 +73,9 @@ Status QueueRunner::Start(Session* sess, int wait_for) {
thread_pool_->Schedule(
std::bind(&QueueRunner::Run, this, sess, enqueue_op));
}
+ if (coord_) {
+ thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess));
+ }
// Wait for up to 'wait_for' milliseconds.
if (wait_for > 0) {
if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) {
@@ -84,12 +94,13 @@ Status QueueRunner::Start(Session* sess, int wait_for) {
return Status::OK();
}
-Status QueueRunner::Stop(Session* sess) {
- should_stop_ = true;
+void QueueRunner::Stop(Session* sess) {
if (cancel_op_name_.empty()) {
- return Status::OK();
+ return;
} else {
- return sess->Run({}, {}, {cancel_op_name_}, nullptr);
+ CHECK(coord_ != nullptr);
+ coord_->WaitForStop();
+ UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr));
}
}
@@ -99,10 +110,28 @@ Status QueueRunner::Join() {
return status_;
}
+void QueueRunner::UpdateStatus(const Status& status) {
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok() || status.ok() ||
+ queue_closed_exception_types_.count(static_cast<int>(status.code())) >
+ 0) {
+ return;
+ }
+ status_ = status;
+ }
+ if (coord_) {
+ coord_->ReportStatus(status);
+ }
+}
+
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
bool decremented = false;
bool first_iteration = true;
- while (!should_stop_.load()) {
+ while (true) {
+ if (coord_ && coord_->ShouldStop()) {
+ break;
+ }
auto status = sess->Run({}, {}, {enqueue_op}, nullptr);
if (first_iteration) {
if (!status.ok()) {
@@ -116,32 +145,26 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) {
continue;
} else if (queue_closed_exception_types_.count(
static_cast<int>(status.code())) > 0) {
- mutex_lock l(mu_);
- runs_--;
- decremented = true;
- should_stop_ = true;
+ {
+ mutex_lock l(mu_);
+ runs_--;
+ decremented = true;
+ }
// If all enqueue ops have finished, run the close op.
- if (runs_ == 0 && !close_op_name_.empty()) {
- auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
- if (!s.ok() && status_.ok() &&
- queue_closed_exception_types_.count(static_cast<int>(s.code())) ==
- 0) {
- status_ = s;
+ if (runs_ == 0) {
+ if (!close_op_name_.empty()) {
+ auto s = sess->Run({}, {}, {close_op_name_}, nullptr);
+ UpdateStatus(status);
}
+ break;
}
} else {
- {
- mutex_lock l(mu_);
- should_stop_ = true;
- // Only record the first failure status.
- if (status_.ok()) {
- status_ = status;
- }
+ UpdateStatus(status);
+ if (coord_) {
+ coord_->RequestStop();
}
- // Stop the queue runner immediately to propagate the error to
- // subsequent queues.
- Stop(sess);
+ break;
}
first_iteration = false;
}
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index 3affe45e55..01dd745951 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -33,7 +34,7 @@ namespace tensorflow {
// QueueRunner class imitates the behavior of the python version of QueueRunner
// which creates a thread for each enqueue op, runs close op on completion.
-class QueueRunner {
+class QueueRunner : public RunnerInterface {
public:
// Creates a new QueueRunner from proto.
// TODO(yuefengz): we may want to initialize from queues and ops in the
@@ -41,6 +42,10 @@ class QueueRunner {
static Status New(const QueueRunnerDef& queue_runner_def,
std::unique_ptr<QueueRunner>* result);
+ // Creates a new QueueRunner with a coordinator, see coordinator.h for usage.
+ static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord,
+ std::unique_ptr<QueueRunner>* result);
+
// The destructor would join all the threads.
~QueueRunner();
@@ -51,18 +56,15 @@ class QueueRunner {
// specified time (in milliseconds) for the queues to start to fill up.
Status Start(Session* sess, int wait_for);
- // Requests to stop and runs the cancel op.
- Status Stop(Session* sess);
-
// Joins all the threads. Returns okay if all threads run successfully;
// otherwise returns the first captured failure status.
- Status Join();
+ Status Join() final;
// Returns the lastest status.
Status GetStatus();
private:
- QueueRunner() {}
+ QueueRunner() : coord_(nullptr) {}
// Initializes the instance with the QueueRunnerDef proto.
Status Init(const QueueRunnerDef& queue_runner_def);
@@ -70,6 +72,14 @@ class QueueRunner {
// The Run function for each thread.
void Run(Session* sess, const string& enqueue_op);
+ // Requests to stop and runs the cancel op. It would be called in a separate
+ // thread when coordinator is set.
+ void Stop(Session* sess);
+
+ // Updates the internal status; it only keeps OK or the first unexpected error
+ // status.
+ void UpdateStatus(const Status& status);
+
string queue_name_;
std::vector<string> enqueue_op_names_;
string close_op_name_;
@@ -78,7 +88,6 @@ class QueueRunner {
std::unordered_set<int> queue_closed_exception_types_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
- std::atomic<bool> should_stop_;
condition_variable wait_to_close_;
mutex mu_;
// TODO(yuefengz): implement c++ coordinator.
@@ -86,6 +95,8 @@ class QueueRunner {
Status status_ GUARDED_BY(mu_);
Status enqueue_status_ GUARDED_BY(mu_);
std::unique_ptr<BlockingCounter> counter_;
+
+ Coordinator* coord_;
};
} // namespace tensorflow
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index df48b25543..73ea5a307f 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -111,7 +112,7 @@ TEST(QueueRunnerTest, BasicTest) {
auto session = BuildSessionAndInitVariable(graph_def);
QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
- kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {});
+ kQueueName, {kCountUpToOpName}, kSquareOpName, "", {});
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
@@ -164,7 +165,8 @@ GraphDef BuildDoubleQueueGraph() {
auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0);
auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0,
QueueClose::CancelPendingEnqueues(true));
- auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32});
+ auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32},
+ FIFOQueue::Capacity(3));
auto dequeue0 =
QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32});
auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]});
@@ -252,34 +254,34 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) {
EXPECT_EQ(join_succeeded, true);
}
-TEST(QueueRunnerTest, Stop) {
- auto graph_def = BuildDoubleQueueGraph();
+TEST(QueueRunnerTest, EmptyEnqueueOps) {
+ QueueRunnerDef queue_runner_def =
+ BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
+ std::unique_ptr<QueueRunner> qr;
+ EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
+ Code::INVALID_ARGUMENT);
+}
+
+TEST(QueueRunnerTest, StartTimeout) {
+ GraphDef graph_def = BuildDoubleQueueGraph();
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
- QueueRunnerDef queue_runner_def =
- BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
- {Code::OUT_OF_RANGE, Code::CANCELLED});
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
+
std::unique_ptr<QueueRunner> qr;
TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
- TF_CHECK_OK(qr->Start(session.get()));
-
- TF_EXPECT_OK(qr->Stop(session.get()));
-
- TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr));
-
- EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(),
- Code::OUT_OF_RANGE);
-
- // qr is already stopped
- TF_EXPECT_OK(qr->Join());
+ // This will timeout since queue0 is not fed and queue1 is fetching data from
+ // queue0.
+ EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
+ session->Close();
}
-TEST(QueueRunnerTest, StopTwoQueues) {
+TEST(QueueRunnerTest, TestCoordinatorStop) {
auto graph_def = BuildDoubleQueueGraph();
-
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph_def));
@@ -290,48 +292,24 @@ TEST(QueueRunnerTest, StopTwoQueues) {
QueueRunnerDef queue_runner1 =
BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1,
{Code::OUT_OF_RANGE, Code::CANCELLED});
+
+ Coordinator coord;
std::unique_ptr<QueueRunner> qr0;
- TF_EXPECT_OK(QueueRunner::New(queue_runner0, &qr0));
+ TF_EXPECT_OK(QueueRunner::New(queue_runner0, &coord, &qr0));
TF_CHECK_OK(qr0->Start(session.get()));
std::unique_ptr<QueueRunner> qr1;
- TF_EXPECT_OK(QueueRunner::New(queue_runner1, &qr1));
+ TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1));
TF_CHECK_OK(qr1->Start(session.get()));
+ coord.RegisterRunner(std::move(qr0));
+ coord.RegisterRunner(std::move(qr1));
+
std::vector<Tensor> dq;
TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq));
EXPECT_EQ(*dq[0].scalar<int>().data(), 10);
- TF_EXPECT_OK(qr0->Stop(session.get()));
- TF_EXPECT_OK(qr1->Stop(session.get()));
-
- TF_EXPECT_OK(qr0->Join());
- TF_EXPECT_OK(qr1->Join());
-}
-
-TEST(QueueRunnerTest, EmptyEnqueueOps) {
- QueueRunnerDef queue_runner_def =
- BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {});
-
- std::unique_ptr<QueueRunner> qr;
- EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(),
- Code::INVALID_ARGUMENT);
-}
-
-TEST(QueueRunnerTest, StartTimeout) {
- GraphDef graph_def = BuildDoubleQueueGraph();
- SessionOptions options;
- std::unique_ptr<Session> session(NewSession(options));
- TF_CHECK_OK(session->Create(graph_def));
-
- QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
- kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {});
-
- std::unique_ptr<QueueRunner> qr;
- TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
- // This will timeout since queue0 is not fed and queue1 is fetching data from
- // queue0.
- EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED);
- session->Close();
+ TF_EXPECT_OK(coord.RequestStop());
+ TF_EXPECT_OK(coord.Join());
}
} // namespace