aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-19 11:11:48 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2016-01-20 07:48:27 -0800
commitee274fdd0af515a09a3a45a1d31065b3edd06087 (patch)
tree9084c8ea21da984f1c659c9ee309b01fc9fc4eb0
parent757eaa78fe3de8193f330829a6d0ec8f7220d5e0 (diff)
Clean up GPU kernel registrations for control flow ops.
Change: 112492676
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc163
1 files changed, 102 insertions, 61 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index 978161329a..8b539a7751 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -86,12 +86,9 @@ class SwitchOp : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SWITCH);
-REGISTER_GPU_SWITCH(int64);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
REGISTER_GPU_SWITCH(bool);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_REF_SWITCH);
-REGISTER_GPU_REF_SWITCH(int32);
-REGISTER_GPU_REF_SWITCH(int64);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
REGISTER_GPU_REF_SWITCH(bool);
#undef REGISTER_CPU_SWITCH
@@ -99,17 +96,36 @@ REGISTER_GPU_REF_SWITCH(bool);
#undef REGISTER_GPU_SWITCH
#undef REGISTER_GPU_REF_SWITCH
-// A special GPU kernel for int32.
+// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("Switch")
- .Device(DEVICE_GPU)
- .HostMemory("data")
- .HostMemory("pred")
- .HostMemory("output_false")
- .HostMemory("output_true")
- .TypeConstraint<int32>("T"),
- SwitchOp);
+#define REGISTER_GPU_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Switch") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("data") \
+ .HostMemory("pred") \
+ .HostMemory("output_false") \
+ .HostMemory("output_true") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+
+#define REGISTER_GPU_HOST_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("data") \
+ .HostMemory("pred") \
+ .HostMemory("output_false") \
+ .HostMemory("output_true") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+
+REGISTER_GPU_HOST_KERNEL(int32);
+REGISTER_GPU_HOST_REF_KERNEL(int32);
+REGISTER_GPU_HOST_KERNEL(string);
+REGISTER_GPU_HOST_REF_KERNEL(string);
+
+#undef REGISTER_GPU_HOST_KERNEL
+#undef REGISTER_GPU_HOST_REF_KERNEL
class RefSelectOp : public OpKernel {
public:
@@ -199,22 +215,29 @@ REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("value_index"), \
- MergeOp);
+ MergeOp)
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
-// A special GPU kernel for int32.
+// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("Merge")
- .Device(DEVICE_GPU)
- .HostMemory("inputs")
- .HostMemory("output")
- .HostMemory("value_index")
- .TypeConstraint<int32>("T"),
- MergeOp);
+#define REGISTER_GPU_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Merge") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("inputs") \
+ .HostMemory("output") \
+ .HostMemory("value_index") \
+ .TypeConstraint<type>("T"), \
+ MergeOp)
+
+REGISTER_GPU_HOST_KERNEL(int32);
+REGISTER_GPU_HOST_KERNEL(string);
+
+#undef REGISTER_GPU_HOST_KERNEL
// An enter op has one input and one output. It creates or finds
// the child frame that is uniquely identified by the frame_name,
@@ -243,41 +266,45 @@ REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
- Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp);
+ Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
#define REGISTER_GPU_REF_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
- Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp);
+ Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
-TF_CALL_NUMBER_TYPES(REGISTER_GPU_REF_KERNEL);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
+REGISTER_GPU_KERNEL(bool);
+REGISTER_GPU_REF_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
-// A special GPU kernel for int32.
+// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("Enter")
- .Device(DEVICE_GPU)
- .HostMemory("data")
- .HostMemory("output")
- .TypeConstraint<int32>("T"),
- EnterOp);
+#define REGISTER_GPU_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Enter") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnterOp)
-// Special GPU kernels for string.
-REGISTER_KERNEL_BUILDER(Name("Enter")
- .Device(DEVICE_GPU)
- .HostMemory("data")
- .HostMemory("output")
- .TypeConstraint<string>("T"),
- EnterOp);
+#define REGISTER_GPU_HOST_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefEnter") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnterOp)
-REGISTER_KERNEL_BUILDER(Name("RefEnter")
- .Device(DEVICE_GPU)
- .HostMemory("data")
- .HostMemory("output")
- .TypeConstraint<string>("T"),
- EnterOp);
+REGISTER_GPU_HOST_KERNEL(int32);
+REGISTER_GPU_HOST_REF_KERNEL(int32);
+REGISTER_GPU_HOST_KERNEL(string);
+REGISTER_GPU_HOST_REF_KERNEL(string);
+
+#undef REGISTER_GPU_HOST_KERNEL
+#undef REGISTER_GPU_HOST_REF_KERNEL
// An exit op has one input and one output. It exits the current
// frame to its parent frame, and makes its input available to the
@@ -304,18 +331,25 @@ REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp);
Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
-// A special GPU kernel for int32.
+// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("Exit")
- .Device(DEVICE_GPU)
- .HostMemory("data")
- .HostMemory("output")
- .TypeConstraint<int32>("T"),
- ExitOp);
+#define REGISTER_GPU_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Exit") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ ExitOp)
+
+REGISTER_GPU_HOST_KERNEL(int32);
+REGISTER_GPU_HOST_KERNEL(string);
+
+#undef REGISTER_GPU_HOST_KERNEL
// A next_iteration op has one input and one output. It makes its input
// available to the next iteration.
@@ -340,21 +374,28 @@ REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU),
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- NextIterationOp);
+ NextIterationOp)
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+REGISTER_GPU_KERNEL(bool);
#undef REGISTER_GPU_KERNEL
-// A special GPU kernel for int32.
+// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("NextIteration")
- .Device(DEVICE_GPU)
- .HostMemory("data")
- .HostMemory("output")
- .TypeConstraint<int32>("T"),
- NextIterationOp);
+#define REGISTER_GPU_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("NextIteration") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("data") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ NextIterationOp)
+
+REGISTER_GPU_HOST_KERNEL(int32);
+REGISTER_GPU_HOST_KERNEL(string);
+
+#undef REGISTER_GPU_HOST_KERNEL
// A LoopCond op has one input and one output. The input is a boolean
// scalar representing the taken branches of the "pivot" Switch that