diff options
Diffstat (limited to 'tensorflow/core/kernels/reshape_op.cc')
-rw-r--r-- | tensorflow/core/kernels/reshape_op.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/reshape_op.cc b/tensorflow/core/kernels/reshape_op.cc index cd6875eeb2..245b324a38 100644 --- a/tensorflow/core/kernels/reshape_op.cc +++ b/tensorflow/core/kernels/reshape_op.cc @@ -31,6 +31,27 @@ REGISTER_KERNEL_BUILDER(Name("Reshape").Device(DEVICE_CPU).HostMemory("shape"), TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Reshape") \ + .Device(DEVICE_SYCL) \ + .HostMemory("shape") \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("Tshape"), \ + ReshapeOp); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); +#undef REGISTER_SYCL_KERNEL + +REGISTER_KERNEL_BUILDER(Name("Reshape") + .Device(DEVICE_SYCL) + .HostMemory("tensor") + .HostMemory("shape") + .HostMemory("output") + .TypeConstraint<int32>("T") + .TypeConstraint<int32>("Tshape"), + ReshapeOp); +#endif // TENSORFLOW_USE_SYCL + #if GOOGLE_CUDA // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel |