diff options
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device.h')
-rw-r--r-- | tensorflow/core/common_runtime/sycl/sycl_device.h | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.h b/tensorflow/core/common_runtime/sycl/sycl_device.h index 2759053df5..db208984f6 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.h +++ b/tensorflow/core/common_runtime/sycl/sycl_device.h @@ -20,8 +20,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_ -#define EIGEN_USE_SYCL - #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/sycl/sycl_allocator.h" #include "tensorflow/core/common_runtime/sycl/sycl_device_context.h" @@ -45,10 +43,13 @@ public: sycl_allocator_(new SYCLAllocator(sycl_queue_)), device_context_(new SYCLDeviceContext()) { set_eigen_sycl_device(sycl_device_); + RegisterDevice(); } ~SYCLDevice() override; + void EnterLameDuckMode(); + void Compute(OpKernel *op_kernel, OpKernelContext *context) override; Allocator *GetAllocator(AllocatorAttributes attr) override; Status MakeTensorFromProto(const TensorProto &tensor_proto, @@ -65,6 +66,8 @@ public: } private: + void RegisterDevice(); + Allocator *cpu_allocator_; // owned Eigen::QueueInterface* sycl_queue_; // owned Eigen::SyclDevice* sycl_device_; // owned |