diff options
author | Ben Barsdell <benbarsdell@gmail.com> | 2018-05-10 11:06:01 -0700 |
---|---|---|
committer | Jonathan Hseu <vomjom@vomjom.net> | 2018-05-10 11:06:01 -0700 |
commit | f08f24cd559b5824a1874a0e76d339875e43f366 (patch) | |
tree | ade423df2e77815bcc246064124fbd0ecbe8e286 /tensorflow/stream_executor/blas.h | |
parent | 9201e2c002667047b1807745c4a7d6a8e5f2e9da (diff) |
Add GPU support for float16 batched matmul (#18436)
* Add GPU support for float16 batched matmul
- Uses cublasGemmBatchedEx introduced in CUDA 9.1.
- Includes support for Tensor Op math.
- Falls back to a loop over non-batched gemm calls on older CUDA
versions or GPU architectures.
* Refactor GPU batched gemm into one internal func
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r-- | tensorflow/stream_executor/blas.h | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index be0b0bf5fb..ea87744b22 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -1086,6 +1086,13 @@ class BlasSupport { virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, + const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, + const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, + float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, + int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + virtual bool DoBlasGemmBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a, int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, @@ -1948,6 +1955,13 @@ class BlasSupport { bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, float alpha, \ + const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, \ + const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, \ + float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, \ + int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + bool DoBlasGemmBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, \ const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \ const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \ const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \ |