diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-03-18 14:30:06 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-03-18 15:47:02 -0700 |
commit | 05ea40f180e528dbfde36cd338a0b6ac3cca6dd9 (patch) | |
tree | e2228aac1d752bf7ae7f3ba834860295d238d845 /tensorflow/stream_executor/blas.h | |
parent | af6a33aa47557391b3af56372598187118caf366 (diff) |
Support ScratchAllocator in BLAS Batched GEMM
Change: 117590857
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r-- | tensorflow/stream_executor/blas.h | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 94475817e0..1f1d427c45 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -50,6 +50,7 @@ namespace perftools { namespace gputools { class Stream; +class ScratchAllocator; template <typename ElemT> class DeviceMemory; @@ -880,14 +881,14 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<float> *> &a, int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, - int batch_count) = 0; + int batch_count, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a, int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, - int batch_count) = 0; + int batch_count, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex<float> alpha, @@ -895,7 +896,7 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, std::complex<float> beta, const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, - int batch_count) = 0; + int batch_count, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, std::complex<double> alpha, @@ -903,7 +904,7 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, std::complex<double> beta, const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, - int batch_count) = 0; + int batch_count, ScratchAllocator *scratch_allocator) = 0; // Computes a matrix-matrix product where one input matrix is Hermitian: // @@ -1140,7 +1141,7 @@ class BlasSupport { // Macro used to quickly declare overrides for abstract virtuals in the // BlasSupport base class. -#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ bool DoBlasAsum(Stream *stream, uint64 elem_count, \ const DeviceMemory<float> &x, int incx, \ DeviceMemory<float> *result) override; \ @@ -1626,14 +1627,14 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \ const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \ const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \ - int batch_count) override; \ + int batch_count, ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, double alpha, \ const port::ArraySlice<DeviceMemory<double> *> &a, int lda, \ const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \ const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, \ - int batch_count) override; \ + int batch_count, ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ @@ -1641,7 +1642,7 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \ std::complex<float> beta, \ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \ - int batch_count) override; \ + int batch_count, ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ @@ -1650,7 +1651,7 @@ class BlasSupport { const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, \ int ldb, std::complex<double> beta, \ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \ - int ldc, int batch_count) override; \ + int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex<float> alpha, \ const DeviceMemory<std::complex<float>> &a, int lda, \ |