aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2017-02-07 02:10:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-07 02:28:19 -0800
commit764507ea71473d3a2676277402b9a79bf383e78c (patch)
treed74696efc705e5850b83e951345cf3e109248851
parente62ef9064902ce91f97e2d7342bd97b16ae25a82 (diff)
Adding support for non-fused Winograd algorithm from Cudnn 5.1.
Enabling this by default makes some unit tests to fail. So adding an env-var "TF_ENABLE_WINOGRAD_NONFUSED" so users can explicitly choose to enable. Change: 146763809
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 809fb1c956..bd8aa4bacb 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/plugin_registry.h"
@@ -260,6 +261,9 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmType algorithm) {
#if CUDNN_VERSION >= 5000
case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
#endif
+#if CUDNN_VERSION >= 5100
+ case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
+#endif
return algo;
default:
LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
@@ -278,6 +282,9 @@ cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
#if CUDNN_VERSION >= 5000
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
#endif
+#if CUDNN_VERSION >= 5100
+ case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
+#endif
return algo;
default:
LOG(FATAL)
@@ -295,6 +302,11 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
+#if CUDNN_VERSION >= 5100
+ // Based on cudnn.h, the following is not implemented.
+ // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
+ case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
+#endif
return algo;
default:
LOG(FATAL)
@@ -1952,6 +1964,33 @@ bool CudnnSupport::DoConvolveImpl(
return true;
}
+// A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms.
+// Doing so by default make a few TensorFlow test cases to fail. Users can
+// explicitly enable them through an env-var "TF_ENABLE_WINOGRAD_NONFUSED=1".
+// https://github.com/tensorflow/tensorflow/pull/4901
+class WinogradNonfused {
+ public:
+ static bool IsEnabled() {
+ static bool is_enabled = IsEnabledImpl();
+ return is_enabled;
+ }
+
+ private:
+ static bool IsEnabledImpl() {
+ const char* tf_env_var_val = getenv("TF_ENABLE_WINOGRAD_NONFUSED");
+ if (tf_env_var_val != nullptr) {
+ port::StringPiece tf_env_var_val_str(tf_env_var_val);
+ if (tf_env_var_val_str == "0") {
+ return false;
+ }
+ return true;
+ }
+ // TODO(zhengxq): turn the default to True when the test failure is
+ // resolved.
+ return false;
+ }
+};
+
bool CudnnSupport::GetConvolveAlgorithms(
std::vector<dnn::AlgorithmType>* out_algorithms) {
out_algorithms->assign({
@@ -1967,6 +2006,11 @@ bool CudnnSupport::GetConvolveAlgorithms(
#endif
// clang-format on
});
+#if CUDNN_VERSION >= 5100
+ if (WinogradNonfused::IsEnabled()) {
+ out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
+ }
+#endif
return true;
}
@@ -1983,6 +2027,12 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
#endif
// clang-format on
});
+#if CUDNN_VERSION >= 5100
+ if (WinogradNonfused::IsEnabled()) {
+ out_algorithms->push_back(
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
+ }
+#endif
return true;
}
@@ -1996,6 +2046,14 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
// clang-format on
});
+#if CUDNN_VERSION >= 5100
+ if (WinogradNonfused::IsEnabled()) {
+ out_algorithms->push_back(
+ // Based on cudnn.h, the following is not implemented.
+ // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
+ }
+#endif
return true;
}