diff options
Diffstat (limited to 'tensorflow/core/kernels/identity_op.cc')
-rw-r--r-- | tensorflow/core/kernels/identity_op.cc | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc index 459d329ba4..45d27dd19e 100644 --- a/tensorflow/core/kernels/identity_op.cc +++ b/tensorflow/core/kernels/identity_op.cc @@ -34,6 +34,24 @@ REGISTER_KERNEL_BUILDER(Name("PlaceholderWithDefault").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("RefIdentity").Device(DEVICE_CPU), IdentityOp); +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Identity").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RefIdentity").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("StopGradient").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\ + IdentityOp) + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); +REGISTER_SYCL_KERNEL(bfloat16); + +#undef REGISTER_SYCL_KERNEL +#endif + #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ @@ -50,6 +68,7 @@ REGISTER_GPU_KERNEL(bfloat16); #undef REGISTER_GPU_KERNEL + #if GOOGLE_CUDA // A special GPU kernel for int32 and bool. // TODO(b/25387198): Also enable int32 in device memory. This kernel |