diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_dnn.cc')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index d78362d4fb..99bed86a17 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -232,6 +232,7 @@ CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) __macro(cudnnRNNBackwardData) \ __macro(cudnnRNNBackwardWeights) \ __macro(cudnnSetRNNDescriptor) \ + __macro(cudnnSetRNNDescriptor_v6) \ __macro(cudnnGetFilterNdDescriptor) // clang-format on @@ -244,8 +245,7 @@ CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) // clang-format off #if CUDNN_VERSION >= 6000 #define CUDNN_DNN_ROUTINE_EACH_R6(__macro) \ - __macro(cudnnConvolutionBiasActivationForward) \ - __macro(cudnnSetRNNDescriptor_v6) + __macro(cudnnConvolutionBiasActivationForward) // clang-format on CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) @@ -665,6 +665,7 @@ class ScopedPoolingDescriptor { LOG(FATAL) << "could not create cudnn pooling descriptor: " << ToString(status); } + const std::vector<int64> strides64 = pooling_descriptor.strides(); const std::vector<int64> padding64 = pooling_descriptor.padding(); const std::vector<int64> shape64 = pooling_descriptor.window(); @@ -679,14 +680,14 @@ class ScopedPoolingDescriptor { &CheckedNarrowing<int64, int>); std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), &CheckedNarrowing<int64, int>); - bool propagate_nans = pooling_descriptor.propagate_nans(); status = wrap::cudnnSetPoolingNdDescriptor( parent_, handle_, (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum ? CUDNN_POOLING_MAX : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING), #if CUDNN_VERSION >= 5000 - propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, + // Always propagate nans. + CUDNN_PROPAGATE_NAN, #endif nd, shape.data(), padding.data(), strides.data()); if (status != CUDNN_STATUS_SUCCESS) { |