// See docs in ../ops/math_ops.cc. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/matmul_op.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/fill_functor.h" #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/stream_executor/stream.h" #endif // GOOGLE_CUDA namespace tensorflow { #if GOOGLE_CUDA namespace { template perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory) { perftools::gputools::DeviceMemoryBase wrapped(const_cast(cuda_memory)); perftools::gputools::DeviceMemory typed(wrapped); return typed; } } // namespace #endif // GOOGLE_CUDA typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; template struct LaunchMatMul; // On CPUs, we ignore USE_CUBLAS template struct LaunchMatMulCPU { static void launch( OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b, const Eigen::array, 1>& dim_pair, Tensor* out) { functor::MatMulFunctor()(ctx->eigen_device(), out->matrix(), a.matrix(), b.matrix(), dim_pair); } }; template struct LaunchMatMul : public LaunchMatMulCPU {}; #if GOOGLE_CUDA template struct LaunchMatMul { static void launch( OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b, const Eigen::array, 1>& dim_pair, Tensor* out) { perftools::gputools::blas::Transpose trans[] = { perftools::gputools::blas::Transpose::kNoTranspose, perftools::gputools::blas::Transpose::kTranspose}; const uint64 m = a.dim_size(1 - dim_pair[0].first); const uint64 k = a.dim_size(dim_pair[0].first); const uint64 n = b.dim_size(1 - dim_pair[0].second); bool transpose_a = dim_pair[0].first == 0; bool transpose_b = dim_pair[0].second == 1; auto blas_transpose_a = trans[transpose_a]; auto blas_transpose_b = trans[transpose_b]; auto* stream = ctx->op_device_context()->stream(); OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); auto a_ptr = AsDeviceMemory(a.template flat().data()); auto b_ptr = AsDeviceMemory(b.template flat().data()); auto c_ptr = AsDeviceMemory(out->template flat().data()); // Cublas does // C = A x B // where A, B and C are assumed to be in column major. // We want the output to be in row-major, so we can compute // C' = B' x A' (' stands for transpose) bool blas_launch_status = stream->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 1.0f, b_ptr, transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0f, &c_ptr, n) .ok(); if (!blas_launch_status) { ctx->SetStatus(errors::Internal( "Blas SGEMM launch failed : a.shape=(", a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k)); } } }; template struct LaunchMatMul { static void launch( OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b, const Eigen::array, 1>& dim_pair, Tensor* out) { functor::MatMulFunctor()(ctx->eigen_device(), out->matrix(), a.matrix(), b.matrix(), dim_pair); } }; #endif // GOOGLE_CUDA template class MatMulOp : public OpKernel { public: explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); } void Compute(OpKernelContext* ctx) override { const Tensor& a = ctx->input(0); const Tensor& b = ctx->input(1); // Check that the dimensions of the two matrices are valid. OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), errors::InvalidArgument("In[0] is not a matrix")); OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), errors::InvalidArgument("In[1] is not a matrix")); Eigen::array, 1> dim_pair; dim_pair[0].first = transpose_a_ ? 0 : 1; dim_pair[0].second = transpose_b_ ? 1 : 0; OP_REQUIRES(ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), errors::InvalidArgument("Matrix size-compatible: In[0]: ", a.shape().DebugString(), ", In[1]: ", b.shape().DebugString())); int a_dim_remaining = 1 - dim_pair[0].first; int b_dim_remaining = 1 - dim_pair[0].second; TensorShape out_shape( {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); if (out->NumElements() == 0) { // If a has shape [0, x] or b has shape [x, 0], the output shape // is a 0-element matrix, so there is nothing to do. return; } if (a.NumElements() == 0 || b.NumElements() == 0) { // If a has shape [x, 0] and b has shape [0, y], the // output shape is [x, y] where x and y are non-zero, so we fill // the output with zeros. functor::SetZeroFunctor f; f(ctx->eigen_device(), out->flat()); return; } LaunchMatMul::launch(ctx, this, a, b, dim_pair, out); } private: bool transpose_a_; bool transpose_b_; }; namespace functor { // Partial specialization MatMulFunctor. template struct MatMulFunctor { void operator()( const CPUDevice& d, typename MatMulTypes::out_type out, typename MatMulTypes::in_type in0, typename MatMulTypes::in_type in1, const Eigen::array, 1>& dim_pair) { MatMul(d, out, in0, in1, dim_pair); } }; } // end namespace functor #define REGISTER_CPU(T) \ REGISTER_KERNEL_BUILDER( \ Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ MatMulOp); \ REGISTER_KERNEL_BUILDER( \ Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T").Label("eigen"), \ MatMulOp) #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER( \ Name("MatMul").Device(DEVICE_GPU).TypeConstraint("T"), \ MatMulOp); \ REGISTER_KERNEL_BUILDER(Name("MatMul") \ .Device(DEVICE_GPU) \ .TypeConstraint("T") \ .Label("cublas"), \ MatMulOp); \ REGISTER_KERNEL_BUILDER( \ Name("MatMul").Device(DEVICE_GPU).TypeConstraint("T").Label("eigen"), \ MatMulOp) REGISTER_CPU(float); REGISTER_CPU(double); REGISTER_CPU(int32); REGISTER_CPU(complex64); #if GOOGLE_CUDA REGISTER_GPU(float); // REGISTER_GPU(double); #endif // GOOGLE_CUDA } // namespace tensorflow