aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matmul_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/matmul_op.h')
-rw-r--r--tensorflow/core/kernels/matmul_op.h40
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h
new file mode 100644
index 0000000000..f75b0ded1b
--- /dev/null
+++ b/tensorflow/core/kernels/matmul_op.h
@@ -0,0 +1,40 @@
+#ifndef TENSORFLOW_KERNELS_MATMUL_OP_H_
+#define TENSORFLOW_KERNELS_MATMUL_OP_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Helpers to define tensor<T> needed by MatMul op.
+template <typename T>
+struct MatMulTypes {
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>
+ out_type;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Aligned> in_type;
+};
+
+template <typename Device, typename In0, typename In1, typename Out,
+ typename DimPair>
+void MatMul(const Device& d, Out out, In0 in0, In1 in1,
+ const DimPair& dim_pair) {
+ out.device(d) = in0.contract(in1, dim_pair);
+}
+
+template <typename Device, typename T>
+struct MatMulFunctor {
+ // Computes on device "d": out = in0 * in1, where * is matrix
+ // multiplication.
+ void operator()(
+ const Device& d, typename MatMulTypes<T>::out_type out,
+ typename MatMulTypes<T>::in_type in0,
+ typename MatMulTypes<T>::in_type in1,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair);
+};
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_