aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matmul_op.h
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-07-21 09:22:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 09:38:31 -0700
commit3e3306ef0009b5b21050139f9b8e5f4868c4c0c7 (patch)
treec7e25f278d93e9ce1ab9e2984df7b97c0f27c6d0 /tensorflow/core/kernels/matmul_op.h
parent4729180d24af3126d736a7045c43fcbf031b5bef (diff)
Let GetBlasGemmAlgorithms() always return true.
PiperOrigin-RevId: 162748507
Diffstat (limited to 'tensorflow/core/kernels/matmul_op.h')
-rw-r--r--tensorflow/core/kernels/matmul_op.h64
1 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h
index 5a8db6da19..6398da2fb9 100644
--- a/tensorflow/core/kernels/matmul_op.h
+++ b/tensorflow/core/kernels/matmul_op.h
@@ -17,7 +17,9 @@ limitations under the License.
#define TENSORFLOW_KERNELS_MATMUL_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace functor {
@@ -50,6 +52,68 @@ struct MatMulFunctor {
};
} // end namespace functor
+
+#if GOOGLE_CUDA
+// Encapsulate all the shape information that is used in matmul operations.
+class MatmulParameters {
+ public:
+ MatmulParameters(bool transa, bool transb, uint64 m, uint64 n, uint64 k,
+ DataType dtype, int device_id)
+ : transa_(transa),
+ transb_(transb),
+ m_(m),
+ n_(n),
+ k_(k),
+ dtype_(dtype),
+ device_id_(device_id) {
+ hash_code_ = transa;
+ hash_code_ = Hash64Combine(hash_code_, transb);
+ hash_code_ = Hash64Combine(hash_code_, m);
+ hash_code_ = Hash64Combine(hash_code_, n);
+ hash_code_ = Hash64Combine(hash_code_, k);
+ hash_code_ = Hash64Combine(hash_code_, dtype);
+ hash_code_ = Hash64Combine(hash_code_, device_id);
+ }
+ bool operator==(const MatmulParameters& other) const {
+ return this->get_data_as_tuple() == other.get_data_as_tuple();
+ }
+
+ bool operator!=(const MatmulParameters& other) const {
+ return !(*this == other);
+ }
+ uint64 hash() const { return hash_code_; }
+
+ string ToString() const {
+ // clang-format off
+ return strings::StrCat(
+ transa_, ", ", transb_, ", ",
+ m_, ", ", n_, ", ", k_,
+ dtype_, ", ", device_id_);
+ // clang-format on
+ }
+
+ private:
+ typedef std::tuple<bool, bool, int64, int64, int64, DataType, int>
+ ParameterDataType;
+
+ ParameterDataType get_data_as_tuple() const {
+ return std::make_tuple(transa_, transb_, m_, n_, k_, dtype_, device_id_);
+ }
+
+ bool transa_;
+ bool transb_;
+ uint64 m_;
+ uint64 n_;
+ uint64 k_;
+ DataType dtype_;
+ int device_id_;
+ uint64 hash_code_;
+};
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#endif // GOOGLE_CUDA
+
} // end namespace tensorflow
#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_