aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/control_flow_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/control_flow_ops.cc')
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc26
1 files changed, 18 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index e92b11efc6..b01263f288 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -113,9 +113,12 @@ REGISTER_GPU_HOST_REF_KERNEL(string);
#undef REGISTER_GPU_HOST_REF_KERNEL
#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Switch").Device(DEVICE_SYCL).TypeConstraint<type>("T"), SwitchOp)
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Switch") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("pred"), \
+ SwitchOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
@@ -219,9 +222,12 @@ REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_REF_KERNEL
#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Merge").Device(DEVICE_SYCL).TypeConstraint<type>("T"), MergeOp)
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Merge") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("value_index"), \
+ MergeOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
@@ -418,8 +424,12 @@ REGISTER_GPU_HOST_KERNEL(string);
#if TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), NextIterationOp)
+ REGISTER_KERNEL_BUILDER(Name("NextIteration") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ NextIterationOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL