diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-08-03 10:24:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-03 10:28:08 -0700 |
commit | e36c16c6720dd64ae5d8a1f8555102a1323af9ae (patch) | |
tree | 9076cc7fa753542490bac8b1fcd2bc3a3c2ac62a | |
parent | 7935c176118f0e50aa657a1c68a85430b70d2245 (diff) |
[XLA:GPU] Use strided batched gemm instead of building pointer tables.
This is mostly a huge amount of plumbing just to call into the cublas functions.
blasGemmStridedBatched has been available since CUDA 8.0.
For autotuning we'd need cublasGemmStridedBatchedEx, which is new in CUDA 9.2
so I didn't wire that up yet.
PiperOrigin-RevId: 207285707
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gemm_thunk.cc | 52 | ||||
-rw-r--r-- | tensorflow/stream_executor/blas.h | 66 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 102 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 109 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.h | 32 |
5 files changed, 319 insertions, 42 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index e9ba1f13eb..a300d5f3fe 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -47,12 +47,6 @@ struct MatrixDescriptor { int64 batch_size; }; -template <typename T> -se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { - se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); - return se::DeviceMemory<T>(wrapped); -} - // Performs a gemm call without an explicit algorithm on lhs_matrix and // rhs_matrix, and stores the result to output_matrix. template <typename Element> @@ -84,43 +78,19 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, .ok(); } - // Create the buffers for batched gemm. - // TODO(b/112111608): We could avoid all of this and also make it faster by - // using cuBLAS 8's strided batched gemm. - using DeviceMemoryType = se::DeviceMemory<Element>; - std::vector<DeviceMemoryType> a_device_memory; - std::vector<DeviceMemoryType> b_device_memory; - std::vector<DeviceMemoryType> c_device_memory; - std::vector<DeviceMemoryType*> a_ptrs; - std::vector<DeviceMemoryType*> b_ptrs; - std::vector<DeviceMemoryType*> c_ptrs; - a_device_memory.reserve(batch_size); - b_device_memory.reserve(batch_size); - c_device_memory.reserve(batch_size); - a_ptrs.reserve(batch_size); - b_ptrs.reserve(batch_size); - c_ptrs.reserve(batch_size); - auto* a_base_ptr = static_cast<Element*>(lhs_data.opaque()); - auto* b_base_ptr = static_cast<Element*>(rhs_data.opaque()); - auto* c_base_ptr = static_cast<Element*>(output_data.opaque()); - for (int64 i = 0; i < batch_size; ++i) { - a_device_memory.push_back(AsDeviceMemory( - a_base_ptr + i * lhs_matrix.num_rows * lhs_matrix.num_cols)); - b_device_memory.push_back(AsDeviceMemory( - b_base_ptr + i * rhs_matrix.num_rows * rhs_matrix.num_cols)); - c_device_memory.push_back(AsDeviceMemory( - c_base_ptr + i * output_matrix.num_rows * output_matrix.num_cols)); - a_ptrs.push_back(&a_device_memory.back()); - b_ptrs.push_back(&b_device_memory.back()); - c_ptrs.push_back(&c_device_memory.back()); - } + int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols; + int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols; + int64 output_stride = output_matrix.num_rows * output_matrix.num_cols; return stream - ->ThenBlasGemmBatched( + ->ThenBlasGemmStridedBatched( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, - a_ptrs, /*leading dim of LHS=*/lhs_matrix.num_rows, b_ptrs, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, c_ptrs, - /*leading dim of output=*/output_matrix.num_rows, batch_size) + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/alpha, lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, + /*beta=*/0.0, &output_data, + /*leading dim of output=*/output_matrix.num_rows, output_stride, + batch_size) .ok(); } diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index ea87744b22..7f851e3646 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -1121,6 +1121,40 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + // Batched gemm with strides instead of pointer arrays. + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) = 0; + // Computes a matrix-matrix product where one input matrix is Hermitian: // // c <- alpha * a * b + beta * c, @@ -1990,6 +2024,38 @@ class BlasSupport { int ldb, std::complex<double> beta, \ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, \ + const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \ + const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \ + DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ + int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \ + int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, double alpha, \ + const DeviceMemory<double> &a, int lda, int64 stride_a, \ + const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \ + DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \ + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \ + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \ + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \ + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ + int ldc, int64 stride_c, int batch_count); \ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex<float> alpha, \ const DeviceMemory<std::complex<float>> &a, int lda, \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 6988389f29..e5b5f7fc94 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx) #if CUDA_VERSION >= 8000 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched) #endif #if CUDA_VERSION >= 9000 @@ -1865,7 +1869,7 @@ bool CUDABlas::DoBlasGemm( stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); - // GPUs < sm_70 don't support Volta hardware. + // GPUs < sm_70 don't support tensor ops. if (cc_major >= 7 && TensorOpMathEnabled()) { use_tensor_ops = true; } @@ -2623,6 +2627,102 @@ bool CUDABlas::DoBlasGemmBatched( return status.ok(); } +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) { + bool use_tensor_ops = false; +#if CUDA_VERSION >= 9000 + int cc_major, cc_minor; + stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#endif + // We'd need cublasgemmStridedBatchedEx for this, which isn't available before + // CUDA 9.1. Fall back to a loop. + for (int batch = 0; batch < batch_count; ++batch) { + const auto *a_matrix = + reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a); + const auto *b_matrix = + reinterpret_cast<const __half *>(CUDAMemory(b) + batch * stride_b); + auto *c_matrix = + reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c); + bool ok = DoBlasInternalImpl( + wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */, + true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, + lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix, + SE_CUDA_DATA_HALF, ldc); + if (!ok) { + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } + } + return true; +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex<float> alpha, diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index b0c061fd74..6439e3992d 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4685,6 +4685,115 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( scratch_allocator); } +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, + int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, + const DeviceMemory<Eigen::half> &, int, int64, + const DeviceMemory<Eigen::half> &, int, int64, float, + DeviceMemory<Eigen::half> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, + const DeviceMemory<float> &, int, int64, + const DeviceMemory<float> &, int, int64, float, + DeviceMemory<float> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double, + const DeviceMemory<double> &, int, int64, + const DeviceMemory<double> &, int, int64, double, + DeviceMemory<double> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<float>, const DeviceMemory<std::complex<float>> &, + int, int64, const DeviceMemory<std::complex<float>> &, int, + int64, std::complex<float>, DeviceMemory<std::complex<float>> *, + int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<double>, const DeviceMemory<std::complex<double>> &, + int, int64, const DeviceMemory<std::complex<double>> &, int, + int64, std::complex<double>, + DeviceMemory<std::complex<double>> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 706442a666..62d0a2062d 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1557,6 +1557,38 @@ class Stream { std::complex<double> beta, const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, + int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count); // See BlasSupport::DoBlasHemm. Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, |