aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/sycl/sycl_device.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device.cc')
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device.cc58
1 files changed, 19 insertions, 39 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc
index 2c2185b2c0..17f5edd572 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc
@@ -22,50 +22,18 @@ limitations under the License.
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
-
-static std::unordered_set<SYCLDevice *> live_devices;
-static bool first_time = true;
+std::mutex GSYCLInterface::mutex_;
+GSYCLInterface *GSYCLInterface::s_instance = 0;
void ShutdownSycl() {
- for (auto device : live_devices) {
- device->EnterLameDuckMode();
- }
- live_devices.clear();
+ GSYCLInterface::Reset();
}
void SYCLDevice::RegisterDevice() {
- if (first_time) {
- first_time = false;
atexit(ShutdownSycl);
- }
- live_devices.insert(this);
}
-SYCLDevice::~SYCLDevice() {
- device_context_->Unref();
- sycl_allocator_->EnterLameDuckMode();
- if (sycl_device_) {
- sycl_device_->synchronize();
- delete sycl_device_;
- }
- if (sycl_queue_) {
- delete sycl_queue_;
- }
- live_devices.erase(this);
-}
-
-void SYCLDevice::EnterLameDuckMode() {
- sycl_allocator_->EnterLameDuckMode();
- if (sycl_device_) {
- sycl_device_->synchronize();
- delete sycl_device_;
- sycl_device_ = nullptr;
- }
- if (sycl_queue_) {
- delete sycl_queue_;
- sycl_queue_ = nullptr;
- }
-}
+SYCLDevice::~SYCLDevice() {}
void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) {
assert(context);
@@ -88,8 +56,12 @@ Allocator *SYCLDevice::GetAllocator(AllocatorAttributes attr) {
Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor *tensor) {
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ Allocator* host_alloc = GetAllocator(attr);
+
Tensor parsed(tensor_proto.dtype());
- if (!parsed.FromProto(cpu_allocator_, tensor_proto)) {
+ if (!parsed.FromProto(host_alloc, tensor_proto)) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
tensor_proto.DebugString());
}
@@ -98,6 +70,14 @@ Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto,
*tensor = parsed;
} else {
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
+
+ // If the tensor is not initialized, we likely ran out of memory.
+ if (!copy.IsInitialized()) {
+ return errors::ResourceExhausted(
+ "OOM when allocating tensor of shape ", parsed.shape().DebugString(),
+ " and type ", DataTypeString(parsed.dtype()));
+ }
+
device_context_->CopyCPUTensorToDevice(
&parsed, this, &copy, [&status](const Status &s) { status = s; });
*tensor = copy;
@@ -119,8 +99,8 @@ Status SYCLDevice::FillContextMap(const Graph *graph,
}
Status SYCLDevice::Sync() {
- sycl_device_->synchronize();
- if (sycl_device_->ok()) {
+ sycl_allocator_->Synchronize();
+ if (sycl_allocator_->Ok()) {
return Status::OK();
} else {
return errors::Internal("Unknown error detected on device ", name());