aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.h
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-08-03 10:24:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 10:28:08 -0700
commite36c16c6720dd64ae5d8a1f8555102a1323af9ae (patch)
tree9076cc7fa753542490bac8b1fcd2bc3a3c2ac62a /tensorflow/stream_executor/stream.h
parent7935c176118f0e50aa657a1c68a85430b70d2245 (diff)
[XLA:GPU] Use strided batched gemm instead of building pointer tables.
This is mostly a huge amount of plumbing just to call into the cublas functions. blasGemmStridedBatched has been available since CUDA 8.0. For autotuning we'd need cublasGemmStridedBatchedEx, which is new in CUDA 9.2 so I didn't wire that up yet. PiperOrigin-RevId: 207285707
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r--tensorflow/stream_executor/stream.h32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 706442a666..62d0a2062d 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1557,6 +1557,38 @@ class Stream {
std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
+ int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count);
// See BlasSupport::DoBlasHemm.
Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,