diff options
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device.cc')
-rw-r--r-- | tensorflow/core/common_runtime/sycl/sycl_device.cc | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc index e5fe85bcf5..2936b4c5c8 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc @@ -23,11 +23,38 @@ limitations under the License. namespace tensorflow { +static std::unordered_set<SYCLDevice*> live_devices; + +void ShutdownSycl() { + for (auto device : live_devices) { + device->EnterLameDuckMode(); + } + live_devices.clear(); +} +bool first_time = true; + +void SYCLDevice::RegisterDevice() { + if (first_time) { + first_time = false; + atexit(ShutdownSycl); + } + live_devices.insert(this); +} + SYCLDevice::~SYCLDevice() { device_context_->Unref(); sycl_allocator_->EnterLameDuckMode(); delete sycl_device_; delete sycl_queue_; + live_devices.erase(this); +} + +void SYCLDevice::EnterLameDuckMode() { + sycl_allocator_->EnterLameDuckMode(); + delete sycl_device_; + sycl_device_ = nullptr; + delete sycl_queue_; + sycl_queue_ = nullptr; } void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) { @@ -63,8 +90,8 @@ Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); device_context_->CopyCPUTensorToDevice(&parsed, this, ©, [&status](const Status &s) { - status = s; - }); + status = s; + }); *tensor = copy; } return status; |