aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/training
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-11-29 07:47:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-29 08:03:56 -0800
commit99700a09c632ad14d99a54d8f1db64928e32d8c6 (patch)
treedd0062477b0acbb62e33f9425e1c002689ab31df /tensorflow/cc/training
parentbf606b66da986e34672a82f26a3e9f962421ff45 (diff)
Added the ability to register a callback with a queue runner to be notified in
case of an error Change: 140478026
Diffstat (limited to 'tensorflow/cc/training')
-rw-r--r--tensorflow/cc/training/queue_runner.cc14
-rw-r--r--tensorflow/cc/training/queue_runner.h9
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc16
3 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index 234e902957..5d6710ea5c 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -33,6 +33,16 @@ Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
return (*result)->Init(queue_runner_def);
}
+void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) {
+ mutex_lock l(cb_mu_);
+ callbacks_.push_back(cb);
+}
+
+void QueueRunner::ClearErrorCallbacks() {
+ mutex_lock l(cb_mu_);
+ callbacks_.clear();
+}
+
Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
queue_name_ = queue_runner_def.queue_name();
enqueue_op_names_.clear();
@@ -126,6 +136,10 @@ void QueueRunner::UpdateStatus(const Status& status) {
if (coord_) {
coord_->ReportStatus(status);
}
+ mutex_lock l(cb_mu_);
+ for (auto& cb : callbacks_) {
+ cb(status);
+ }
}
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index 50213fa81b..fd9f97a958 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -46,6 +46,12 @@ class QueueRunner : public RunnerInterface {
static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord,
std::unique_ptr<QueueRunner>* result);
+ // Adds a callback that the queue runner will call when it detects an error.
+ void AddErrorCallback(const std::function<void(Status)>& cb);
+
+ // Delete the previously registered callbacks.
+ void ClearErrorCallbacks();
+
// The destructor would join all the threads.
~QueueRunner();
@@ -101,6 +107,9 @@ class QueueRunner : public RunnerInterface {
std::unique_ptr<BlockingCounter> counter_;
Coordinator* coord_;
+
+ mutex cb_mu_;
+ std::vector<std::function<void(Status)>> callbacks_;
};
} // namespace tensorflow
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 0e7e94b40f..1661c5c91b 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -328,5 +328,21 @@ TEST(QueueRunnerTest, TestCoordinatorStop) {
TF_EXPECT_OK(coord.Join());
}
+TEST(QueueRunnerTest, CallbackCalledOnError) {
+ GraphDef graph_def = BuildSimpleGraph();
+ auto session = BuildSessionAndInitVariable(graph_def);
+
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
+
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ bool error_caught = false;
+ qr->AddErrorCallback([&error_caught](const Status&) { error_caught = true; });
+ TF_EXPECT_OK(qr->Start(session.get()));
+ qr->Join();
+ EXPECT_TRUE(error_caught);
+}
+
} // namespace
} // namespace tensorflow