aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matmul_op_gpu.cu.cc
blob: 17107ce5df0c05c9f7efe3860baadcd1351f629c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#if GOOGLE_CUDA

#define EIGEN_USE_GPU

#include "tensorflow/core/kernels/matmul_op.h"

namespace tensorflow {
namespace functor {

typedef Eigen::GpuDevice GPUDevice;

// Partial specialization MatMulTensorFunctor<Device=GPUDevice, T>
template <typename T>
struct MatMulFunctor<GPUDevice, T> {
  void operator()(
      const GPUDevice& d, typename MatMulTypes<T>::out_type out,
      typename MatMulTypes<T>::in_type in0,
      typename MatMulTypes<T>::in_type in1,
      const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
    MatMul<GPUDevice>(d, To32Bit(out), To32Bit(in0), To32Bit(in1), dim_pair);
  }
};

#define DEFINE(T) template struct MatMulFunctor<GPUDevice, T>;
DEFINE(float);
// DEFINE(double);  // Does not compile 1/2015.
#undef DEFINE

}  // end namespace functor
}  // end namespace tensorflow

#endif  // GOOGLE_CUDA