aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-05-28 14:20:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-28 14:23:23 -0700
commite78e5ec8a8c862e65b6a194e9caea377120d7207 (patch)
treeef713546bfdacced118056167b49de0c4c2f1a36 /tensorflow/stream_executor/dnn.h
parent3f9b69a50f40154f6078e1610ce7d3afa94bd07c (diff)
Set winograd nofused flag to be true by default.
Disable winograd nonfused conv for certain input params to avoid a known bug in cuDNNv5 and cuDNNv6. PiperOrigin-RevId: 157352847
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index c5805064f3..8e56933ba3 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -952,7 +952,7 @@ class DnnSupport {
// Return a list of algorithms supported by the forward convolution pass.
virtual bool GetConvolveAlgorithms(
- std::vector<AlgorithmType>* out_algorithms);
+ bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
// Enqueues a double-precision convolution operation onto the stream.
// See DoConvolve above for argument details.
@@ -1056,7 +1056,7 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for
// data.
virtual bool GetConvolveBackwardDataAlgorithms(
- std::vector<AlgorithmType>* out_algorithms);
+ bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
virtual bool DoConvolveBackwardData(
Stream* stream, const FilterDescriptor& filter_descriptor,
@@ -1104,7 +1104,7 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for
// filters.
virtual bool GetConvolveBackwardFilterAlgorithms(
- std::vector<AlgorithmType>* out_algorithms);
+ bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
virtual bool DoConvolveBackwardFilter(
Stream* stream, const BatchDescriptor& input_descriptor,