aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/sycl/sycl_device.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device.h')
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device.h7
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.h b/tensorflow/core/common_runtime/sycl/sycl_device.h
index 2759053df5..db208984f6 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device.h
+++ b/tensorflow/core/common_runtime/sycl/sycl_device.h
@@ -20,8 +20,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
-#define EIGEN_USE_SYCL
-
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/common_runtime/sycl/sycl_allocator.h"
#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
@@ -45,10 +43,13 @@ public:
sycl_allocator_(new SYCLAllocator(sycl_queue_)),
device_context_(new SYCLDeviceContext()) {
set_eigen_sycl_device(sycl_device_);
+ RegisterDevice();
}
~SYCLDevice() override;
+ void EnterLameDuckMode();
+
void Compute(OpKernel *op_kernel, OpKernelContext *context) override;
Allocator *GetAllocator(AllocatorAttributes attr) override;
Status MakeTensorFromProto(const TensorProto &tensor_proto,
@@ -65,6 +66,8 @@ public:
}
private:
+ void RegisterDevice();
+
Allocator *cpu_allocator_; // owned
Eigen::QueueInterface* sycl_queue_; // owned
Eigen::SyclDevice* sycl_device_; // owned