diff options
author | 2017-03-02 17:49:45 -0800 | |
---|---|---|
committer | 2017-03-02 18:08:01 -0800 | |
commit | 01194694948eb883e99af597d9dbbf3fc9f5c9e2 (patch) | |
tree | ab3517cf656259681283a90c6682c5b320ac36e3 /tensorflow/stream_executor/stream_executor_pimpl.cc | |
parent | e065b3093f4fec5a5f79ad9de81f6baab361962e (diff) |
[XLA] [StreamExecutor] Tune GEMMs when possible.
cublas 8 adds the cublasGemmEx function, which lets you specify an
explicit "algorithm" for the computation. This functions as an opaque
tuning hint to cublas.
This patch adds support for cublasGemmEx to StreamExecutor, and wires up
XLA's GemmThunk to use the new function.
This patch does not add GEMM autotuning support in TensorFlow proper,
only XLA.
Change: 149068961
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.cc | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index c498eecb3c..42fcd5867c 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -310,6 +310,15 @@ bool StreamExecutor::GetConvolveBackwardFilterAlgorithms( return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms); } +bool StreamExecutor::GetBlasGemmAlgorithms( + std::vector<blas::AlgorithmType> *out_algorithms) { + blas::BlasSupport *blas_support = AsBlas(); + if (!blas_support) { + return false; + } + return blas_support->GetBlasGemmAlgorithms(out_algorithms); +} + port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> StreamExecutor::createRnnDescriptor( int num_layers, int hidden_size, int input_size, |