diff options
author | Benoit Steiner <bsteiner@google.com> | 2016-11-29 07:47:38 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-29 08:03:56 -0800 |
commit | 99700a09c632ad14d99a54d8f1db64928e32d8c6 (patch) | |
tree | dd0062477b0acbb62e33f9425e1c002689ab31df /tensorflow/cc/training | |
parent | bf606b66da986e34672a82f26a3e9f962421ff45 (diff) |
Added the ability to register a callback with a queue runner to be notified in
case of an error
Change: 140478026
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r-- | tensorflow/cc/training/queue_runner.cc | 14 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner.h | 9 | ||||
-rw-r--r-- | tensorflow/cc/training/queue_runner_test.cc | 16 |
3 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 234e902957..5d6710ea5c 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -33,6 +33,16 @@ Status QueueRunner::New(const QueueRunnerDef& queue_runner_def, return (*result)->Init(queue_runner_def); } +void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) { + mutex_lock l(cb_mu_); + callbacks_.push_back(cb); +} + +void QueueRunner::ClearErrorCallbacks() { + mutex_lock l(cb_mu_); + callbacks_.clear(); +} + Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { queue_name_ = queue_runner_def.queue_name(); enqueue_op_names_.clear(); @@ -126,6 +136,10 @@ void QueueRunner::UpdateStatus(const Status& status) { if (coord_) { coord_->ReportStatus(status); } + mutex_lock l(cb_mu_); + for (auto& cb : callbacks_) { + cb(status); + } } void QueueRunner::Run(Session* sess, const string& enqueue_op) { diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index 50213fa81b..fd9f97a958 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -46,6 +46,12 @@ class QueueRunner : public RunnerInterface { static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord, std::unique_ptr<QueueRunner>* result); + // Adds a callback that the queue runner will call when it detects an error. + void AddErrorCallback(const std::function<void(Status)>& cb); + + // Delete the previously registered callbacks. + void ClearErrorCallbacks(); + // The destructor would join all the threads. ~QueueRunner(); @@ -101,6 +107,9 @@ class QueueRunner : public RunnerInterface { std::unique_ptr<BlockingCounter> counter_; Coordinator* coord_; + + mutex cb_mu_; + std::vector<std::function<void(Status)>> callbacks_; }; } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 0e7e94b40f..1661c5c91b 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -328,5 +328,21 @@ TEST(QueueRunnerTest, TestCoordinatorStop) { TF_EXPECT_OK(coord.Join()); } +TEST(QueueRunnerTest, CallbackCalledOnError) { + GraphDef graph_def = BuildSimpleGraph(); + auto session = BuildSessionAndInitVariable(graph_def); + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {}); + + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + bool error_caught = false; + qr->AddErrorCallback([&error_caught](const Status&) { error_caught = true; }); + TF_EXPECT_OK(qr->Start(session.get())); + qr->Join(); + EXPECT_TRUE(error_caught); +} + } // namespace } // namespace tensorflow |