diff options
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device.cc')
-rw-r--r-- | tensorflow/core/common_runtime/sycl/sycl_device.cc | 32 |
1 files changed, 23 insertions, 9 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc index 19d39056ff..0abe25c373 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc @@ -23,7 +23,8 @@ limitations under the License. namespace tensorflow { -static std::unordered_set<SYCLDevice *> live_devices; +static std::unordered_set<SYCLDevice*> live_devices; +static bool first_time = true; void ShutdownSycl() { for (auto device : live_devices) { @@ -31,7 +32,6 @@ void ShutdownSycl() { } live_devices.clear(); } -bool first_time = true; void SYCLDevice::RegisterDevice() { if (first_time) { @@ -44,17 +44,27 @@ void SYCLDevice::RegisterDevice() { SYCLDevice::~SYCLDevice() { device_context_->Unref(); sycl_allocator_->EnterLameDuckMode(); - delete sycl_device_; - delete sycl_queue_; + if (sycl_device_) { + sycl_device_->synchronize(); + delete sycl_device_; + } + if (sycl_queue_) { + delete sycl_queue_; + } live_devices.erase(this); } void SYCLDevice::EnterLameDuckMode() { sycl_allocator_->EnterLameDuckMode(); - delete sycl_device_; - sycl_device_ = nullptr; - delete sycl_queue_; - sycl_queue_ = nullptr; + if (sycl_device_) { + sycl_device_->synchronize(); + delete sycl_device_; + sycl_device_ = nullptr; + } + if (sycl_queue_) { + delete sycl_queue_; + sycl_queue_ = nullptr; + } } void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) { @@ -110,7 +120,11 @@ Status SYCLDevice::FillContextMap(const Graph *graph, Status SYCLDevice::Sync() { sycl_device_->synchronize(); - return Status::OK(); + if (sycl_device_->ok()) { + return Status::OK(); + } else { + return errors::Internal("Unknown error detected on device ", name()); + } } } // namespace tensorflow |