diff options
Diffstat (limited to 'tensorflow/core/kernels/matmul_op.h')
-rw-r--r-- | tensorflow/core/kernels/matmul_op.h | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h new file mode 100644 index 0000000000..f75b0ded1b --- /dev/null +++ b/tensorflow/core/kernels/matmul_op.h @@ -0,0 +1,40 @@ +#ifndef TENSORFLOW_KERNELS_MATMUL_OP_H_ +#define TENSORFLOW_KERNELS_MATMUL_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +// Helpers to define tensor<T> needed by MatMul op. +template <typename T> +struct MatMulTypes { + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned> + out_type; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Aligned> in_type; +}; + +template <typename Device, typename In0, typename In1, typename Out, + typename DimPair> +void MatMul(const Device& d, Out out, In0 in0, In1 in1, + const DimPair& dim_pair) { + out.device(d) = in0.contract(in1, dim_pair); +} + +template <typename Device, typename T> +struct MatMulFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, typename MatMulTypes<T>::out_type out, + typename MatMulTypes<T>::in_type in0, + typename MatMulTypes<T>::in_type in1, + const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair); +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_ |