diff options
Diffstat (limited to 'tensorflow/core/kernels/function_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/function_ops.cc | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 4a08f98b33..7cb9a3a657 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -87,6 +87,29 @@ class RetvalOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp); REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp); +#if TENSORFLOW_USE_SYCL +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp); + TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) + TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg") + .Device(DEVICE_GPU) + .HostMemory("output") + .TypeConstraint<int32>("T"), + ArgOp); +#undef REGISTER +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp); + TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) + TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval") + .Device(DEVICE_GPU) + .HostMemory("input") + .TypeConstraint<int32>("T"), + RetvalOp); +#undef REGISTER +#endif + #define REGISTER(type) \ REGISTER_KERNEL_BUILDER( \ Name("_Arg").Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp); |