aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 56822f02be..f2f9ac0a8f 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -2209,7 +2209,7 @@ bool CudnnSupport::GetConvolveAlgorithms(
// clang-format on
});
#if CUDNN_VERSION >= 5100
- if (WinogradNonfused<false>::IsEnabled() && with_winograd_nonfused) {
+ if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) {
out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
}
#endif
@@ -2231,7 +2231,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
// clang-format on
});
#if CUDNN_VERSION >= 5100
- if (WinogradNonfused<false>::IsEnabled() && with_winograd_nonfused) {
+ if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) {
out_algorithms->push_back(
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
}
@@ -2251,7 +2251,13 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
// clang-format on
});
#if CUDNN_VERSION >= 5100
- if (WinogradNonfused<false>::IsEnabled() && with_winograd_nonfused) {
+#if CUDNN_VERSION >= 5110
+ static constexpr bool kDefaultFlagWinogradNonfused = true;
+#else
+ static constexpr bool kDefaultFlagWinogradNonfused = false;
+#endif
+ if (WinogradNonfused<kDefaultFlagWinogradNonfused>::IsEnabled() &&
+ with_winograd_nonfused) {
out_algorithms->push_back(
// Based on cudnn.h, the following is not implemented.
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,