diff options
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.cc | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index fe32039d71..5e55169613 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -295,6 +295,24 @@ bool StreamExecutor::GetConvolveAlgorithms( return dnn_support->GetConvolveAlgorithms(out_algorithms); } +bool StreamExecutor::GetConvolveBackwardDataAlgorithms( + std::vector<dnn::AlgorithmType> *out_algorithms) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return false; + } + return dnn_support->GetConvolveBackwardDataAlgorithms(out_algorithms); +} + +bool StreamExecutor::GetConvolveBackwardFilterAlgorithms( + std::vector<dnn::AlgorithmType> *out_algorithms) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return false; + } + return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms); +} + dnn::DnnSupport *StreamExecutor::AsDnn() { mutex_lock lock{mu_}; if (dnn_ != nullptr) { |