diff options
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device.cc')
-rw-r--r-- | tensorflow/core/common_runtime/sycl/sycl_device.cc | 58 |
1 files changed, 19 insertions, 39 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc index 2c2185b2c0..17f5edd572 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc @@ -22,50 +22,18 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" namespace tensorflow { - -static std::unordered_set<SYCLDevice *> live_devices; -static bool first_time = true; +std::mutex GSYCLInterface::mutex_; +GSYCLInterface *GSYCLInterface::s_instance = 0; void ShutdownSycl() { - for (auto device : live_devices) { - device->EnterLameDuckMode(); - } - live_devices.clear(); + GSYCLInterface::Reset(); } void SYCLDevice::RegisterDevice() { - if (first_time) { - first_time = false; atexit(ShutdownSycl); - } - live_devices.insert(this); } -SYCLDevice::~SYCLDevice() { - device_context_->Unref(); - sycl_allocator_->EnterLameDuckMode(); - if (sycl_device_) { - sycl_device_->synchronize(); - delete sycl_device_; - } - if (sycl_queue_) { - delete sycl_queue_; - } - live_devices.erase(this); -} - -void SYCLDevice::EnterLameDuckMode() { - sycl_allocator_->EnterLameDuckMode(); - if (sycl_device_) { - sycl_device_->synchronize(); - delete sycl_device_; - sycl_device_ = nullptr; - } - if (sycl_queue_) { - delete sycl_queue_; - sycl_queue_ = nullptr; - } -} +SYCLDevice::~SYCLDevice() {} void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) { assert(context); @@ -88,8 +56,12 @@ Allocator *SYCLDevice::GetAllocator(AllocatorAttributes attr) { Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto, const AllocatorAttributes alloc_attrs, Tensor *tensor) { + AllocatorAttributes attr; + attr.set_on_host(true); + Allocator* host_alloc = GetAllocator(attr); + Tensor parsed(tensor_proto.dtype()); - if (!parsed.FromProto(cpu_allocator_, tensor_proto)) { + if (!parsed.FromProto(host_alloc, tensor_proto)) { return errors::InvalidArgument("Cannot parse tensor from proto: ", tensor_proto.DebugString()); } @@ -98,6 +70,14 @@ Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto, *tensor = parsed; } else { Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); + + // If the tensor is not initialized, we likely ran out of memory. + if (!copy.IsInitialized()) { + return errors::ResourceExhausted( + "OOM when allocating tensor of shape ", parsed.shape().DebugString(), + " and type ", DataTypeString(parsed.dtype())); + } + device_context_->CopyCPUTensorToDevice( &parsed, this, ©, [&status](const Status &s) { status = s; }); *tensor = copy; @@ -119,8 +99,8 @@ Status SYCLDevice::FillContextMap(const Graph *graph, } Status SYCLDevice::Sync() { - sycl_device_->synchronize(); - if (sycl_device_->ok()) { + sycl_allocator_->Synchronize(); + if (sycl_allocator_->Ok()) { return Status::OK(); } else { return errors::Internal("Unknown error detected on device ", name()); |