diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_dnn.cc')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 99bed86a17..d78362d4fb 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -232,7 +232,6 @@ CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) __macro(cudnnRNNBackwardData) \ __macro(cudnnRNNBackwardWeights) \ __macro(cudnnSetRNNDescriptor) \ - __macro(cudnnSetRNNDescriptor_v6) \ __macro(cudnnGetFilterNdDescriptor) // clang-format on @@ -245,7 +244,8 @@ CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) // clang-format off #if CUDNN_VERSION >= 6000 #define CUDNN_DNN_ROUTINE_EACH_R6(__macro) \ - __macro(cudnnConvolutionBiasActivationForward) + __macro(cudnnConvolutionBiasActivationForward) \ + __macro(cudnnSetRNNDescriptor_v6) // clang-format on CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) @@ -665,7 +665,6 @@ class ScopedPoolingDescriptor { LOG(FATAL) << "could not create cudnn pooling descriptor: " << ToString(status); } - const std::vector<int64> strides64 = pooling_descriptor.strides(); const std::vector<int64> padding64 = pooling_descriptor.padding(); const std::vector<int64> shape64 = pooling_descriptor.window(); @@ -680,14 +679,14 @@ class ScopedPoolingDescriptor { &CheckedNarrowing<int64, int>); std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), &CheckedNarrowing<int64, int>); + bool propagate_nans = pooling_descriptor.propagate_nans(); status = wrap::cudnnSetPoolingNdDescriptor( parent_, handle_, (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum ? CUDNN_POOLING_MAX : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING), #if CUDNN_VERSION >= 5000 - // Always propagate nans. - CUDNN_PROPAGATE_NAN, + propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, #endif nd, shape.data(), padding.data(), strides.data()); if (status != CUDNN_STATUS_SUCCESS) { |