diff options
Diffstat (limited to 'tensorflow/core/kernels/dense_update_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/dense_update_ops.cc | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index baa8f83091..5216a4b5d0 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -97,13 +97,20 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); #if TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Assign") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint<type>("T"), \ - AssignOpT<SYCLDevice, type>); -TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Assign") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T"), \ + AssignOpT<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>); + +REGISTER_SYCL_KERNEL(float); #undef REGISTER_SYCL_KERNEL #endif |