aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/reverse_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/reverse_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/reverse_op_gpu.cu.cc40
1 files changed, 23 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/reverse_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_op_gpu.cu.cc
index e5f5f2fc51..39ab010627 100644
--- a/tensorflow/core/kernels/reverse_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/reverse_op_gpu.cu.cc
@@ -25,24 +25,30 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-#define DEFINE_REVERSE(DIM) \
- template struct functor::Reverse<GPUDevice, uint8, DIM>; \
- template struct functor::Reverse<GPUDevice, int8, DIM>; \
- template struct functor::Reverse<GPUDevice, int32, DIM>; \
- template struct functor::Reverse<GPUDevice, bool, DIM>; \
- template struct functor::Reverse<GPUDevice, Eigen::half, DIM>; \
- template struct functor::Reverse<GPUDevice, float, DIM>; \
- template struct functor::Reverse<GPUDevice, double, DIM>;
-DEFINE_REVERSE(0)
-DEFINE_REVERSE(1)
-DEFINE_REVERSE(2)
-DEFINE_REVERSE(3)
-DEFINE_REVERSE(4)
-DEFINE_REVERSE(5)
-DEFINE_REVERSE(6)
-DEFINE_REVERSE(7)
-DEFINE_REVERSE(8)
+#define DEFINE_REVERSE(T, DIM) \
+ template struct functor::Reverse<GPUDevice, T, DIM>;
+#define DEFINE_REVERSE_ALL_DIMS(T) \
+ DEFINE_REVERSE(T, 0) \
+ DEFINE_REVERSE(T, 1) \
+ DEFINE_REVERSE(T, 2) \
+ DEFINE_REVERSE(T, 3) \
+ DEFINE_REVERSE(T, 4) \
+ DEFINE_REVERSE(T, 5) \
+ DEFINE_REVERSE(T, 6) \
+ DEFINE_REVERSE(T, 7) \
+ DEFINE_REVERSE(T, 8)
+
+TF_CALL_uint8(DEFINE_REVERSE_ALL_DIMS);
+TF_CALL_int8(DEFINE_REVERSE_ALL_DIMS);
+TF_CALL_int32(DEFINE_REVERSE_ALL_DIMS);
+TF_CALL_bool(DEFINE_REVERSE_ALL_DIMS);
+TF_CALL_half(DEFINE_REVERSE_ALL_DIMS);
+TF_CALL_float(DEFINE_REVERSE_ALL_DIMS);
+TF_CALL_double(DEFINE_REVERSE_ALL_DIMS);
+TF_CALL_complex64(DEFINE_REVERSE_ALL_DIMS);
+TF_CALL_complex128(DEFINE_REVERSE_ALL_DIMS);
#undef DEFINE_REVERSE
+#undef DEFINE_REVERSE_ALL_DIMS
} // namespace tensorflow