diff options
Diffstat (limited to 'tensorflow/core/kernels/reverse_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/reverse_op_gpu.cu.cc | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/reverse_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_op_gpu.cu.cc index e5f5f2fc51..39ab010627 100644 --- a/tensorflow/core/kernels/reverse_op_gpu.cu.cc +++ b/tensorflow/core/kernels/reverse_op_gpu.cu.cc @@ -25,24 +25,30 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_REVERSE(DIM) \ - template struct functor::Reverse<GPUDevice, uint8, DIM>; \ - template struct functor::Reverse<GPUDevice, int8, DIM>; \ - template struct functor::Reverse<GPUDevice, int32, DIM>; \ - template struct functor::Reverse<GPUDevice, bool, DIM>; \ - template struct functor::Reverse<GPUDevice, Eigen::half, DIM>; \ - template struct functor::Reverse<GPUDevice, float, DIM>; \ - template struct functor::Reverse<GPUDevice, double, DIM>; -DEFINE_REVERSE(0) -DEFINE_REVERSE(1) -DEFINE_REVERSE(2) -DEFINE_REVERSE(3) -DEFINE_REVERSE(4) -DEFINE_REVERSE(5) -DEFINE_REVERSE(6) -DEFINE_REVERSE(7) -DEFINE_REVERSE(8) +#define DEFINE_REVERSE(T, DIM) \ + template struct functor::Reverse<GPUDevice, T, DIM>; +#define DEFINE_REVERSE_ALL_DIMS(T) \ + DEFINE_REVERSE(T, 0) \ + DEFINE_REVERSE(T, 1) \ + DEFINE_REVERSE(T, 2) \ + DEFINE_REVERSE(T, 3) \ + DEFINE_REVERSE(T, 4) \ + DEFINE_REVERSE(T, 5) \ + DEFINE_REVERSE(T, 6) \ + DEFINE_REVERSE(T, 7) \ + DEFINE_REVERSE(T, 8) + +TF_CALL_uint8(DEFINE_REVERSE_ALL_DIMS); +TF_CALL_int8(DEFINE_REVERSE_ALL_DIMS); +TF_CALL_int32(DEFINE_REVERSE_ALL_DIMS); +TF_CALL_bool(DEFINE_REVERSE_ALL_DIMS); +TF_CALL_half(DEFINE_REVERSE_ALL_DIMS); +TF_CALL_float(DEFINE_REVERSE_ALL_DIMS); +TF_CALL_double(DEFINE_REVERSE_ALL_DIMS); +TF_CALL_complex64(DEFINE_REVERSE_ALL_DIMS); +TF_CALL_complex128(DEFINE_REVERSE_ALL_DIMS); #undef DEFINE_REVERSE +#undef DEFINE_REVERSE_ALL_DIMS } // namespace tensorflow |