aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/blas.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-18 14:30:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-18 15:47:02 -0700
commit05ea40f180e528dbfde36cd338a0b6ac3cca6dd9 (patch)
treee2228aac1d752bf7ae7f3ba834860295d238d845 /tensorflow/stream_executor/blas.h
parentaf6a33aa47557391b3af56372598187118caf366 (diff)
Support ScratchAllocator in BLAS Batched GEMM
Change: 117590857
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r--tensorflow/stream_executor/blas.h19
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 94475817e0..1f1d427c45 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -50,6 +50,7 @@ namespace perftools {
namespace gputools {
class Stream;
+class ScratchAllocator;
template <typename ElemT>
class DeviceMemory;
@@ -880,14 +881,14 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
- int batch_count) = 0;
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, double alpha,
const port::ArraySlice<DeviceMemory<double> *> &a, int lda,
const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta,
const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
- int batch_count) = 0;
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<float> alpha,
@@ -895,7 +896,7 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
std::complex<float> beta,
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
- int batch_count) = 0;
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<double> alpha,
@@ -903,7 +904,7 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
- int batch_count) = 0;
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
// Computes a matrix-matrix product where one input matrix is Hermitian:
//
@@ -1140,7 +1141,7 @@ class BlasSupport {
// Macro used to quickly declare overrides for abstract virtuals in the
// BlasSupport base class.
-#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
+#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
bool DoBlasAsum(Stream *stream, uint64 elem_count, \
const DeviceMemory<float> &x, int incx, \
DeviceMemory<float> *result) override; \
@@ -1626,14 +1627,14 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \
- int batch_count) override; \
+ int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, double alpha, \
const port::ArraySlice<DeviceMemory<double> *> &a, int lda, \
const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \
const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, \
- int batch_count) override; \
+ int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
@@ -1641,7 +1642,7 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \
std::complex<float> beta, \
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \
- int batch_count) override; \
+ int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
@@ -1650,7 +1651,7 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, \
int ldb, std::complex<double> beta, \
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \
- int ldc, int batch_count) override; \
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
uint64 m, uint64 n, std::complex<float> alpha, \
const DeviceMemory<std::complex<float>> &a, int lda, \