diff options
author | Yangzihao Wang <yangzihao@google.com> | 2017-07-21 09:22:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-21 09:38:31 -0700 |
commit | 3e3306ef0009b5b21050139f9b8e5f4868c4c0c7 (patch) | |
tree | c7e25f278d93e9ce1ab9e2984df7b97c0f27c6d0 /tensorflow/core/kernels/matmul_op.h | |
parent | 4729180d24af3126d736a7045c43fcbf031b5bef (diff) |
Let GetBlasGemmAlgorithms() always return true.
PiperOrigin-RevId: 162748507
Diffstat (limited to 'tensorflow/core/kernels/matmul_op.h')
-rw-r--r-- | tensorflow/core/kernels/matmul_op.h | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h index 5a8db6da19..6398da2fb9 100644 --- a/tensorflow/core/kernels/matmul_op.h +++ b/tensorflow/core/kernels/matmul_op.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_KERNELS_MATMUL_OP_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { namespace functor { @@ -50,6 +52,68 @@ struct MatMulFunctor { }; } // end namespace functor + +#if GOOGLE_CUDA +// Encapsulate all the shape information that is used in matmul operations. +class MatmulParameters { + public: + MatmulParameters(bool transa, bool transb, uint64 m, uint64 n, uint64 k, + DataType dtype, int device_id) + : transa_(transa), + transb_(transb), + m_(m), + n_(n), + k_(k), + dtype_(dtype), + device_id_(device_id) { + hash_code_ = transa; + hash_code_ = Hash64Combine(hash_code_, transb); + hash_code_ = Hash64Combine(hash_code_, m); + hash_code_ = Hash64Combine(hash_code_, n); + hash_code_ = Hash64Combine(hash_code_, k); + hash_code_ = Hash64Combine(hash_code_, dtype); + hash_code_ = Hash64Combine(hash_code_, device_id); + } + bool operator==(const MatmulParameters& other) const { + return this->get_data_as_tuple() == other.get_data_as_tuple(); + } + + bool operator!=(const MatmulParameters& other) const { + return !(*this == other); + } + uint64 hash() const { return hash_code_; } + + string ToString() const { + // clang-format off + return strings::StrCat( + transa_, ", ", transb_, ", ", + m_, ", ", n_, ", ", k_, + dtype_, ", ", device_id_); + // clang-format on + } + + private: + typedef std::tuple<bool, bool, int64, int64, int64, DataType, int> + ParameterDataType; + + ParameterDataType get_data_as_tuple() const { + return std::make_tuple(transa_, transb_, m_, n_, k_, dtype_, device_id_); + } + + bool transa_; + bool transb_; + uint64 m_; + uint64 n_; + uint64 k_; + DataType dtype_; + int device_id_; + uint64 hash_code_; +}; + +typedef Eigen::GpuDevice GPUDevice; + +#endif // GOOGLE_CUDA + } // end namespace tensorflow #endif // TENSORFLOW_KERNELS_MATMUL_OP_H_ |