aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/transpose_op_gpu.cu.cc
blob: 8c04a6544ef9ab3361b3c9565c24d865bd44c52a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#if GOOGLE_CUDA

#define EIGEN_USE_GPU

#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/kernels/transpose_op_functor.h"

namespace tensorflow {
namespace functor {

template <typename T, int NDIMS>
struct TransposeFunctor<Eigen::GpuDevice, T, NDIMS> {
  void operator()(const Eigen::GpuDevice& d,
                  typename TTypes<T, NDIMS>::Tensor out,
                  typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) {
    Transpose<Eigen::GpuDevice, T, NDIMS>(d, out, in, perm);
  }
};

#define DEFINE(T, N) template struct TransposeFunctor<Eigen::GpuDevice, T, N>;
#define DEFINE_DIM(T) \
  DEFINE(T, 1);       \
  DEFINE(T, 2);       \
  DEFINE(T, 3);       \
  DEFINE(T, 4);       \
  DEFINE(T, 5);       \
  DEFINE(T, 6);       \
  DEFINE(T, 7);       \
  DEFINE(T, 8);
DEFINE_DIM(uint8);
DEFINE_DIM(int8);
DEFINE_DIM(int16);
DEFINE_DIM(int32);
DEFINE_DIM(int64);
DEFINE_DIM(float);
DEFINE_DIM(double);
#undef DEFINE_DIM
#undef DEFINE

}  // end namespace functor
}  // end namespace tensorflow

#endif  // GOOGLE_CUDA