aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/blas.h
diff options
context:
space:
mode:
authorGravatar Ben Barsdell <benbarsdell@gmail.com>2018-05-10 11:06:01 -0700
committerGravatar Jonathan Hseu <vomjom@vomjom.net>2018-05-10 11:06:01 -0700
commitf08f24cd559b5824a1874a0e76d339875e43f366 (patch)
treeade423df2e77815bcc246064124fbd0ecbe8e286 /tensorflow/stream_executor/blas.h
parent9201e2c002667047b1807745c4a7d6a8e5f2e9da (diff)
Add GPU support for float16 batched matmul (#18436)
* Add GPU support for float16 batched matmul - Uses cublasGemmBatchedEx introduced in CUDA 9.1. - Includes support for Tensor Op math. - Falls back to a loop over non-batched gemm calls on older CUDA versions or GPU architectures. * Refactor GPU batched gemm into one internal func
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r--tensorflow/stream_executor/blas.h14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index be0b0bf5fb..ea87744b22 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -1086,6 +1086,13 @@ class BlasSupport {
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ virtual bool DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha,
const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
@@ -1948,6 +1955,13 @@ class BlasSupport {
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, float alpha, \
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, \
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, \
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, \
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
+ bool DoBlasGemmBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, \
const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \