aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/control_flow_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/control_flow_ops.cc')
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc112
1 files changed, 108 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index b01263f288..5241a4d916 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -121,8 +121,20 @@ REGISTER_GPU_HOST_REF_KERNEL(string);
SwitchOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_REF_SWITCH(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("pred") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+REGISTER_SYCL_REF_SWITCH(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
+
#undef REGISTER_SYCL_KERNEL
-#endif
+#undef REGISTER_SYCL_REF_SWITCH
+
+#endif // TENSORFLOW_USE_SYCL
class RefSelectOp : public OpKernel {
public:
@@ -230,8 +242,18 @@ REGISTER_GPU_REF_KERNEL(bool);
MergeOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefMerge") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("value_index"), \
+ MergeOp)
+REGISTER_SYCL_REF_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
#undef REGISTER_SYCL_KERNEL
-#endif
+#undef REGISTER_SYCL_REF_KERNEL
+#endif // TENSORFLOW_USE_SYCL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -289,7 +311,15 @@ REGISTER_GPU_REF_KERNEL(bool);
Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
+REGISTER_SYCL_REF_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
+
#undef REGISTER_SYCL_KERNEL
+#undef REGISTER_SYCL_REF_KERNEL
#endif
// Special GPU kernels for int32 and string.
@@ -349,8 +379,37 @@ REGISTER_GPU_KERNEL(bool);
Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#define REGISTER_SYCL_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp)
+REGISTER_SYCL_REF_KERNEL(bool);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
+
#undef REGISTER_SYCL_KERNEL
-#endif
+#undef REGISTER_SYCL_REF_KERNEL
+
+// Special GPU kernels for int32 and string.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Exit") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ ExitOp); \
+ REGISTER_KERNEL_BUILDER(Name("RefExit") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ ExitOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(string);
+#undef REGISTER_SYCL_HOST_KERNEL
+#endif // TENSORFLOW_USE_SYCL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@@ -432,8 +491,39 @@ REGISTER_GPU_HOST_KERNEL(string);
NextIterationOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+#define REGISTER_SYCL_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ NextIterationOp)
+ REGISTER_SYCL_REF_KERNEL(bool);
+ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
#undef REGISTER_SYCL_KERNEL
-#endif
+#undef REGISTER_SYCL_REF_KERNEL
+
+// Special GPU kernels for int32 and string.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("NextIteration") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ NextIterationOp); \
+ REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ NextIterationOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(string);
+#undef REGISTER_SYCL_HOST_KERNEL
+#endif // TENSORFLOW_USE_SYCL
// A LoopCond op has one input and one output. The input is a boolean
// scalar representing the taken branches of the "pivot" Switch that
@@ -461,6 +551,14 @@ REGISTER_KERNEL_BUILDER(Name("LoopCond")
.HostMemory("output"),
LoopCondOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("LoopCond")
+ .Device(DEVICE_SYCL)
+ .HostMemory("input")
+ .HostMemory("output"),
+ LoopCondOp);
+#endif // TENSORFLOW_USE_SYCL
+
// ControlTrigger kernels
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU),
ControlTriggerOp);
@@ -468,6 +566,11 @@ REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU),
ControlTriggerOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL),
+ ControlTriggerOp);
+#endif // TENSORFLOW_USE_SYCL
+
// When called, abort op will abort the current process. This can be used to
// abort remote PSs when needed.
class AbortOp : public OpKernel {
@@ -493,4 +596,5 @@ class AbortOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp);
+
} // namespace tensorflow