aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/direct_session.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-11-18 09:57:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-18 10:02:44 -0800
commit1970f1838b62950e682c83457b5eff405c96b2a9 (patch)
treee2bd6e09b3d8f5dd66f7cc18360c95793c11ddcd /tensorflow/core/common_runtime/direct_session.cc
parentb11798488dec834a0d2d4eede4d24c39a55ef898 (diff)
Updated DirectSession::RecvOutputs to take into account the session timeouts.
Change: 139591752
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc30
1 files changed, 20 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0c48aeec63..2ab0f9aa53 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -782,7 +782,8 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
s = Rendezvous::ParseKey(output_key, &parsed);
if (s.ok()) {
// Fetch data from the Rendezvous.
- s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead);
+ s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
+ operation_timeout_in_ms_);
if (is_dead && s.ok()) {
s = errors::InvalidArgument("The tensor returned for ", output_name,
" was not valid.");
@@ -1193,20 +1194,29 @@ DirectSession::RunState::~RunState() {
void DirectSession::WaitForNotification(RunState* run_state,
CancellationManager* cm,
int64 timeout_in_ms) {
+ Status status =
+ WaitForNotification(&run_state->executors_done, timeout_in_ms);
+ if (!status.ok()) {
+ {
+ mutex_lock l(run_state->mu_);
+ run_state->status.Update(status);
+ }
+ cm->StartCancel();
+ }
+}
+
+::tensorflow::Status DirectSession::WaitForNotification(
+ Notification* notification, int64 timeout_in_ms) {
if (timeout_in_ms > 0) {
- bool notified = WaitForNotificationWithTimeout(&run_state->executors_done,
- timeout_in_ms);
+ bool notified = WaitForNotificationWithTimeout(notification, timeout_in_ms);
if (!notified) {
- {
- mutex_lock l(run_state->mu_);
- run_state->status.Update(Status(error::DEADLINE_EXCEEDED,
- "Timed out waiting for notification"));
- }
- cm->StartCancel();
+ return Status(error::DEADLINE_EXCEEDED,
+ "Timed out waiting for notification");
}
} else {
- run_state->executors_done.WaitForNotification();
+ notification->WaitForNotification();
}
+ return Status::OK();
}
} // namespace tensorflow