aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/batch_matmul_op_impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/batch_matmul_op_impl.h')
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_impl.h106
1 files changed, 100 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index a1c03f9918..475bda848d 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -329,6 +329,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
c_ptrs.push_back(&c_device_memory.back());
}
+ typedef Scalar Coefficient;
+
// Cublas does
// C = A x B
// where A, B and C are assumed to be in column major.
@@ -352,9 +354,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
bool blas_launch_status =
stream
->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
- static_cast<Scalar>(1.0), *(a_ptrs[0]),
+ static_cast<Coefficient>(1.0), *(a_ptrs[0]),
adj_x ? m : k, *(b_ptrs[0]), 1,
- static_cast<Scalar>(0.0), c_ptrs[0], 1)
+ static_cast<Coefficient>(0.0), c_ptrs[0], 1)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
@@ -366,9 +368,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
bool blas_launch_status =
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
- static_cast<Scalar>(1.0), *(b_ptrs[0]),
+ static_cast<Coefficient>(1.0), *(b_ptrs[0]),
adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
- static_cast<Scalar>(0.0), c_ptrs[0], n)
+ static_cast<Coefficient>(0.0), c_ptrs[0], n)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
@@ -383,8 +385,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
stream
->ThenBlasGemmBatchedWithScratch(
blas_transpose_b, blas_transpose_a, n, m, k,
- static_cast<Scalar>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
- adj_x ? m : k, static_cast<Scalar>(0.0), c_ptrs, n,
+ static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
+ adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
batch_size, &scratch_allocator)
.ok();
if (!blas_launch_status) {
@@ -398,6 +400,98 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
}
};
+template <>
+struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
+ static void Launch(OpKernelContext* context, const Tensor& in_x,
+ const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
+ typedef Eigen::half Scalar;
+ constexpr perftools::gputools::blas::Transpose kTranspose =
+ is_complex<Scalar>::value
+ ? perftools::gputools::blas::Transpose::kConjugateTranspose
+ : perftools::gputools::blas::Transpose::kTranspose;
+ perftools::gputools::blas::Transpose trans[] = {
+ perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
+ const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
+ const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
+ const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
+ const uint64 batch_size = in_x.dim_size(0);
+ auto blas_transpose_a = trans[adj_x];
+ auto blas_transpose_b = trans[adj_y];
+
+ auto* stream = context->op_device_context()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
+ std::vector<DeviceMemoryType> a_device_memory;
+ std::vector<DeviceMemoryType> b_device_memory;
+ std::vector<DeviceMemoryType> c_device_memory;
+ std::vector<DeviceMemoryType*> a_ptrs;
+ std::vector<DeviceMemoryType*> b_ptrs;
+ std::vector<DeviceMemoryType*> c_ptrs;
+ a_device_memory.reserve(batch_size);
+ b_device_memory.reserve(batch_size);
+ c_device_memory.reserve(batch_size);
+ a_ptrs.reserve(batch_size);
+ b_ptrs.reserve(batch_size);
+ c_ptrs.reserve(batch_size);
+ auto* a_base_ptr = in_x.template flat<Scalar>().data();
+ auto* b_base_ptr = in_y.template flat<Scalar>().data();
+ auto* c_base_ptr = out->template flat<Scalar>().data();
+ for (int64 i = 0; i < batch_size; ++i) {
+ a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
+ b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
+ c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
+ a_ptrs.push_back(&a_device_memory.back());
+ b_ptrs.push_back(&b_device_memory.back());
+ c_ptrs.push_back(&c_device_memory.back());
+ }
+
+ typedef float Coefficient;
+
+ // Cublas does
+ // C = A x B
+ // where A, B and C are assumed to be in column major.
+ // We want the output to be in row-major, so we can compute
+ // C' = B' x A', where ' stands for transpose (not adjoint).
+ // TODO(yangzihao): Choose the best of the three strategies using autotune.
+ if (batch_size == 1) {
+ // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
+ // overhead of the scratch allocator and the batch interface.
+ // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
+ static_cast<Coefficient>(1.0), *(b_ptrs[0]),
+ adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
+ static_cast<Coefficient>(0.0), c_ptrs[0], n)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(errors::Internal(
+ "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
+ ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
+ ", k=", k));
+ }
+ } else {
+ CublasScratchAllocator scratch_allocator(context);
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemmBatchedWithScratch(
+ blas_transpose_b, blas_transpose_a, n, m, k,
+ static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
+ adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
+ batch_size, &scratch_allocator)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(
+ errors::Internal("Blas xGEMMBatched launch failed : a.shape=",
+ in_x.shape().DebugString(), ", b.shape=",
+ in_y.shape().DebugString(), ", m=", m, ", n=", n,
+ ", k=", k, ", batch_size=", batch_size));
+ }
+ }
+ }
+};
+
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL