aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_op_sub.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_sub.cc')
-rw-r--r--tensorflow/core/kernels/cwise_op_sub.cc31
1 files changed, 10 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc
index eab1e2a09c..eb173c7040 100644
--- a/tensorflow/core/kernels/cwise_op_sub.cc
+++ b/tensorflow/core/kernels/cwise_op_sub.cc
@@ -24,28 +24,7 @@ REGISTER7(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32,
// int32 version of this op is needed, so explicitly include it.
REGISTER(BinaryOp, CPU, "Sub", functor::sub, int32);
#endif // __ANDROID_TYPES_SLIM__
-#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("Sub") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<TYPE>("T"), \
- BinaryOp<SYCLDevice, functor::sub<TYPE>>);
- REGISTER_SYCL_KERNEL(float);
- REGISTER_SYCL_KERNEL(double);
-// A special GPU kernel for int32.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("Sub")
- .Device(DEVICE_SYCL)
- .HostMemory("x")
- .HostMemory("y")
- .HostMemory("z")
- .TypeConstraint<int32>("T"),
- BinaryOp<CPUDevice, functor::sub<int32>>);
-#undef REGISTER_SYCL_KERNEL
-#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA
REGISTER6(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
complex64, complex128);
@@ -62,4 +41,14 @@ REGISTER_KERNEL_BUILDER(Name("Sub")
BinaryOp<CPUDevice, functor::sub<int32>>);
#endif
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER3(BinaryOp, SYCL, "Sub", functor::sub, float, double, int64);
+REGISTER_KERNEL_BUILDER(Name("Sub")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::sub<int32>>);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow