diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-08-03 10:24:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-03 10:28:08 -0700 |
commit | e36c16c6720dd64ae5d8a1f8555102a1323af9ae (patch) | |
tree | 9076cc7fa753542490bac8b1fcd2bc3a3c2ac62a /tensorflow/stream_executor/blas.h | |
parent | 7935c176118f0e50aa657a1c68a85430b70d2245 (diff) |
[XLA:GPU] Use strided batched gemm instead of building pointer tables.
This is mostly a huge amount of plumbing just to call into the cublas functions.
blasGemmStridedBatched has been available since CUDA 8.0.
For autotuning we'd need cublasGemmStridedBatchedEx, which is new in CUDA 9.2
so I didn't wire that up yet.
PiperOrigin-RevId: 207285707
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r-- | tensorflow/stream_executor/blas.h | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index ea87744b22..7f851e3646 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -1121,6 +1121,40 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + // Batched gemm with strides instead of pointer arrays. + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) = 0; + // Computes a matrix-matrix product where one input matrix is Hermitian: // // c <- alpha * a * b + beta * c, @@ -1990,6 +2024,38 @@ class BlasSupport { int ldb, std::complex<double> beta, \ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, \ + const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \ + const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \ + DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ + int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \ + int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, double alpha, \ + const DeviceMemory<double> &a, int lda, int64 stride_a, \ + const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \ + DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \ + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \ + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \ + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \ + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ + int ldc, int64 stride_c, int batch_count); \ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex<float> alpha, \ const DeviceMemory<std::complex<float>> &a, int lda, \ |