diff options
author | 2017-02-07 02:10:35 -0800 | |
---|---|---|
committer | 2017-02-07 02:28:19 -0800 | |
commit | 764507ea71473d3a2676277402b9a79bf383e78c (patch) | |
tree | d74696efc705e5850b83e951345cf3e109248851 | |
parent | e62ef9064902ce91f97e2d7342bd97b16ae25a82 (diff) |
Adding support for non-fused Winograd algorithm from Cudnn 5.1.
Enabling this by default makes some unit tests to fail. So adding an env-var
"TF_ENABLE_WINOGRAD_NONFUSED" so users can explicitly choose to enable.
Change: 146763809
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 809fb1c956..bd8aa4bacb 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/strcat.h" +#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -260,6 +261,9 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmType algorithm) { #if CUDNN_VERSION >= 5000 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD: #endif +#if CUDNN_VERSION >= 5100 + case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED: +#endif return algo; default: LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: " @@ -278,6 +282,9 @@ cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo( #if CUDNN_VERSION >= 5000 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD: #endif +#if CUDNN_VERSION >= 5100 + case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED: +#endif return algo; default: LOG(FATAL) @@ -295,6 +302,11 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo( case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3: +#if CUDNN_VERSION >= 5100 + // Based on cudnn.h, the following is not implemented. + // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD: + case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED: +#endif return algo; default: LOG(FATAL) @@ -1952,6 +1964,33 @@ bool CudnnSupport::DoConvolveImpl( return true; } +// A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms. +// Doing so by default make a few TensorFlow test cases to fail. Users can +// explicitly enable them through an env-var "TF_ENABLE_WINOGRAD_NONFUSED=1". +// https://github.com/tensorflow/tensorflow/pull/4901 +class WinogradNonfused { + public: + static bool IsEnabled() { + static bool is_enabled = IsEnabledImpl(); + return is_enabled; + } + + private: + static bool IsEnabledImpl() { + const char* tf_env_var_val = getenv("TF_ENABLE_WINOGRAD_NONFUSED"); + if (tf_env_var_val != nullptr) { + port::StringPiece tf_env_var_val_str(tf_env_var_val); + if (tf_env_var_val_str == "0") { + return false; + } + return true; + } + // TODO(zhengxq): turn the default to True when the test failure is + // resolved. + return false; + } +}; + bool CudnnSupport::GetConvolveAlgorithms( std::vector<dnn::AlgorithmType>* out_algorithms) { out_algorithms->assign({ @@ -1967,6 +2006,11 @@ bool CudnnSupport::GetConvolveAlgorithms( #endif // clang-format on }); +#if CUDNN_VERSION >= 5100 + if (WinogradNonfused::IsEnabled()) { + out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED); + } +#endif return true; } @@ -1983,6 +2027,12 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( #endif // clang-format on }); +#if CUDNN_VERSION >= 5100 + if (WinogradNonfused::IsEnabled()) { + out_algorithms->push_back( + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); + } +#endif return true; } @@ -1996,6 +2046,14 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, // clang-format on }); +#if CUDNN_VERSION >= 5100 + if (WinogradNonfused::IsEnabled()) { + out_algorithms->push_back( + // Based on cudnn.h, the following is not implemented. + // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED); + } +#endif return true; } |