diff options
Diffstat (limited to 'tensorflow/core/kernels/dense_update_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/dense_update_ops.cc | 45 |
1 files changed, 25 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index 767f143727..33991fa1f9 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -126,6 +126,9 @@ class DenseUpdateOp : public OpKernel { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL #define REGISTER_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ @@ -136,26 +139,6 @@ TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -#if TENSORFLOW_USE_SYCL -typedef Eigen::SyclDevice SYCLDevice; -#define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Assign") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint<type>("T"), \ - AssignOpT<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>); \ - REGISTER_KERNEL_BUILDER( \ - Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>); - -REGISTER_SYCL_KERNEL(float); -REGISTER_SYCL_KERNEL(double); -#undef REGISTER_SYCL_KERNEL -#endif - #if GOOGLE_CUDA // Only register 'Assign' on GPU for the subset of types also supported by // 'Variable' (see variable_ops.cc.) @@ -175,6 +158,16 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(type) \ +REGISTER_KERNEL_BUILDER( \ + Name("Assign").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + AssignOpT<SYCLDevice, type>); + +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); +#undef REGISTER_SYCL_KERNELS +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ @@ -214,4 +207,16 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // end GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>); + +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); +#undef REGISTER_SYCL_KERNELS +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow |