diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_matmul_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_matmul_op.cc | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc new file mode 100644 index 0000000000..3ba28c13ed --- /dev/null +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -0,0 +1,217 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/math_ops.cc. + +// This file uses MKL CBLAS xGEMM for acceleration of TF Matrix-Matrix +// Multiplication (MatMul) operations. +// We currently register this kernel only for MKL supported data +// types (float, double, complex64, complex128). The macro INTEL_MKL is defined +// by the build system only when MKL is chosen as an option at configure stage +// and when it is undefined at build time, this file becomes an empty +// compilation unit + +#if defined(INTEL_MKL) + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "third_party/mkl/include/mkl_cblas.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template <typename Device, typename T, bool USE_CUBLAS> +class MklMatMulOp : public OpKernel { + public: + explicit MklMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + + // Check that the dimensions of the two matrices are valid. + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), + errors::InvalidArgument("In[0] is not a matrix")); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), + errors::InvalidArgument("In[1] is not a matrix")); + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; + dim_pair[0].first = transpose_a_ ? 0 : 1; + dim_pair[0].second = transpose_b_ ? 1 : 0; + + OP_REQUIRES(ctx, + a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), + errors::InvalidArgument("Matrix size-incompatible: In[0]: ", + a.shape().DebugString(), ", In[1]: ", + b.shape().DebugString())); + int a_dim_remaining = 1 - dim_pair[0].first; + int b_dim_remaining = 1 - dim_pair[0].second; + TensorShape out_shape( + {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); + + if (out->NumElements() == 0) { + // If a has shape [0, x] or b has shape [x, 0], the output shape + // is a 0-element matrix, so there is nothing to do. + return; + } + + if (a.NumElements() == 0 || b.NumElements() == 0) { + // If a has shape [x, 0] and b has shape [0, y], the + // output shape is [x, y] where x and y are non-zero, so we fill + // the output with zeros. + functor::SetZeroFunctor<Device, T> f; + f(ctx->eigen_device<Device>(), out->flat<T>()); + return; + } + + const int m = a.dim_size(1 - dim_pair[0].first); + const int k = a.dim_size(dim_pair[0].first); + const int n = b.dim_size(1 - dim_pair[0].second); + bool transpose_a = dim_pair[0].first == 0; + bool transpose_b = dim_pair[0].second == 1; + + auto a_ptr = (a.template flat<T>().data()); + auto b_ptr = (b.template flat<T>().data()); + auto c_ptr = (out->template flat<T>().data()); + + MklBlasGemm(transpose_a, transpose_b, m, n, k, a_ptr, transpose_a ? m : k, + b_ptr, transpose_b ? k : n, c_ptr, n); + } + + private: + bool transpose_a_; + bool transpose_b_; + + // -------------------------------------------------------------------------- + // + // @brief Matrix-Matrix Multiplication with FP32 tensors, a, b, c using CBLAS + // interface. c = op(a) * op(b) + // + // @param transa Specifies the form of op(a) used in MatMul. If transa is + // true, then op(a) = a^T, otherwise op(a) = a + // + // @param transb Specifies the form of op(b) used in MatMul. If transb is + // true, then op(b) = b^T, otherwise op(b) = b + // + // @param m Specifies the number of rows of the matrix op(a) and of the + // matrix c. The value of m must be at least zero. + // + // @param n Specifies the number of columns of the matrix op(b) and the + // number of columns of the matrix c. The value of n must be at least zero. + // + // @param k Specifies the number of columns of the matrix op(a) and the + // number of rows of the matrix op(b) + // + // @param a Address of matrix a + // + // @param lda Leading dimension of 'a' matrix. This is set at calling site + // depending on transa parameter. Since TF uses row-major + // layout, leading dimension is the stride between consecutive rows + // lda = max(1,k) when transa is false, otherwise lda = max(1,m) + // + // @param b Address of matrix b + // + // @param ldb Leading dimension of 'b' matrix. This is set at calling site + // depending on transb parameter. Since TF uses row-major + // layout, leading dimension is the stride between consecutive rows + // ldb = max(1,n) when transb is false, otherwise ldb = max(1,k) + // + // @param c Address of matrix c + // + // @param ldc Leading dimension of 'c' matrix. Since TF uses row-major + // layout, leading dimension is the stride between consecutive rows, max(1,n) + // + // -------------------------------------------------------------------------- + void MklBlasGemm(bool transa, bool transb, const int m, const int n, + const int k, const float* a, const int lda, const float* b, + const int ldb, float* c, const int ldc) { + // BLAS GEMM API defines Matrix Multiplication as c = alpha * op(a) * op(b) + // + beta * c. + // Since TF MatMul does not have parameters for alpha, beta, we set them to + // 1.0 and 0.0 respectively. + const float alpha = 1.0f; + const float beta = 0.0f; + cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, + transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); + } + + // Matrix-Matrix Multiplication with FP64 tensors. For detailed info about + // parameters, look at FP32 function description. + void MklBlasGemm(bool transa, bool transb, const int m, const int n, + const int k, const double* a, const int lda, const double* b, + const int ldb, double* c, const int ldc) { + const double alpha = 1.0; + const double beta = 0.0; + cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, + transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); + } + + // Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors. + // For detailed info about parameters, look at FP32 function description. + void MklBlasGemm(bool transa, bool transb, const int m, const int n, + const int k, const std::complex<float>* a, const int lda, + const std::complex<float>* b, const int ldb, + std::complex<float>* c, int const ldc) { + const MKL_Complex8 alpha = {1.0f, 0.0f}; + const MKL_Complex8 beta = {0.0f, 0.0f}; + cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, + transb ? CblasTrans : CblasNoTrans, m, n, k, + static_cast<const void*>(&alpha), static_cast<const void*>(a), + lda, static_cast<const void*>(b), ldb, + static_cast<const void*>(&beta), static_cast<void*>(c), ldc); + } + + // Matrix-Matrix Multiplication with Complex128 (std::complex<double>) + // tensors. For detailed info about parameters, look at FP32 function + // description. + void MklBlasGemm(bool transa, bool transb, const int m, const int n, + const int k, const std::complex<double>* a, const int lda, + const std::complex<double>* b, const int ldb, + std::complex<double>* c, const int ldc) { + const MKL_Complex16 alpha = {1.0, 0.0}; + const MKL_Complex16 beta = {0.0, 0.0}; + cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, + transb ? CblasTrans : CblasNoTrans, m, n, k, + static_cast<const void*>(&alpha), static_cast<const void*>(a), + lda, static_cast<const void*>(b), ldb, + static_cast<const void*>(&beta), static_cast<void*>(c), ldc); + } +}; + +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("MKL"), \ + MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>) + +// TODO:Consider template specialization when adding/removing additional types +TF_CALL_float(REGISTER_CPU); +TF_CALL_double(REGISTER_CPU); +TF_CALL_complex64(REGISTER_CPU); +TF_CALL_complex128(REGISTER_CPU); + +} // namespace tensorflow +#endif // INTEL_MKL |