diff options
author | Yangzihao Wang <yangzihao@google.com> | 2017-07-18 16:48:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-18 16:52:54 -0700 |
commit | 06acccabcb41513c76bbfffcd17817a7b136494b (patch) | |
tree | 82c2b379a5d8d4aa6cbbb653d3020ee2504bbd58 /tensorflow/stream_executor/stream.h | |
parent | 11dff5b05b3488520d3a415173d73ae91fded092 (diff) |
Add autotuning code for matmul operator.
Currently it is turned off by default.
PiperOrigin-RevId: 162423171
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r-- | tensorflow/stream_executor/stream.h | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 3c8b7ee894..e218873839 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -934,6 +934,31 @@ class Stream { std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy); + Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n, + float alpha, const DeviceMemory<float> &a, + int lda, const DeviceMemory<float> &x, + int incx, float beta, + DeviceMemory<float> *y, int incy, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n, + double alpha, const DeviceMemory<double> &a, + int lda, const DeviceMemory<double> &x, + int incx, double beta, + DeviceMemory<double> *y, int incy, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemvWithProfiling( + blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &x, int incx, + std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemvWithProfiling( + blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &x, int incx, + std::complex<double> beta, DeviceMemory<std::complex<double>> *y, + int incy, blas::ProfileResult *output_profile_result); + // See BlasSupport::DoBlasGer. Stream &ThenBlasGer(uint64 m, uint64 n, float alpha, const DeviceMemory<float> &x, int incx, @@ -1249,6 +1274,44 @@ class Stream { std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc); + Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, + const DeviceMemory<Eigen::half> &a, int lda, + const DeviceMemory<Eigen::half> &b, int ldb, + float beta, DeviceMemory<Eigen::half> *c, + int ldc, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, + const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, + float beta, DeviceMemory<float> *c, int ldc, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, + const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, + double beta, DeviceMemory<double> *c, + int ldc, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithProfiling( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithProfiling( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + blas::ProfileResult *output_profile_result); + // See BlasSupport::DoBlasGemmWithAlgorithm. Stream &ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, |