aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/sycl/sycl_device.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device.cc')
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device.cc31
1 files changed, 29 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc
index e5fe85bcf5..2936b4c5c8 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc
@@ -23,11 +23,38 @@ limitations under the License.
namespace tensorflow {
+static std::unordered_set<SYCLDevice*> live_devices;
+
+void ShutdownSycl() {
+ for (auto device : live_devices) {
+ device->EnterLameDuckMode();
+ }
+ live_devices.clear();
+}
+bool first_time = true;
+
+void SYCLDevice::RegisterDevice() {
+ if (first_time) {
+ first_time = false;
+ atexit(ShutdownSycl);
+ }
+ live_devices.insert(this);
+}
+
SYCLDevice::~SYCLDevice() {
device_context_->Unref();
sycl_allocator_->EnterLameDuckMode();
delete sycl_device_;
delete sycl_queue_;
+ live_devices.erase(this);
+}
+
+void SYCLDevice::EnterLameDuckMode() {
+ sycl_allocator_->EnterLameDuckMode();
+ delete sycl_device_;
+ sycl_device_ = nullptr;
+ delete sycl_queue_;
+ sycl_queue_ = nullptr;
}
void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) {
@@ -63,8 +90,8 @@ Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
device_context_->CopyCPUTensorToDevice(&parsed, this, &copy,
[&status](const Status &s) {
- status = s;
- });
+ status = s;
+ });
*tensor = copy;
}
return status;