diff options
author | 2016-01-19 11:11:48 -0800 | |
---|---|---|
committer | 2016-01-20 07:48:27 -0800 | |
commit | ee274fdd0af515a09a3a45a1d31065b3edd06087 (patch) | |
tree | 9084c8ea21da984f1c659c9ee309b01fc9fc4eb0 | |
parent | 757eaa78fe3de8193f330829a6d0ec8f7220d5e0 (diff) |
Clean up GPU kernel registrations for control flow ops.
Change: 112492676
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.cc | 163 |
1 files changed, 102 insertions, 61 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 978161329a..8b539a7751 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -86,12 +86,9 @@ class SwitchOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH); TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SWITCH); -REGISTER_GPU_SWITCH(int64); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH); REGISTER_GPU_SWITCH(bool); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_REF_SWITCH); -REGISTER_GPU_REF_SWITCH(int32); -REGISTER_GPU_REF_SWITCH(int64); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH); REGISTER_GPU_REF_SWITCH(bool); #undef REGISTER_CPU_SWITCH @@ -99,17 +96,36 @@ REGISTER_GPU_REF_SWITCH(bool); #undef REGISTER_GPU_SWITCH #undef REGISTER_GPU_REF_SWITCH -// A special GPU kernel for int32. +// Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. -REGISTER_KERNEL_BUILDER(Name("Switch") - .Device(DEVICE_GPU) - .HostMemory("data") - .HostMemory("pred") - .HostMemory("output_false") - .HostMemory("output_true") - .TypeConstraint<int32>("T"), - SwitchOp); +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Switch") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("pred") \ + .HostMemory("output_false") \ + .HostMemory("output_true") \ + .TypeConstraint<type>("T"), \ + SwitchOp) + +#define REGISTER_GPU_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("pred") \ + .HostMemory("output_false") \ + .HostMemory("output_true") \ + .TypeConstraint<type>("T"), \ + SwitchOp) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_REF_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(string); +REGISTER_GPU_HOST_REF_KERNEL(string); + +#undef REGISTER_GPU_HOST_KERNEL +#undef REGISTER_GPU_HOST_REF_KERNEL class RefSelectOp : public OpKernel { public: @@ -199,22 +215,29 @@ REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp); .Device(DEVICE_GPU) \ .TypeConstraint<type>("T") \ .HostMemory("value_index"), \ - MergeOp); + MergeOp) TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +REGISTER_GPU_KERNEL(bool); #undef REGISTER_GPU_KERNEL -// A special GPU kernel for int32. +// Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. -REGISTER_KERNEL_BUILDER(Name("Merge") - .Device(DEVICE_GPU) - .HostMemory("inputs") - .HostMemory("output") - .HostMemory("value_index") - .TypeConstraint<int32>("T"), - MergeOp); +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Merge") \ + .Device(DEVICE_GPU) \ + .HostMemory("inputs") \ + .HostMemory("output") \ + .HostMemory("value_index") \ + .TypeConstraint<type>("T"), \ + MergeOp) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(string); + +#undef REGISTER_GPU_HOST_KERNEL // An enter op has one input and one output. It creates or finds // the child frame that is uniquely identified by the frame_name, @@ -243,41 +266,45 @@ REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp); #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ - Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp); + Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp) #define REGISTER_GPU_REF_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ - Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp); + Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp) TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); -TF_CALL_NUMBER_TYPES(REGISTER_GPU_REF_KERNEL); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); +REGISTER_GPU_KERNEL(bool); +REGISTER_GPU_REF_KERNEL(bool); #undef REGISTER_GPU_KERNEL #undef REGISTER_GPU_REF_KERNEL -// A special GPU kernel for int32. +// Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. -REGISTER_KERNEL_BUILDER(Name("Enter") - .Device(DEVICE_GPU) - .HostMemory("data") - .HostMemory("output") - .TypeConstraint<int32>("T"), - EnterOp); +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Enter") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + EnterOp) -// Special GPU kernels for string. -REGISTER_KERNEL_BUILDER(Name("Enter") - .Device(DEVICE_GPU) - .HostMemory("data") - .HostMemory("output") - .TypeConstraint<string>("T"), - EnterOp); +#define REGISTER_GPU_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefEnter") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + EnterOp) -REGISTER_KERNEL_BUILDER(Name("RefEnter") - .Device(DEVICE_GPU) - .HostMemory("data") - .HostMemory("output") - .TypeConstraint<string>("T"), - EnterOp); +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_REF_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(string); +REGISTER_GPU_HOST_REF_KERNEL(string); + +#undef REGISTER_GPU_HOST_KERNEL +#undef REGISTER_GPU_HOST_REF_KERNEL // An exit op has one input and one output. It exits the current // frame to its parent frame, and makes its input available to the @@ -304,18 +331,25 @@ REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp); Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +REGISTER_GPU_KERNEL(bool); #undef REGISTER_GPU_KERNEL -// A special GPU kernel for int32. +// Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. -REGISTER_KERNEL_BUILDER(Name("Exit") - .Device(DEVICE_GPU) - .HostMemory("data") - .HostMemory("output") - .TypeConstraint<int32>("T"), - ExitOp); +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Exit") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + ExitOp) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(string); + +#undef REGISTER_GPU_HOST_KERNEL // A next_iteration op has one input and one output. It makes its input // available to the next iteration. @@ -340,21 +374,28 @@ REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU), #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - NextIterationOp); + NextIterationOp) TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +REGISTER_GPU_KERNEL(bool); #undef REGISTER_GPU_KERNEL -// A special GPU kernel for int32. +// Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. -REGISTER_KERNEL_BUILDER(Name("NextIteration") - .Device(DEVICE_GPU) - .HostMemory("data") - .HostMemory("output") - .TypeConstraint<int32>("T"), - NextIterationOp); +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("NextIteration") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + NextIterationOp) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(string); + +#undef REGISTER_GPU_HOST_KERNEL // A LoopCond op has one input and one output. The input is a boolean // scalar representing the taken branches of the "pivot" Switch that |