diff options
Diffstat (limited to 'tensorflow/core/kernels/batch_matmul_op_impl.h')
-rw-r--r-- | tensorflow/core/kernels/batch_matmul_op_impl.h | 106 |
1 files changed, 100 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index a1c03f9918..475bda848d 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -329,6 +329,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> { c_ptrs.push_back(&c_device_memory.back()); } + typedef Scalar Coefficient; + // Cublas does // C = A x B // where A, B and C are assumed to be in column major. @@ -352,9 +354,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> { bool blas_launch_status = stream ->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m, - static_cast<Scalar>(1.0), *(a_ptrs[0]), + static_cast<Coefficient>(1.0), *(a_ptrs[0]), adj_x ? m : k, *(b_ptrs[0]), 1, - static_cast<Scalar>(0.0), c_ptrs[0], 1) + static_cast<Coefficient>(0.0), c_ptrs[0], 1) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -366,9 +368,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> { bool blas_launch_status = stream ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, - static_cast<Scalar>(1.0), *(b_ptrs[0]), + static_cast<Coefficient>(1.0), *(b_ptrs[0]), adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k, - static_cast<Scalar>(0.0), c_ptrs[0], n) + static_cast<Coefficient>(0.0), c_ptrs[0], n) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -383,8 +385,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> { stream ->ThenBlasGemmBatchedWithScratch( blas_transpose_b, blas_transpose_a, n, m, k, - static_cast<Scalar>(1.0), b_ptrs, adj_y ? k : n, a_ptrs, - adj_x ? m : k, static_cast<Scalar>(0.0), c_ptrs, n, + static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs, + adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n, batch_size, &scratch_allocator) .ok(); if (!blas_launch_status) { @@ -398,6 +400,98 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> { } }; +template <> +struct LaunchBatchMatMul<GPUDevice, Eigen::half> { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { + typedef Eigen::half Scalar; + constexpr perftools::gputools::blas::Transpose kTranspose = + is_complex<Scalar>::value + ? perftools::gputools::blas::Transpose::kConjugateTranspose + : perftools::gputools::blas::Transpose::kTranspose; + perftools::gputools::blas::Transpose trans[] = { + perftools::gputools::blas::Transpose::kNoTranspose, kTranspose}; + const uint64 m = in_x.dim_size(adj_x ? 2 : 1); + const uint64 k = in_x.dim_size(adj_x ? 1 : 2); + const uint64 n = in_y.dim_size(adj_y ? 1 : 2); + const uint64 batch_size = in_x.dim_size(0); + auto blas_transpose_a = trans[adj_x]; + auto blas_transpose_b = trans[adj_y]; + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType; + std::vector<DeviceMemoryType> a_device_memory; + std::vector<DeviceMemoryType> b_device_memory; + std::vector<DeviceMemoryType> c_device_memory; + std::vector<DeviceMemoryType*> a_ptrs; + std::vector<DeviceMemoryType*> b_ptrs; + std::vector<DeviceMemoryType*> c_ptrs; + a_device_memory.reserve(batch_size); + b_device_memory.reserve(batch_size); + c_device_memory.reserve(batch_size); + a_ptrs.reserve(batch_size); + b_ptrs.reserve(batch_size); + c_ptrs.reserve(batch_size); + auto* a_base_ptr = in_x.template flat<Scalar>().data(); + auto* b_base_ptr = in_y.template flat<Scalar>().data(); + auto* c_base_ptr = out->template flat<Scalar>().data(); + for (int64 i = 0; i < batch_size; ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } + + typedef float Coefficient; + + // Cublas does + // C = A x B + // where A, B and C are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // C' = B' x A', where ' stands for transpose (not adjoint). + // TODO(yangzihao): Choose the best of the three strategies using autotune. + if (batch_size == 1) { + // This is a regular matrix*matrix or matrix*vector multiply. Avoid the + // overhead of the scratch allocator and the batch interface. + // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS + bool blas_launch_status = + stream + ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, + static_cast<Coefficient>(1.0), *(b_ptrs[0]), + adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k, + static_cast<Coefficient>(0.0), c_ptrs[0], n) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k)); + } + } else { + CublasScratchAllocator scratch_allocator(context); + bool blas_launch_status = + stream + ->ThenBlasGemmBatchedWithScratch( + blas_transpose_b, blas_transpose_a, n, m, k, + static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs, + adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n, + batch_size, &scratch_allocator) + .ok(); + if (!blas_launch_status) { + context->SetStatus( + errors::Internal("Blas xGEMMBatched launch failed : a.shape=", + in_x.shape().DebugString(), ", b.shape=", + in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } + } +}; + #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL |