aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/function_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/function_ops.cc')
-rw-r--r--tensorflow/core/kernels/function_ops.cc8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 504d9eeab4..b1e6c90ff2 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -93,7 +93,7 @@ REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
- .Device(DEVICE_GPU)
+ .Device(DEVICE_SYCL)
.HostMemory("output")
.TypeConstraint<int32>("T"),
ArgOp);
@@ -104,7 +104,7 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
RetvalOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
- .Device(DEVICE_GPU)
+ .Device(DEVICE_SYCL)
.HostMemory("input")
.TypeConstraint<int32>("T"),
RetvalOp);
@@ -238,5 +238,9 @@ REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
SymbolicGradientOp);
REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU),
SymbolicGradientOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL),
+ SymbolicGradientOp);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow