From e36c16c6720dd64ae5d8a1f8555102a1323af9ae Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Fri, 3 Aug 2018 10:24:13 -0700 Subject: [XLA:GPU] Use strided batched gemm instead of building pointer tables. This is mostly a huge amount of plumbing just to call into the cublas functions. blasGemmStridedBatched has been available since CUDA 8.0. For autotuning we'd need cublasGemmStridedBatchedEx, which is new in CUDA 9.2 so I didn't wire that up yet. PiperOrigin-RevId: 207285707 --- tensorflow/stream_executor/blas.h | 66 ++++++++++++++++ tensorflow/stream_executor/cuda/cuda_blas.cc | 102 ++++++++++++++++++++++++- tensorflow/stream_executor/stream.cc | 109 +++++++++++++++++++++++++++ tensorflow/stream_executor/stream.h | 32 ++++++++ 4 files changed, 308 insertions(+), 1 deletion(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index ea87744b22..7f851e3646 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -1121,6 +1121,40 @@ class BlasSupport { const port::ArraySlice> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + // Batched gemm with strides instead of pointer arrays. + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory &a, + int lda, int64 stride_a, const DeviceMemory &b, int ldb, + int64 stride_b, float beta, DeviceMemory *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + double beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) = 0; + // Computes a matrix-matrix product where one input matrix is Hermitian: // // c <- alpha * a * b + beta * c, @@ -1990,6 +2024,38 @@ class BlasSupport { int ldb, std::complex beta, \ const port::ArraySlice> *> &c, \ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, \ + const DeviceMemory &a, int lda, int64 stride_a, \ + const DeviceMemory &b, int ldb, int64 stride_b, float beta, \ + DeviceMemory *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, \ + int lda, int64 stride_a, const DeviceMemory &b, int ldb, \ + int64 stride_b, float beta, DeviceMemory *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, double alpha, \ + const DeviceMemory &a, int lda, int64 stride_a, \ + const DeviceMemory &b, int ldb, int64 stride_b, double beta, \ + DeviceMemory *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex alpha, \ + const DeviceMemory> &a, int lda, int64 stride_a, \ + const DeviceMemory> &b, int ldb, int64 stride_b, \ + std::complex beta, DeviceMemory> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex alpha, \ + const DeviceMemory> &a, int lda, int64 stride_a, \ + const DeviceMemory> &b, int ldb, int64 stride_b, \ + std::complex beta, DeviceMemory> *c, \ + int ldc, int64 stride_c, int batch_count); \ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 6988389f29..e5b5f7fc94 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx) #if CUDA_VERSION >= 8000 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched) #endif #if CUDA_VERSION >= 9000 @@ -1865,7 +1869,7 @@ bool CUDABlas::DoBlasGemm( stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); - // GPUs < sm_70 don't support Volta hardware. + // GPUs < sm_70 don't support tensor ops. if (cc_major >= 7 && TensorOpMathEnabled()) { use_tensor_ops = true; } @@ -2623,6 +2627,102 @@ bool CUDABlas::DoBlasGemmBatched( return status.ok(); } +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory &a, + int lda, int64 stride_a, const DeviceMemory &b, int ldb, + int64 stride_b, float beta, DeviceMemory *c, int ldc, + int64 stride_c, int batch_count) { + bool use_tensor_ops = false; +#if CUDA_VERSION >= 9000 + int cc_major, cc_minor; + stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#endif + // We'd need cublasgemmStridedBatchedEx for this, which isn't available before + // CUDA 9.1. Fall back to a loop. + for (int batch = 0; batch < batch_count; ++batch) { + const auto *a_matrix = + reinterpret_cast(CUDAMemory(a) + batch * stride_a); + const auto *b_matrix = + reinterpret_cast(CUDAMemory(b) + batch * stride_b); + auto *c_matrix = + reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c); + bool ok = DoBlasInternalImpl( + wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */, + true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, + lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix, + SE_CUDA_DATA_HALF, ldc); + if (!ok) { + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } + } + return true; +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + double beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex alpha, diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index b0c061fd74..6439e3992d 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4685,6 +4685,115 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( scratch_allocator); } +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl &, int, int64, + const DeviceMemory &, int, int64, float, + DeviceMemory *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl &, int, int64, + const DeviceMemory &, int, int64, float, + DeviceMemory *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + double beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl &, int, int64, + const DeviceMemory &, int, int64, double, + DeviceMemory *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl, const DeviceMemory> &, + int, int64, const DeviceMemory> &, int, + int64, std::complex, DeviceMemory> *, + int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl, const DeviceMemory> &, + int, int64, const DeviceMemory> &, int, + int64, std::complex, + DeviceMemory> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 706442a666..62d0a2062d 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1557,6 +1557,38 @@ class Stream { std::complex beta, const port::ArraySlice> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, + int64 stride_b, float beta, DeviceMemory *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + double beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count); // See BlasSupport::DoBlasHemm. Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, -- cgit v1.2.3