diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2016-11-01 12:49:38 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-01 14:17:00 -0700 |
commit | 99f55f806f426a50c01dd06bd71a478009a84af2 (patch) | |
tree | cb6ecc95f412765fa9be10fcd5dbe7bb8ff82dd3 /tensorflow/cc/training | |
parent | 4863a6074f19e9546e195ab495061a6df7b18ce2 (diff) |
Add C++ Coordinator
Change: 137866409
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/coordinator.cc | 90 | ||||
-rw-r--r-- | tensorflow/cc/training/coordinator.h | 109 | ||||
-rw-r--r-- | tensorflow/cc/training/coordinator_test.cc | 183 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 77 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 25 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 84 |
6 files changed, 481 insertions, 87 deletions
diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc new file mode 100644 index 0000000000..254538d778 --- /dev/null +++ b/tensorflow/cc/training/coordinator.cc @@ -0,0 +1,90 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/training/coordinator.h" + +namespace tensorflow { + +Coordinator::Coordinator() : Coordinator(std::vector<error::Code>()) {} + +Coordinator::Coordinator(const std::vector<error::Code>& clean_stop_errors) + : should_stop_(false) { + if (clean_stop_errors.empty()) { + clean_stop_errors_.insert(error::OUT_OF_RANGE); + } else { + for (const auto& code : clean_stop_errors) { + clean_stop_errors_.insert(static_cast<int>(code)); + } + } +} + +Coordinator::~Coordinator() { + RequestStop(); + Join(); +} + +Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) { + runners_.push_back(std::move(runner)); + return Status::OK(); +} + +Status Coordinator::RequestStop() { + mutex_lock l(mu_); + if (should_stop_) { + return Status(error::FAILED_PRECONDITION, + "The Coordinator is not running."); + } + should_stop_ = true; + wait_for_stop_.notify_all(); + return Status::OK(); +} + +bool Coordinator::ShouldStop() { + mutex_lock l(mu_); + return should_stop_; +} + +Status Coordinator::Join() { + // TODO(yuefengz): deal with unexpected calls to Join(). + // TODO(yuefengz): deal with stragglers. + for (const auto& t : runners_) { + ReportStatus(t->Join()); + } + runners_.clear(); + return status_; +} + +void Coordinator::ReportStatus(const Status& status) { + mutex_lock l(status_lock_); + if (status.ok() || !status_.ok() || + clean_stop_errors_.count(static_cast<int>(status.code())) > 0) { + return; + } + status_ = status; +} + +Status Coordinator::GetStatus() { + mutex_lock l(status_lock_); + return status_; +} + +void Coordinator::WaitForStop() { + mutex_lock l(mu_); + while (!should_stop_) { + wait_for_stop_.wait(l); + } +} + +} // namespace diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h new file mode 100644 index 0000000000..987d243fbd --- /dev/null +++ b/tensorflow/cc/training/coordinator.h @@ -0,0 +1,109 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_ + +#include <memory> +#include <unordered_set> +#include <vector> + +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// The abstract interface for runners which must implement the Join function. +class RunnerInterface { + public: + virtual ~RunnerInterface() {} + virtual Status Join() = 0; +}; + +// Coordinator class manages the termination of a collection of QueueRunners. +// Without a coordinator, QueueRunners have to be joined in a specific order; +// otherwise the QueueRunner::Join() could sometimes hang. The +// Coordinator::RequestStop() plays the key role which notifies all running +// threads under a coordinator to stop. This function could be called by any +// thread or any client. +// Usage, in the client: +// Coordinator coord; +// std::unique_ptr<QueueRunner> qr(&coord, ...); +// qr.Start(session); +// coord.RegisterRunner(std::move(qr)); +// // do some work +// TF_CHECK_OK(coord.Join()); +// In each thread of QueueRunner, the coordinator needs to be used as: +// void Run() { +// while (!coord->ShouldStop()) { +// // do some work +// if (error) { +// coord->RequestStop(); +// coord->ReportStatus(error_status); +// } +// } +// } +class Coordinator { + public: + Coordinator(); + + // Constructor with a list of error codes which would not be taken as errors + // in status reporting. + Coordinator(const std::vector<error::Code>& clean_stop_errors); + + // In the destructor, RequestStop() and Join() would be called. + ~Coordinator(); + + // Registers a runner, i.e. a unit of running threads which is usually a + // QueueRunner. It takes the ownership of runner to avoid lifecycle-related + // problems. Note, the coordinator would not start these threads; they are + // supposed to be in running state when they are registered here. + Status RegisterRunner(std::unique_ptr<RunnerInterface> runner); + + // Requests all running threads to stop. + Status RequestStop(); + + // Returns true if its RequestStop() has been called. + bool ShouldStop(); + + // Joins all threads, returns OK or the first reported and unexpected status. + Status Join(); + + // Reports status to the coordinator. This is usually called by threads. + void ReportStatus(const Status& status); + + // Returns the latest status. + Status GetStatus(); + + // Returns immediately if the coordinator is stopped or blocks until + // RequestStop() is called. + void WaitForStop(); + + private: + std::vector<std::unique_ptr<RunnerInterface>> runners_; + std::unordered_set<int> clean_stop_errors_; + mutex mu_; + bool should_stop_ GUARDED_BY(mu_); + mutex status_lock_; + Status status_; + condition_variable wait_for_stop_; + TF_DISALLOW_COPY_AND_ASSIGN(Coordinator); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_ diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc new file mode 100644 index 0000000000..3bdce5f07f --- /dev/null +++ b/tensorflow/cc/training/coordinator_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/training/coordinator.h" + +#include "tensorflow/cc/training/queue_runner.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace { + +using error::Code; + +void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) { + coord->WaitForStop(); + *stopped = true; + done->Notify(); +} + +TEST(CoordinatorTest, TestStopAndWaitOnStop) { + Coordinator coord; + EXPECT_EQ(coord.ShouldStop(), false); + + bool stopped = false; + Notification done; + Env::Default()->SchedClosure( + std::bind(&WaitForStopThread, &coord, &stopped, &done)); + Env::Default()->SleepForMicroseconds(10000000); + EXPECT_EQ(stopped, false); + + coord.RequestStop(); + done.WaitForNotification(); + EXPECT_EQ(stopped, true); + EXPECT_EQ(coord.ShouldStop(), true); +} + +class MockQueueRunner : public RunnerInterface { + public: + MockQueueRunner(Coordinator* coord) { + coord_ = coord; + join_counter_ = nullptr; + thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10)); + } + + MockQueueRunner(Coordinator* coord, int* join_counter) + : MockQueueRunner(coord) { + join_counter_ = join_counter; + } + + void StartCounting(std::atomic<int>* counter, int until) { + thread_pool_->Schedule( + std::bind(&MockQueueRunner::CountThread, this, counter, until)); + } + + void StartSettingStatus(const Status& status, BlockingCounter* counter) { + thread_pool_->Schedule( + std::bind(&MockQueueRunner::SetStatusThread, this, status, counter)); + } + + Status Join() { + if (join_counter_ != nullptr) { + (*join_counter_)++; + } + thread_pool_.reset(); + return status_; + } + + Status GetStatus() { return status_; } + + void SetStatus(const Status& status) { status_ = status; } + + private: + void CountThread(std::atomic<int>* counter, int until) { + while (!coord_->ShouldStop() && counter->load() < until) { + (*counter)++; + Env::Default()->SleepForMicroseconds(100000); + } + coord_->RequestStop(); + } + void SetStatusThread(const Status& status, BlockingCounter* counter) { + Env::Default()->SleepForMicroseconds(100000); + SetStatus(status); + counter->DecrementCount(); + } + std::unique_ptr<thread::ThreadPool> thread_pool_; + Status status_; + Coordinator* coord_; + int* join_counter_; +}; + +TEST(CoordinatorTest, TestRealStop) { + std::atomic<int> counter(0); + Coordinator coord; + + std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord)); + qr1->StartCounting(&counter, 100); + coord.RegisterRunner(std::move(qr1)); + + std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord)); + qr2->StartCounting(&counter, 100); + coord.RegisterRunner(std::move(qr2)); + + // Wait until the counting has started + while (counter.load() == 0) + ; + coord.RequestStop(); + + int temp_counter = counter.load(); + Env::Default()->SleepForMicroseconds(10000000); + EXPECT_EQ(temp_counter, counter.load()); + TF_EXPECT_OK(coord.Join()); +} + +TEST(CoordinatorTest, TestRequestStop) { + Coordinator coord; + std::atomic<int> counter(0); + std::unique_ptr<MockQueueRunner> qr; + for (int i = 0; i < 10; i++) { + qr.reset(new MockQueueRunner(&coord)); + qr->StartCounting(&counter, 10); + coord.RegisterRunner(std::move(qr)); + } + + coord.WaitForStop(); + EXPECT_EQ(coord.ShouldStop(), true); + EXPECT_EQ(counter.load(), 10); + TF_EXPECT_OK(coord.Join()); +} + +TEST(CoordinatorTest, TestJoin) { + Coordinator coord; + int join_counter = 0; + std::unique_ptr<MockQueueRunner> qr1( + new MockQueueRunner(&coord, &join_counter)); + coord.RegisterRunner(std::move(qr1)); + std::unique_ptr<MockQueueRunner> qr2( + new MockQueueRunner(&coord, &join_counter)); + coord.RegisterRunner(std::move(qr2)); + + TF_EXPECT_OK(coord.Join()); + EXPECT_EQ(join_counter, 2); +} + +TEST(CoordinatorTest, StatusReporting) { + Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE}); + BlockingCounter counter(3); + + std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord)); + qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter); + coord.RegisterRunner(std::move(qr1)); + + std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord)); + qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter); + coord.RegisterRunner(std::move(qr2)); + + std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord)); + qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter); + coord.RegisterRunner(std::move(qr3)); + + counter.Wait(); + EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 6435e88d48..bc48a41ff5 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -25,6 +25,14 @@ Status QueueRunner::New(const QueueRunnerDef& queue_runner_def, return (*result)->Init(queue_runner_def); } +Status QueueRunner::New(const QueueRunnerDef& queue_runner_def, + Coordinator* coord, + std::unique_ptr<QueueRunner>* result) { + result->reset(new QueueRunner()); + (*result)->coord_ = coord; + return (*result)->Init(queue_runner_def); +} + Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { queue_name_ = queue_runner_def.queue_name(); enqueue_op_names_.clear(); @@ -46,8 +54,7 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { } thread_pool_.reset(new thread::ThreadPool( - Env::Default(), SanitizeThreadSuffix(queue_name_), runs_)); - should_stop_ = false; + Env::Default(), SanitizeThreadSuffix(queue_name_), runs_ + 1)); return Status::OK(); } @@ -66,6 +73,9 @@ Status QueueRunner::Start(Session* sess, int wait_for) { thread_pool_->Schedule( std::bind(&QueueRunner::Run, this, sess, enqueue_op)); } + if (coord_) { + thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess)); + } // Wait for up to 'wait_for' milliseconds. if (wait_for > 0) { if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) { @@ -84,12 +94,13 @@ Status QueueRunner::Start(Session* sess, int wait_for) { return Status::OK(); } -Status QueueRunner::Stop(Session* sess) { - should_stop_ = true; +void QueueRunner::Stop(Session* sess) { if (cancel_op_name_.empty()) { - return Status::OK(); + return; } else { - return sess->Run({}, {}, {cancel_op_name_}, nullptr); + CHECK(coord_ != nullptr); + coord_->WaitForStop(); + UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr)); } } @@ -99,10 +110,28 @@ Status QueueRunner::Join() { return status_; } +void QueueRunner::UpdateStatus(const Status& status) { + { + mutex_lock l(mu_); + if (!status_.ok() || status.ok() || + queue_closed_exception_types_.count(static_cast<int>(status.code())) > + 0) { + return; + } + status_ = status; + } + if (coord_) { + coord_->ReportStatus(status); + } +} + void QueueRunner::Run(Session* sess, const string& enqueue_op) { bool decremented = false; bool first_iteration = true; - while (!should_stop_.load()) { + while (true) { + if (coord_ && coord_->ShouldStop()) { + break; + } auto status = sess->Run({}, {}, {enqueue_op}, nullptr); if (first_iteration) { if (!status.ok()) { @@ -116,32 +145,26 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { continue; } else if (queue_closed_exception_types_.count( static_cast<int>(status.code())) > 0) { - mutex_lock l(mu_); - runs_--; - decremented = true; - should_stop_ = true; + { + mutex_lock l(mu_); + runs_--; + decremented = true; + } // If all enqueue ops have finished, run the close op. - if (runs_ == 0 && !close_op_name_.empty()) { - auto s = sess->Run({}, {}, {close_op_name_}, nullptr); - if (!s.ok() && status_.ok() && - queue_closed_exception_types_.count(static_cast<int>(s.code())) == - 0) { - status_ = s; + if (runs_ == 0) { + if (!close_op_name_.empty()) { + auto s = sess->Run({}, {}, {close_op_name_}, nullptr); + UpdateStatus(status); } + break; } } else { - { - mutex_lock l(mu_); - should_stop_ = true; - // Only record the first failure status. - if (status_.ok()) { - status_ = status; - } + UpdateStatus(status); + if (coord_) { + coord_->RequestStop(); } - // Stop the queue runner immediately to propagate the error to - // subsequent queues. - Stop(sess); + break; } first_iteration = false; } diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index 3affe45e55..01dd745951 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -21,6 +21,7 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "tensorflow/cc/training/coordinator.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -33,7 +34,7 @@ namespace tensorflow { // QueueRunner class imitates the behavior of the python version of QueueRunner // which creates a thread for each enqueue op, runs close op on completion. -class QueueRunner { +class QueueRunner : public RunnerInterface { public: // Creates a new QueueRunner from proto. // TODO(yuefengz): we may want to initialize from queues and ops in the @@ -41,6 +42,10 @@ class QueueRunner { static Status New(const QueueRunnerDef& queue_runner_def, std::unique_ptr<QueueRunner>* result); + // Creates a new QueueRunner with a coordinator, see coordinator.h for usage. + static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord, + std::unique_ptr<QueueRunner>* result); + // The destructor would join all the threads. ~QueueRunner(); @@ -51,18 +56,15 @@ class QueueRunner { // specified time (in milliseconds) for the queues to start to fill up. Status Start(Session* sess, int wait_for); - // Requests to stop and runs the cancel op. - Status Stop(Session* sess); - // Joins all the threads. Returns okay if all threads run successfully; // otherwise returns the first captured failure status. - Status Join(); + Status Join() final; // Returns the lastest status. Status GetStatus(); private: - QueueRunner() {} + QueueRunner() : coord_(nullptr) {} // Initializes the instance with the QueueRunnerDef proto. Status Init(const QueueRunnerDef& queue_runner_def); @@ -70,6 +72,14 @@ class QueueRunner { // The Run function for each thread. void Run(Session* sess, const string& enqueue_op); + // Requests to stop and runs the cancel op. It would be called in a separate + // thread when coordinator is set. + void Stop(Session* sess); + + // Updates the internal status; it only keeps OK or the first unexpected error + // status. + void UpdateStatus(const Status& status); + string queue_name_; std::vector<string> enqueue_op_names_; string close_op_name_; @@ -78,7 +88,6 @@ class QueueRunner { std::unordered_set<int> queue_closed_exception_types_; std::unique_ptr<thread::ThreadPool> thread_pool_; - std::atomic<bool> should_stop_; condition_variable wait_to_close_; mutex mu_; // TODO(yuefengz): implement c++ coordinator. @@ -86,6 +95,8 @@ class QueueRunner { Status status_ GUARDED_BY(mu_); Status enqueue_status_ GUARDED_BY(mu_); std::unique_ptr<BlockingCounter> counter_; + + Coordinator* coord_; }; } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index df48b25543..73ea5a307f 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/training/coordinator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -111,7 +112,7 @@ TEST(QueueRunnerTest, BasicTest) { auto session = BuildSessionAndInitVariable(graph_def); QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName, {kCountUpToOpName, kCountUpToOpName}, kSquareOpName, "", {}); + kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); std::unique_ptr<QueueRunner> qr; TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); @@ -164,7 +165,8 @@ GraphDef BuildDoubleQueueGraph() { auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0); auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0, QueueClose::CancelPendingEnqueues(true)); - auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32}); + auto q1 = FIFOQueue(root.WithOpName(kQueueName1), {DataType::DT_INT32}, + FIFOQueue::Capacity(3)); auto dequeue0 = QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_INT32}); auto enqueue1 = QueueEnqueue(root.WithOpName(kEnqueueOp1), q1, {dequeue0[0]}); @@ -252,34 +254,34 @@ TEST(QueueRunnerTest, SessionCloseCancelPendingEnqueue) { EXPECT_EQ(join_succeeded, true); } -TEST(QueueRunnerTest, Stop) { - auto graph_def = BuildDoubleQueueGraph(); +TEST(QueueRunnerTest, EmptyEnqueueOps) { + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {}); + std::unique_ptr<QueueRunner> qr; + EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(), + Code::INVALID_ARGUMENT); +} + +TEST(QueueRunnerTest, StartTimeout) { + GraphDef graph_def = BuildDoubleQueueGraph(); SessionOptions options; std::unique_ptr<Session> session(NewSession(options)); TF_CHECK_OK(session->Create(graph_def)); - QueueRunnerDef queue_runner_def = - BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, - {Code::OUT_OF_RANGE, Code::CANCELLED}); + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); + std::unique_ptr<QueueRunner> qr; TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); - TF_CHECK_OK(qr->Start(session.get())); - - TF_EXPECT_OK(qr->Stop(session.get())); - - TF_EXPECT_OK(session->Run({}, {}, {kEnqueueOp0}, nullptr)); - - EXPECT_EQ(session->Run({}, {kDequeueOp1}, {}, nullptr).code(), - Code::OUT_OF_RANGE); - - // qr is already stopped - TF_EXPECT_OK(qr->Join()); + // This will timeout since queue0 is not fed and queue1 is fetching data from + // queue0. + EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED); + session->Close(); } -TEST(QueueRunnerTest, StopTwoQueues) { +TEST(QueueRunnerTest, TestCoordinatorStop) { auto graph_def = BuildDoubleQueueGraph(); - SessionOptions options; std::unique_ptr<Session> session(NewSession(options)); TF_CHECK_OK(session->Create(graph_def)); @@ -290,48 +292,24 @@ TEST(QueueRunnerTest, StopTwoQueues) { QueueRunnerDef queue_runner1 = BuildQueueRunnerDef(kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {Code::OUT_OF_RANGE, Code::CANCELLED}); + + Coordinator coord; std::unique_ptr<QueueRunner> qr0; - TF_EXPECT_OK(QueueRunner::New(queue_runner0, &qr0)); + TF_EXPECT_OK(QueueRunner::New(queue_runner0, &coord, &qr0)); TF_CHECK_OK(qr0->Start(session.get())); std::unique_ptr<QueueRunner> qr1; - TF_EXPECT_OK(QueueRunner::New(queue_runner1, &qr1)); + TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1)); TF_CHECK_OK(qr1->Start(session.get())); + coord.RegisterRunner(std::move(qr0)); + coord.RegisterRunner(std::move(qr1)); + std::vector<Tensor> dq; TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq)); EXPECT_EQ(*dq[0].scalar<int>().data(), 10); - TF_EXPECT_OK(qr0->Stop(session.get())); - TF_EXPECT_OK(qr1->Stop(session.get())); - - TF_EXPECT_OK(qr0->Join()); - TF_EXPECT_OK(qr1->Join()); -} - -TEST(QueueRunnerTest, EmptyEnqueueOps) { - QueueRunnerDef queue_runner_def = - BuildQueueRunnerDef(kQueueName, {}, kCountUpToOpName, "", {}); - - std::unique_ptr<QueueRunner> qr; - EXPECT_EQ(QueueRunner::New(queue_runner_def, &qr).code(), - Code::INVALID_ARGUMENT); -} - -TEST(QueueRunnerTest, StartTimeout) { - GraphDef graph_def = BuildDoubleQueueGraph(); - SessionOptions options; - std::unique_ptr<Session> session(NewSession(options)); - TF_CHECK_OK(session->Create(graph_def)); - - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName1, {kEnqueueOp1}, kCloseOp1, kCancelOp1, {}); - - std::unique_ptr<QueueRunner> qr; - TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); - // This will timeout since queue0 is not fed and queue1 is fetching data from - // queue0. - EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED); - session->Close(); + TF_EXPECT_OK(coord.RequestStop()); + TF_EXPECT_OK(coord.Join()); } } // namespace |