diff options
Diffstat (limited to 'tensorflow/core/kernels/constant_op.cc')
-rw-r--r-- | tensorflow/core/kernels/constant_op.cc | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 115a842d1c..d444ddec1d 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -57,7 +57,10 @@ REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp); REGISTER_KERNEL_BUILDER( \ Name("Const").Device(DEVICE_SYCL).TypeConstraint<TYPE>("dtype"), \ ConstantOp); -TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); +REGISTER_SYCL_KERNEL(bool); +REGISTER_SYCL_KERNEL(int64); #undef REGISTER_SYCL_KERNEL #endif @@ -112,6 +115,17 @@ REGISTER_KERNEL_BUILDER(Name("Const") HostConstantOp); #endif +#ifdef TENSORFLOW_USE_SYCL +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Const") + .Device(DEVICE_SYCL) + .HostMemory("output") + .TypeConstraint<int32>("dtype"), + HostConstantOp); +#endif // TENSORFLOW_USE_SYCL + typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL @@ -186,6 +200,7 @@ REGISTER_KERNEL(CPU, quint8); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL(SYCL, float) +REGISTER_KERNEL(SYCL, double) REGISTER_KERNEL_BUILDER(Name("Fill") .Device(DEVICE_SYCL) .TypeConstraint<int32>("T") @@ -246,6 +261,7 @@ TF_CALL_POD_STRING_TYPES(REGISTER_CPU); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL(float, SYCL); +REGISTER_KERNEL(bool, SYCL); REGISTER_KERNEL_BUILDER(Name("ZerosLike") .Device(DEVICE_SYCL) .TypeConstraint<int32>("T") |