diff options
Diffstat (limited to 'tensorflow/core/kernels/control_flow_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.cc | 112 |
1 files changed, 108 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index b01263f288..5241a4d916 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -121,8 +121,20 @@ REGISTER_GPU_HOST_REF_KERNEL(string); SwitchOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_SWITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ + .Device(DEVICE_SYCL) \ + .HostMemory("pred") \ + .TypeConstraint<type>("T"), \ + SwitchOp) +REGISTER_SYCL_REF_SWITCH(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH); + #undef REGISTER_SYCL_KERNEL -#endif +#undef REGISTER_SYCL_REF_SWITCH + +#endif // TENSORFLOW_USE_SYCL class RefSelectOp : public OpKernel { public: @@ -230,8 +242,18 @@ REGISTER_GPU_REF_KERNEL(bool); MergeOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefMerge") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .HostMemory("value_index"), \ + MergeOp) +REGISTER_SYCL_REF_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); #undef REGISTER_SYCL_KERNEL -#endif +#undef REGISTER_SYCL_REF_KERNEL +#endif // TENSORFLOW_USE_SYCL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -289,7 +311,15 @@ REGISTER_GPU_REF_KERNEL(bool); Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp) +REGISTER_SYCL_REF_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); + #undef REGISTER_SYCL_KERNEL +#undef REGISTER_SYCL_REF_KERNEL #endif // Special GPU kernels for int32 and string. @@ -349,8 +379,37 @@ REGISTER_GPU_KERNEL(bool); Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp) +REGISTER_SYCL_REF_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); + #undef REGISTER_SYCL_KERNEL -#endif +#undef REGISTER_SYCL_REF_KERNEL + +// Special GPU kernels for int32 and string. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Exit") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("RefExit") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + ExitOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +#undef REGISTER_SYCL_HOST_KERNEL +#endif // TENSORFLOW_USE_SYCL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -432,8 +491,39 @@ REGISTER_GPU_HOST_KERNEL(string); NextIterationOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + NextIterationOp) + REGISTER_SYCL_REF_KERNEL(bool); + TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); #undef REGISTER_SYCL_KERNEL -#endif +#undef REGISTER_SYCL_REF_KERNEL + +// Special GPU kernels for int32 and string. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("NextIteration") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + NextIterationOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +#undef REGISTER_SYCL_HOST_KERNEL +#endif // TENSORFLOW_USE_SYCL // A LoopCond op has one input and one output. The input is a boolean // scalar representing the taken branches of the "pivot" Switch that @@ -461,6 +551,14 @@ REGISTER_KERNEL_BUILDER(Name("LoopCond") .HostMemory("output"), LoopCondOp); +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("LoopCond") + .Device(DEVICE_SYCL) + .HostMemory("input") + .HostMemory("output"), + LoopCondOp); +#endif // TENSORFLOW_USE_SYCL + // ControlTrigger kernels REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU), ControlTriggerOp); @@ -468,6 +566,11 @@ REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU), ControlTriggerOp); +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL), + ControlTriggerOp); +#endif // TENSORFLOW_USE_SYCL + // When called, abort op will abort the current process. This can be used to // abort remote PSs when needed. class AbortOp : public OpKernel { @@ -493,4 +596,5 @@ class AbortOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp); + } // namespace tensorflow |