diff options
Diffstat (limited to 'tensorflow/core/kernels/control_flow_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.cc | 136 |
1 files changed, 86 insertions, 50 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 203a9a9f24..64c06786bc 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -112,15 +112,14 @@ REGISTER_GPU_HOST_REF_KERNEL(string); #undef REGISTER_GPU_HOST_KERNEL #undef REGISTER_GPU_HOST_REF_KERNEL -#if TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(type) \ +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_SWITCH(type) \ REGISTER_KERNEL_BUILDER(Name("Switch") \ .Device(DEVICE_SYCL) \ - .TypeConstraint<type>("T") \ - .HostMemory("pred"), \ + .HostMemory("pred") \ + .TypeConstraint<type>("T"),\ SwitchOp) -REGISTER_SYCL_KERNEL(bool); -TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); +TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH); #define REGISTER_SYCL_REF_SWITCH(type) \ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ @@ -128,12 +127,41 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); .HostMemory("pred") \ .TypeConstraint<type>("T"), \ SwitchOp) -REGISTER_SYCL_REF_SWITCH(bool); -TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH); +TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH); -#undef REGISTER_SYCL_KERNEL +#undef REGISTER_SYCL_SWITCH #undef REGISTER_SYCL_REF_SWITCH +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Switch") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("pred") \ + .HostMemory("output_false")\ + .HostMemory("output_true") \ + .TypeConstraint<type>("T"),\ + SwitchOp) + +REGISTER_SYCL_HOST_KERNEL(bool); +REGISTER_SYCL_HOST_KERNEL(string); +REGISTER_SYCL_HOST_KERNEL(int32); + +#define REGISTER_SYCL_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("pred") \ + .HostMemory("output_false") \ + .HostMemory("output_true") \ + .TypeConstraint<type>("T"), \ + SwitchOp) + +REGISTER_SYCL_HOST_REF_KERNEL(int32); +REGISTER_SYCL_HOST_REF_KERNEL(bool); +REGISTER_SYCL_HOST_REF_KERNEL(string); + +#undef REGISTER_SYCL_HOST_KERNEL +#undef REGISTER_SYCL_HOST_REF_KERNEL #endif // TENSORFLOW_USE_SYCL class RefSelectOp : public OpKernel { @@ -233,13 +261,13 @@ REGISTER_GPU_REF_KERNEL(bool); #undef REGISTER_GPU_KERNEL #undef REGISTER_GPU_REF_KERNEL -#if TENSORFLOW_USE_SYCL +#ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("Merge") \ .Device(DEVICE_SYCL) \ .TypeConstraint<type>("T") \ .HostMemory("value_index"), \ - MergeOp) + MergeOp); REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); @@ -248,9 +276,10 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); .Device(DEVICE_SYCL) \ .TypeConstraint<type>("T") \ .HostMemory("value_index"), \ - MergeOp) + MergeOp); REGISTER_SYCL_REF_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); + #undef REGISTER_SYCL_KERNEL #undef REGISTER_SYCL_REF_KERNEL #endif // TENSORFLOW_USE_SYCL @@ -280,6 +309,30 @@ REGISTER_GPU_HOST_KERNEL(ResourceHandle); #undef REGISTER_GPU_HOST_KERNEL +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Merge") \ + .Device(DEVICE_SYCL) \ + .HostMemory("inputs") \ + .HostMemory("output") \ + .HostMemory("value_index") \ + .TypeConstraint<type>("T"), \ + MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("RefMerge") \ + .Device(DEVICE_SYCL) \ + .HostMemory("inputs") \ + .HostMemory("output") \ + .HostMemory("value_index") \ + .TypeConstraint<type>("T"), \ + MergeOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +REGISTER_SYCL_HOST_KERNEL(ResourceHandle); + +#undef REGISTER_SYCL_HOST_KERNEL +#endif // TENSORFLOW_USE_SYCL + void EnterOp::Compute(OpKernelContext* context) { if (IsRefType(context->input_dtype(0))) { context->forward_ref_input_to_ref_output(0, 0); @@ -306,7 +359,7 @@ REGISTER_GPU_REF_KERNEL(bool); #undef REGISTER_GPU_KERNEL #undef REGISTER_GPU_REF_KERNEL -#if TENSORFLOW_USE_SYCL +#ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp) @@ -345,7 +398,7 @@ REGISTER_SYCL_HOST_KERNEL(ResourceHandle); #undef REGISTER_SYCL_HOST_KERNEL #undef REGISTER_SYCL_HOST_REF_KERNEL -#endif +#endif // TENSORFLOW_USE_SYCL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -394,30 +447,25 @@ REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp); Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); REGISTER_GPU_KERNEL(bool); +REGISTER_GPU_REF_KERNEL(bool); #undef REGISTER_GPU_KERNEL #undef REGISTER_GPU_REF_KERNEL -#if TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp) +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); -#define REGISTER_SYCL_REF_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp) -REGISTER_SYCL_REF_KERNEL(bool); -TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); - #undef REGISTER_SYCL_KERNEL #undef REGISTER_SYCL_REF_KERNEL -// Special GPU kernels for int32 and string. -// TODO(b/25387198): Also enable int32 in device memory. This kernel -// registration requires all int32 inputs and outputs to be in host memory. #define REGISTER_SYCL_HOST_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("Exit") \ .Device(DEVICE_SYCL) \ @@ -507,31 +555,19 @@ REGISTER_GPU_HOST_KERNEL(string); #undef REGISTER_GPU_HOST_KERNEL -#if TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER(Name("NextIteration") \ - .Device(DEVICE_SYCL) \ - .HostMemory("data") \ - .HostMemory("output") \ - .TypeConstraint<type>("T"), \ - NextIterationOp) - REGISTER_SYCL_KERNEL(bool); - TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); -#define REGISTER_SYCL_REF_KERNEL(type) \ - REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ - .Device(DEVICE_SYCL) \ - .HostMemory("data") \ - .HostMemory("output") \ - .TypeConstraint<type>("T"), \ - NextIterationOp) - REGISTER_SYCL_REF_KERNEL(bool); - TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\ + NextIterationOp) +REGISTER_SYCL_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + #undef REGISTER_SYCL_KERNEL -#undef REGISTER_SYCL_REF_KERNEL -// Special GPU kernels for int32 and string. -// TODO(b/25387198): Also enable int32 in device memory. This kernel -// registration requires all int32 inputs and outputs to be in host memory. #define REGISTER_SYCL_HOST_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("NextIteration") \ .Device(DEVICE_SYCL) \ |