diff options
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 56822f02be..f2f9ac0a8f 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -2209,7 +2209,7 @@ bool CudnnSupport::GetConvolveAlgorithms( // clang-format on }); #if CUDNN_VERSION >= 5100 - if (WinogradNonfused<false>::IsEnabled() && with_winograd_nonfused) { + if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) { out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED); } #endif @@ -2231,7 +2231,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( // clang-format on }); #if CUDNN_VERSION >= 5100 - if (WinogradNonfused<false>::IsEnabled() && with_winograd_nonfused) { + if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) { out_algorithms->push_back( CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); } @@ -2251,7 +2251,13 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( // clang-format on }); #if CUDNN_VERSION >= 5100 - if (WinogradNonfused<false>::IsEnabled() && with_winograd_nonfused) { +#if CUDNN_VERSION >= 5110 + static constexpr bool kDefaultFlagWinogradNonfused = true; +#else + static constexpr bool kDefaultFlagWinogradNonfused = false; +#endif + if (WinogradNonfused<kDefaultFlagWinogradNonfused>::IsEnabled() && + with_winograd_nonfused) { out_algorithms->push_back( // Based on cudnn.h, the following is not implemented. // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD, |