diff options
Diffstat (limited to 'tensorflow/core/kernels/identity_op.cc')
-rw-r--r-- | tensorflow/core/kernels/identity_op.cc | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc index b482099c69..0e56a27c84 100644 --- a/tensorflow/core/kernels/identity_op.cc +++ b/tensorflow/core/kernels/identity_op.cc @@ -54,10 +54,29 @@ REGISTER_KERNEL_BUILDER(Name("RefIdentity").Device(DEVICE_CPU), IdentityOp); IdentityOp) TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); -REGISTER_SYCL_KERNEL(bfloat16); #undef REGISTER_SYCL_KERNEL -#endif + +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Identity") \ + .Device(DEVICE_SYCL) \ + .HostMemory("input") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + IdentityOp); \ + REGISTER_KERNEL_BUILDER(Name("RefIdentity") \ + .Device(DEVICE_SYCL) \ + .HostMemory("input") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + IdentityOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(bool); + +#undef REGISTER_SYCL_HOST_KERNEL + +#endif // TENSORFLOW_USE_SYCL #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ |