aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.cc
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-05-28 14:20:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-28 14:23:23 -0700
commite78e5ec8a8c862e65b6a194e9caea377120d7207 (patch)
treeef713546bfdacced118056167b49de0c4c2f1a36 /tensorflow/stream_executor/stream_executor_pimpl.cc
parent3f9b69a50f40154f6078e1610ce7d3afa94bd07c (diff)
Set winograd nofused flag to be true by default.
Disable winograd nonfused conv for certain input params to avoid a known bug in cuDNNv5 and cuDNNv6. PiperOrigin-RevId: 157352847
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index fe5da12639..b3eefe0299 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -285,30 +285,36 @@ bool StreamExecutor::SupportsDnn() const {
}
bool StreamExecutor::GetConvolveAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveAlgorithms(out_algorithms);
+ return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused,
+ out_algorithms);
}
bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveBackwardDataAlgorithms(out_algorithms);
+ return dnn_support->GetConvolveBackwardDataAlgorithms(with_winograd_nonfused,
+ out_algorithms);
}
bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms);
+ return dnn_support->GetConvolveBackwardFilterAlgorithms(
+ with_winograd_nonfused, out_algorithms);
}
bool StreamExecutor::GetBlasGemmAlgorithms(