diff options
author | Justin Lebar <jlebar@google.com> | 2017-03-02 17:49:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-02 18:08:01 -0800 |
commit | 01194694948eb883e99af597d9dbbf3fc9f5c9e2 (patch) | |
tree | ab3517cf656259681283a90c6682c5b320ac36e3 /tensorflow/stream_executor/stream.h | |
parent | e065b3093f4fec5a5f79ad9de81f6baab361962e (diff) |
[XLA] [StreamExecutor] Tune GEMMs when possible.
cublas 8 adds the cublasGemmEx function, which lets you specify an
explicit "algorithm" for the computation. This functions as an opaque
tuning hint to cublas.
This patch adds support for cublasGemmEx to StreamExecutor, and wires up
XLA's GemmThunk to use the new function.
This patch does not add GEMM autotuning support in TensorFlow proper,
only XLA.
Change: 149068961
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r-- | tensorflow/stream_executor/stream.h | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index a092e87004..f22fba1d74 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1179,6 +1179,47 @@ class Stream { std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc); + // See BlasSupport::DoBlasGemmWithAlgorithm. + Stream &ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a, + int lda, const DeviceMemory<Eigen::half> &b, int ldb, + const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, + const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, + float beta, DeviceMemory<float> *c, int ldc, + blas::ComputationType computation_type, + blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + // See BlasSupport::DoBlasGemmBatched. Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, |