diff options
author | Tim Shen <timshen@google.com> | 2018-09-06 18:42:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 18:46:08 -0700 |
commit | ed7dcd42076afe778e3ead8f86708cabd4e8ce10 (patch) | |
tree | 5ebeb419de0822d429b44f143cab54ce1b4f3f83 /tensorflow/stream_executor | |
parent | ed343f4a05ee16f3b354f647d89f21505ea45912 (diff) |
Zero out the result buffer for strided conv backward filter for NHWC layouts.
cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not zeroed.
PiperOrigin-RevId: 211905127
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 207f22c931..3c533c7f99 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -3275,6 +3275,26 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl( "This configuration potentially produces incorrect results."); }()); + // Zero out the result buffer for strided conv backward filter for NHWC + // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not + // zeroed. + // + // This wrong result caused by the bug is very flaky. It needs to be run for + // up to 20 times to produce a mismatch. + // + // TODO(timshen): add a nvbugs link. + if (CUDNN_VERSION >= 7100 && + algorithm_config.algorithm().algo_id() == + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 && + cudnn_type == CUDNN_DATA_HALF && + input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth && + filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput && + output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth && + (convolution_descriptor.vertical_filter_stride() > 1 || + convolution_descriptor.horizontal_filter_stride() > 1)) { + stream->ThenMemZero(backward_filter_data, backward_filter_data->size()); + } + RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter( cudnn.handle(), /*alpha=*/alpha, |