aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-20 01:43:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 01:46:55 -0700
commita54310b1faa39df94dcef9ad1b5aaa0acc691e35 (patch)
treeb8f1a60490cd697e008b89569f775dd5aede5799 /tensorflow/core/common_runtime
parentda3357ecbdd6772413e8bbceeab8238971be11ce (diff)
Internal change.
PiperOrigin-RevId: 213770000
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/device.h4
-rw-r--r--tensorflow/core/common_runtime/executor.cc6
2 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 81d68e3be4..fb76d6ac29 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -106,6 +106,10 @@ class Device : public DeviceBase {
// at completion.
virtual Status Sync() = 0;
+ // Override this to return true for devices that require a Sync() call before
+ // session completion.
+ virtual bool RequiresSyncOnCompletion() const { return false; }
+
// Optionally modify the device's GraphDef before execution.
//
// This method should be considered experimental and is supplied to enable
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index d0a0767d6b..98719542c0 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -2301,13 +2301,15 @@ void ExecutorState::Finish() {
auto done_cb = std::move(done_cb_);
auto runner = std::move(runner_);
mu_.unlock();
- if (sync_on_finish_ && status.ok()) {
+ Device* device = impl_->params_.device;
+ if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) {
// Block until the device has finished all queued operations. For
// devices like GPUs that continue to execute Ops after their Compute
// methods have completed, this ensures that control is not returned to
// the user until the step (and its side-effects) has actually completed.
- status = impl_->params_.device->Sync();
+ status.Update(device->Sync());
}
+
delete this;
CHECK(done_cb != nullptr);
runner([=]() { done_cb(status); });