aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/control_flow_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/control_flow_ops.cc')
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index 65413a09b2..1a8c17b1ef 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -112,6 +112,15 @@ REGISTER_GPU_HOST_REF_KERNEL(string);
#undef REGISTER_GPU_HOST_KERNEL
#undef REGISTER_GPU_HOST_REF_KERNEL
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Switch").Device(DEVICE_SYCL).TypeConstraint<type>("T"), SwitchOp)
+REGISTER_SYCL_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+#endif
+
class RefSelectOp : public OpKernel {
public:
explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -209,6 +218,15 @@ REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Merge").Device(DEVICE_SYCL).TypeConstraint<type>("T"), MergeOp)
+REGISTER_SYCL_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+#endif
+
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
@@ -259,6 +277,15 @@ REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
+REGISTER_SYCL_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+#endif
+
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
@@ -310,6 +337,15 @@ REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp)
+REGISTER_SYCL_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+#endif
+
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
@@ -380,6 +416,15 @@ REGISTER_GPU_HOST_KERNEL(string);
#undef REGISTER_GPU_HOST_KERNEL
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), NextIterationOp)
+ REGISTER_SYCL_KERNEL(bool);
+ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+#undef REGISTER_SYCL_KERNEL
+#endif
+
// A LoopCond op has one input and one output. The input is a boolean
// scalar representing the taken branches of the "pivot" Switch that
// determines loop termination. As a contract, any high-level front-end