aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/control_flow_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/control_flow_ops.cc')
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc136
1 files changed, 86 insertions, 50 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index 203a9a9f24..64c06786bc 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -112,15 +112,14 @@ REGISTER_GPU_HOST_REF_KERNEL(string);
#undef REGISTER_GPU_HOST_KERNEL
#undef REGISTER_GPU_HOST_REF_KERNEL
-#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \
.Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T") \
- .HostMemory("pred"), \
+ .HostMemory("pred") \
+ .TypeConstraint<type>("T"),\
SwitchOp)
-REGISTER_SYCL_KERNEL(bool);
-TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH);
#define REGISTER_SYCL_REF_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
@@ -128,12 +127,41 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
.HostMemory("pred") \
.TypeConstraint<type>("T"), \
SwitchOp)
-REGISTER_SYCL_REF_SWITCH(bool);
-TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
+TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
-#undef REGISTER_SYCL_KERNEL
+#undef REGISTER_SYCL_SWITCH
#undef REGISTER_SYCL_REF_SWITCH
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Switch") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("pred") \
+ .HostMemory("output_false")\
+ .HostMemory("output_true") \
+ .TypeConstraint<type>("T"),\
+ SwitchOp)
+
+REGISTER_SYCL_HOST_KERNEL(bool);
+REGISTER_SYCL_HOST_KERNEL(string);
+REGISTER_SYCL_HOST_KERNEL(int32);
+
+#define REGISTER_SYCL_HOST_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("pred") \
+ .HostMemory("output_false") \
+ .HostMemory("output_true") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+
+REGISTER_SYCL_HOST_REF_KERNEL(int32);
+REGISTER_SYCL_HOST_REF_KERNEL(bool);
+REGISTER_SYCL_HOST_REF_KERNEL(string);
+
+#undef REGISTER_SYCL_HOST_KERNEL
+#undef REGISTER_SYCL_HOST_REF_KERNEL
#endif // TENSORFLOW_USE_SYCL
class RefSelectOp : public OpKernel {
@@ -233,13 +261,13 @@ REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
-#if TENSORFLOW_USE_SYCL
+#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Merge") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.HostMemory("value_index"), \
- MergeOp)
+ MergeOp);
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
@@ -248,9 +276,10 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.HostMemory("value_index"), \
- MergeOp)
+ MergeOp);
REGISTER_SYCL_REF_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
+
#undef REGISTER_SYCL_KERNEL
#undef REGISTER_SYCL_REF_KERNEL
#endif // TENSORFLOW_USE_SYCL
@@ -280,6 +309,30 @@ REGISTER_GPU_HOST_KERNEL(ResourceHandle);
#undef REGISTER_GPU_HOST_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Merge") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("inputs") \
+ .HostMemory("output") \
+ .HostMemory("value_index") \
+ .TypeConstraint<type>("T"), \
+ MergeOp); \
+ REGISTER_KERNEL_BUILDER(Name("RefMerge") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("inputs") \
+ .HostMemory("output") \
+ .HostMemory("value_index") \
+ .TypeConstraint<type>("T"), \
+ MergeOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(string);
+REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
+
+#undef REGISTER_SYCL_HOST_KERNEL
+#endif // TENSORFLOW_USE_SYCL
+
void EnterOp::Compute(OpKernelContext* context) {
if (IsRefType(context->input_dtype(0))) {
context->forward_ref_input_to_ref_output(0, 0);
@@ -306,7 +359,7 @@ REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
-#if TENSORFLOW_USE_SYCL
+#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
@@ -345,7 +398,7 @@ REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
#undef REGISTER_SYCL_HOST_KERNEL
#undef REGISTER_SYCL_HOST_REF_KERNEL
-#endif
+#endif // TENSORFLOW_USE_SYCL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -394,30 +447,25 @@ REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
REGISTER_GPU_KERNEL(bool);
+REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
-#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp)
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp);
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
-#define REGISTER_SYCL_REF_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp)
-REGISTER_SYCL_REF_KERNEL(bool);
-TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
-
#undef REGISTER_SYCL_KERNEL
#undef REGISTER_SYCL_REF_KERNEL
-// Special GPU kernels for int32 and string.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
#define REGISTER_SYCL_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Exit") \
.Device(DEVICE_SYCL) \
@@ -507,31 +555,19 @@ REGISTER_GPU_HOST_KERNEL(string);
#undef REGISTER_GPU_HOST_KERNEL
-#if TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNEL(type) \
- REGISTER_KERNEL_BUILDER(Name("NextIteration") \
- .Device(DEVICE_SYCL) \
- .HostMemory("data") \
- .HostMemory("output") \
- .TypeConstraint<type>("T"), \
- NextIterationOp)
- REGISTER_SYCL_KERNEL(bool);
- TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
-#define REGISTER_SYCL_REF_KERNEL(type) \
- REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
- .Device(DEVICE_SYCL) \
- .HostMemory("data") \
- .HostMemory("output") \
- .TypeConstraint<type>("T"), \
- NextIterationOp)
- REGISTER_SYCL_REF_KERNEL(bool);
- TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ NextIterationOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\
+ NextIterationOp)
+REGISTER_SYCL_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
#undef REGISTER_SYCL_KERNEL
-#undef REGISTER_SYCL_REF_KERNEL
-// Special GPU kernels for int32 and string.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
#define REGISTER_SYCL_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("NextIteration") \
.Device(DEVICE_SYCL) \