aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/sycl/sycl_device.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device.cc')
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device.cc32
1 files changed, 23 insertions, 9 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc
index 19d39056ff..0abe25c373 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc
@@ -23,7 +23,8 @@ limitations under the License.
namespace tensorflow {
-static std::unordered_set<SYCLDevice *> live_devices;
+static std::unordered_set<SYCLDevice*> live_devices;
+static bool first_time = true;
void ShutdownSycl() {
for (auto device : live_devices) {
@@ -31,7 +32,6 @@ void ShutdownSycl() {
}
live_devices.clear();
}
-bool first_time = true;
void SYCLDevice::RegisterDevice() {
if (first_time) {
@@ -44,17 +44,27 @@ void SYCLDevice::RegisterDevice() {
SYCLDevice::~SYCLDevice() {
device_context_->Unref();
sycl_allocator_->EnterLameDuckMode();
- delete sycl_device_;
- delete sycl_queue_;
+ if (sycl_device_) {
+ sycl_device_->synchronize();
+ delete sycl_device_;
+ }
+ if (sycl_queue_) {
+ delete sycl_queue_;
+ }
live_devices.erase(this);
}
void SYCLDevice::EnterLameDuckMode() {
sycl_allocator_->EnterLameDuckMode();
- delete sycl_device_;
- sycl_device_ = nullptr;
- delete sycl_queue_;
- sycl_queue_ = nullptr;
+ if (sycl_device_) {
+ sycl_device_->synchronize();
+ delete sycl_device_;
+ sycl_device_ = nullptr;
+ }
+ if (sycl_queue_) {
+ delete sycl_queue_;
+ sycl_queue_ = nullptr;
+ }
}
void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) {
@@ -110,7 +120,11 @@ Status SYCLDevice::FillContextMap(const Graph *graph,
Status SYCLDevice::Sync() {
sycl_device_->synchronize();
- return Status::OK();
+ if (sycl_device_->ok()) {
+ return Status::OK();
+ } else {
+ return errors::Internal("Unknown error detected on device ", name());
+ }
}
} // namespace tensorflow