aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dense_update_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/dense_update_ops.cc')
-rw-r--r--tensorflow/core/kernels/dense_update_ops.cc21
1 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc
index baa8f83091..5216a4b5d0 100644
--- a/tensorflow/core/kernels/dense_update_ops.cc
+++ b/tensorflow/core/kernels/dense_update_ops.cc
@@ -97,13 +97,20 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#if TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Assign") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T"), \
- AssignOpT<SYCLDevice, type>);
-TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Assign") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T"), \
+ AssignOpT<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>);
+
+REGISTER_SYCL_KERNEL(float);
#undef REGISTER_SYCL_KERNEL
#endif