aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/function_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/function_ops.cc')
-rw-r--r--tensorflow/core/kernels/function_ops.cc23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 4a08f98b33..7cb9a3a657 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -87,6 +87,29 @@ class RetvalOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
+#if TENSORFLOW_USE_SYCL
+#define REGISTER(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
+ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
+ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
+ .Device(DEVICE_GPU)
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ ArgOp);
+#undef REGISTER
+#define REGISTER(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
+ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
+ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
+ .Device(DEVICE_GPU)
+ .HostMemory("input")
+ .TypeConstraint<int32>("T"),
+ RetvalOp);
+#undef REGISTER
+#endif
+
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
Name("_Arg").Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);