From 4b052cc98201a9f07ff9e451913a8adfbb74ab11 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 18 May 2018 02:37:52 -0700 Subject: Dropping support for cuDNN < 6. Enable CUDNN_FFT_TILING_FORWARD for cuDNN >= 7. PiperOrigin-RevId: 197118212 --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 245 ++++------------------------ 1 file changed, 29 insertions(+), 216 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 7ace7fd303..d82d36c691 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -53,6 +53,8 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin); namespace { +static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher"); + // Converts (via narrowing) a type T value to a type U, and checks that the // value has no value change due to the conversion. template @@ -93,7 +95,6 @@ string ToString(cudnnStatus_t status) { } } -#if CUDNN_VERSION >= 6000 string ToString(libraryPropertyType type) { switch (type) { case MAJOR_VERSION: @@ -107,7 +108,6 @@ string ToString(libraryPropertyType type) { "(type), ">"); } } -#endif template cudnnDataType_t GetCudnnDataType(); @@ -213,12 +213,8 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) { case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT: case CUDNN_CONVOLUTION_FWD_ALGO_FFT: case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING: -#if CUDNN_VERSION >= 5000 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD: -#endif -#if CUDNN_VERSION >= 5100 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED: -#endif return algo; default: LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: " @@ -235,12 +231,8 @@ cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo( case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1: case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT: case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING: -#if CUDNN_VERSION >= 5000 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD: -#endif -#if CUDNN_VERSION >= 5100 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED: -#endif return algo; default: LOG(FATAL) @@ -258,11 +250,12 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo( case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3: -#if CUDNN_VERSION >= 5100 // Based on cudnn.h, the following is not implemented. // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED: -#endif + // Produces incorrect results for some shapes. Disabled for now, see + // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes. + // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING: return algo; default: LOG(FATAL) @@ -271,7 +264,6 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo( } } -#if CUDNN_VERSION >= 6000 port::Status GetCudnnProperty(libraryPropertyType type, int* value) { cudnnStatus_t status = cudnnGetProperty(type, value); if (status != CUDNN_STATUS_SUCCESS) { @@ -300,19 +292,11 @@ cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) { } } } -#endif port::Status GetLoadedCudnnVersion(CudnnVersion* version) { -#if CUDNN_VERSION >= 6000 TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version)); TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version)); TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level)); -#else - size_t loaded_version = ::cudnnGetVersion(); - version->major_version = loaded_version / 1000; - version->minor_version = (loaded_version / 100) % 10; - version->patch_level = loaded_version % 100; -#endif return port::Status::OK(); } @@ -418,7 +402,6 @@ class ScopedTensorDescriptor { << " to cudnn tensor descriptor: " << ToString(status); } } break; -#if CUDNN_VERSION >= 6000 case dnn::DataLayout::kBatchDepthYX4: { status = cudnnSetTensor4dDescriptor( handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type, @@ -430,7 +413,6 @@ class ScopedTensorDescriptor { << " to cudnn tensor descriptor: " << ToString(status); } } break; -#endif default: LOG(FATAL) << "Unsupported tensor format " << DataLayoutString(batch_descriptor.layout()); @@ -466,7 +448,6 @@ class ScopedFilterDescriptor { << ToString(status); } -#if CUDNN_VERSION >= 5000 // TODO(b/23032134): Even if the filter layout is not supported, // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it // does not take layout as an input. Maybe force cuDNN by giving wrong @@ -476,17 +457,14 @@ class ScopedFilterDescriptor { case dnn::FilterLayout::kOutputInputYX: format = CUDNN_TENSOR_NCHW; break; -#if CUDNN_VERSION >= 6000 case dnn::FilterLayout::kOutputInputYX4: format = CUDNN_TENSOR_NCHW_VECT_C; break; -#endif default: LOG(FATAL) << "Unsupported filter format " << FilterLayoutString(filter_descriptor.layout()); break; } -#endif std::vector dims(2 + filter_descriptor.ndims()); dims[0] = filter_descriptor.output_feature_map_count(); @@ -494,11 +472,8 @@ class ScopedFilterDescriptor { const auto& spatial_dims = filter_descriptor.input_filter_dims(); std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2); - status = cudnnSetFilterNdDescriptor(handle_, elem_type, -#if CUDNN_VERSION >= 5000 - format, -#endif - dims.size(), dims.data()); + status = cudnnSetFilterNdDescriptor(handle_, elem_type, format, dims.size(), + dims.data()); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not set cudnn filter descriptor: " << ToString(status); @@ -692,10 +667,8 @@ class ScopedPoolingDescriptor { (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum ? CUDNN_POOLING_MAX : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING), -#if CUDNN_VERSION >= 5000 - propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, -#endif - nd, shape.data(), padding.data(), strides.data()); + propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd, + shape.data(), padding.data(), strides.data()); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not set cudnn pooling descriptor: " << ToString(status); @@ -771,7 +744,6 @@ class ScopedNormalizeDescriptor { SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor); }; -#if CUDNN_VERSION >= 5000 // Turns a ActivationDescriptor structure into a cudnn activation // descriptor handle within a scope. class ScopedActivationDescriptor { @@ -834,7 +806,6 @@ class ScopedActivationDescriptor { SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor); }; -#endif cudnnDataType_t ToCudnnDataType( dnn::DataType data_type, @@ -844,18 +815,14 @@ cudnnDataType_t ToCudnnDataType( case dnn::DataType::kDouble: case dnn::DataType::kHalf: return static_cast(data_type); -#if CUDNN_VERSION >= 6000 case dnn::DataType::kInt8: return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4 : CUDNN_DATA_INT8; -#endif default: LOG(FATAL) << "Invalid DNN data type: " << static_cast(data_type); } } -#if CUDNN_VERSION >= 5000 - cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) { switch (input_mode) { case dnn::RnnInputMode::kRnnLinearSkip: @@ -903,15 +870,11 @@ int CudnnDataTypeToByteSize(cudnnDataType_t data_type) { } } -#endif // CUDNN_VERSION - template class MixinBase : public Base {}; template <> class MixinBase {}; -#if CUDNN_VERSION >= 5000 - #define CUDNN_RETURN_IF_FAIL(STATUS, ...) \ if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) { \ string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \ @@ -1042,9 +1005,7 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { hidden_size_(hidden_size), input_size_(input_size), batch_size_(batch_size), -#if CUDNN_VERSION >= 6000 rnn_plan_(nullptr), -#endif input_mode_(input_mode), direction_mode_(direction_mode), rnn_mode_(rnn_mode), @@ -1062,7 +1023,6 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { // Create the RNN handle cudnnStatus_t status = cudnnCreateRNNDescriptor(&rnn_desc_); CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor"); -#if CUDNN_VERSION >= 6000 // TODO: allow the user to choose an algorithm. rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm()); status = cudnnSetRNNDescriptor_v6( @@ -1084,16 +1044,6 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { status = cudnnSetPersistentRNNPlan(rnn_desc_, rnn_plan_); CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan."); } -#else - CHECK(algorithm_config_.is_default()) - << "Non-default algorithm not supported for CUDA version < 6.0"; - status = cudnnSetRNNDescriptor( - /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size, - /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(), - /*inputMode=*/input_mode, /*direction=*/direction_mode, - /*mode=*/rnn_mode, /*dataType=*/compute_type); - CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor"); -#endif // Create the params handle. cudnn_params_desc_.reset(new CudnnRnnParamsDescriptor(cudnn, *this)); @@ -1106,12 +1056,10 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { ~CudnnRnnDescriptor() override { if (rnn_desc_) { cudnnStatus_t status; -#if CUDNN_VERSION >= 6000 if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) { status = cudnnDestroyPersistentRNNPlan(rnn_plan_); CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan."); } -#endif status = cudnnDestroyRNNDescriptor(rnn_desc_); CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor"); } @@ -1172,10 +1120,8 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC // algorithm. int batch_size_; -#if CUDNN_VERSION >= 6000 cudnnRNNAlgo_t rnn_algo_; cudnnPersistentRNNPlan_t rnn_plan_; -#endif cudnnRNNInputMode_t input_mode_; cudnnDirectionMode_t direction_mode_; cudnnRNNMode_t rnn_mode_; @@ -1806,8 +1752,6 @@ bool CudnnSupport::DoRnnBackwardImpl( return true; } -#endif // CUDNN_VERSION - port::StatusOr> CudnnSupport::createRnnDescriptor( int num_layers, int hidden_size, int input_size, int batch_size, @@ -1815,7 +1759,6 @@ CudnnSupport::createRnnDescriptor( dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, ScratchAllocator* state_allocator) { -#if CUDNN_VERSION >= 5000 // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's // not enqueueing anything into a stream, we pass in the null stream. auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr); @@ -1830,20 +1773,12 @@ CudnnSupport::createRnnDescriptor( } return port::StatusOr>( std::move(rnn_desc)); -#else - string error_msg = - port::StrCat("createRnnDescriptor needs at least Cudnn 5.0 to work. ", - "Current Cudnn version: ", CUDNN_VERSION, ". "); - LOG(ERROR) << error_msg; - return port::Status(port::error::UNIMPLEMENTED, error_msg); -#endif // CUDNN_VERSION } port::StatusOr> CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, int data_size, dnn::DataType data_type) { -#if CUDNN_VERSION >= 5000 std::unique_ptr seq_desc( new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size, data_size, @@ -1853,20 +1788,12 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, } return port::StatusOr>( std::move(seq_desc)); -#else - string error_msg = port::StrCat( - "createRnnSequenceTensorDescriptor needs at least Cudnn 5.0 to work. ", - "Current Cudnn version: ", CUDNN_VERSION, ". "); - LOG(ERROR) << error_msg; - return port::Status(port::error::UNIMPLEMENTED, error_msg); -#endif // CUDNN_VERSION } port::StatusOr> CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) { -#if CUDNN_VERSION >= 5000 std::unique_ptr state_desc( new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size, data_size, ToCudnnDataType(data_type))); @@ -1875,13 +1802,6 @@ CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, } return port::StatusOr>( std::move(state_desc)); -#else - string error_msg = port::StrCat( - "createRnnStateTensorDescriptor needs at least Cudnn 5.0 to work. ", - "Current Cudnn version: ", CUDNN_VERSION, ". "); - LOG(ERROR) << error_msg; - return port::Status(port::error::UNIMPLEMENTED, error_msg); -#endif // CUDNN_VERSION } bool CudnnSupport::DoRnnForward( @@ -1902,7 +1822,6 @@ bool CudnnSupport::DoRnnForward( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -1924,9 +1843,6 @@ bool CudnnSupport::DoRnnForward( output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, output_c_data, is_training, reserve_space_allocator, workspace_allocator, output_profile_result); -#else - return false; -#endif // CUDNN_VERSION } bool CudnnSupport::DoRnnForward( @@ -1946,7 +1862,6 @@ bool CudnnSupport::DoRnnForward( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -1968,9 +1883,6 @@ bool CudnnSupport::DoRnnForward( output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, output_c_data, is_training, reserve_space_allocator, workspace_allocator, output_profile_result); -#else - return false; -#endif // CUDNN_VERSION } bool CudnnSupport::DoRnnForward( @@ -1991,7 +1903,6 @@ bool CudnnSupport::DoRnnForward( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -2013,9 +1924,6 @@ bool CudnnSupport::DoRnnForward( output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, output_c_data, is_training, reserve_space_allocator, workspace_allocator, output_profile_result); -#else - return false; -#endif // CUDNN_VERSION } bool CudnnSupport::DoRnnBackward( @@ -2043,7 +1951,6 @@ bool CudnnSupport::DoRnnBackward( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -2067,9 +1974,6 @@ bool CudnnSupport::DoRnnBackward( output_c_backprop_data, input_backprop_data, input_h_backprop_data, input_c_backprop_data, params_backprop_data, reserve_space_data, workspace_allocator, output_profile_result); -#else - return false; -#endif // CUDNN_VERSION } bool CudnnSupport::DoRnnBackward( @@ -2096,7 +2000,6 @@ bool CudnnSupport::DoRnnBackward( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -2120,9 +2023,6 @@ bool CudnnSupport::DoRnnBackward( output_c_backprop_data, input_backprop_data, input_h_backprop_data, input_c_backprop_data, params_backprop_data, reserve_space_data, workspace_allocator, output_profile_result); -#else - return false; -#endif // CUDNN_VERSION } bool CudnnSupport::DoRnnBackward( @@ -2150,7 +2050,6 @@ bool CudnnSupport::DoRnnBackward( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION >= 5000 const CudnnRnnDescriptor& cudnn_rnn_desc = static_cast(rnn_desc); const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc = @@ -2174,9 +2073,6 @@ bool CudnnSupport::DoRnnBackward( output_c_backprop_data, input_backprop_data, input_h_backprop_data, input_c_backprop_data, params_backprop_data, reserve_space_data, workspace_allocator, output_profile_result); -#else - return false; -#endif // CUDNN_VERSION } namespace { @@ -2311,16 +2207,12 @@ class CudnnEnvVar { }; // A helper struct to decide whether to enable the FFT_TILING algorithms for -// forward convolution. Before cudnn v5.1 it works fine but since cudnn v5.1 -// it is turned off due to memory corruption caused by some shapes with this -// algorithm. -// Before NVIDIA fixes the memory corruption bug, users can explicitly -// enable the algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1". +// forward convolution. It is disabled for cuDNN < 7 due to memory corruption +// caused by some shapes with this algorithm. Users can explicitly enable the +// algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1". struct FftTilingForward { static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD"; - // TODO(yangzihao): turn the default to True when the memory corruption bug - // is fixed. - static constexpr bool kDefaultFlag = CUDNN_VERSION < 5100; + static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000; }; // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms. @@ -2329,10 +2221,9 @@ struct FftTilingForward { // https://github.com/tensorflow/tensorflow/pull/4901 struct WinogradNonfused { static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED"; - // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. - // For cudnn v>=5.1, we have a workaround and for any lower version, we - // disable it by default. - static constexpr bool kDefaultFlag = CUDNN_VERSION >= 5100; + // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions, + // we have a workaround. + static constexpr bool kDefaultFlag = true; }; // A helper struct to decide whether to use FP32 as the internal compute type @@ -2621,11 +2512,6 @@ bool CudnnSupport::DoFusedConvolveImpl( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION < 6000 - LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only " - "supported for cuDNN version >= 6"; - return false; -#else ScopedTensorDescriptor conv_input_nd( conv_input_descriptor, static_cast(cudnn_data_type)); ScopedTensorDescriptor output_nd( @@ -2732,32 +2618,27 @@ bool CudnnSupport::DoFusedConvolveImpl( } return true; -#endif // CUDNN_VERSION < 6000 } bool CudnnSupport::GetConvolveAlgorithms( bool with_winograd_nonfused, int cc_major, int cc_minor, std::vector* out_algorithms) { std::vector algo_types = { - // clang-format off + // clang-format off CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, CUDNN_CONVOLUTION_FWD_ALGO_FFT, -#if CUDNN_VERSION >= 5000 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, -#endif - // clang-format on + // clang-format on }; if (CudnnEnvVar::IsEnabled()) { algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING); } -#if CUDNN_VERSION >= 5100 if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED); } -#endif out_algorithms->clear(); for (auto i : algo_types) { @@ -2772,13 +2653,11 @@ bool CudnnSupport::GetConvolveAlgorithms( bool CudnnSupport::GetRnnAlgorithms( std::vector* out_algorithms) { std::vector algo_types = { - // clang-format off -#if CUDNN_VERSION >= 6000 + // clang-format off CUDNN_RNN_ALGO_STANDARD, CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC, -#endif - // clang-format on + // clang-format on }; out_algorithms->clear(); @@ -2797,21 +2676,17 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( bool with_winograd_nonfused, int cc_major, int cc_minor, std::vector* out_algorithms) { std::vector algo_types = { - // clang-format off + // clang-format off CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, -#if CUDNN_VERSION >= 5000 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, -#endif - // clang-format on + // clang-format on }; -#if CUDNN_VERSION >= 5100 if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); } -#endif out_algorithms->clear(); for (auto i : algo_types) { @@ -2834,13 +2709,15 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, // Based on cudnn.h, the following is not implemented. // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD, + + // Produces incorrect results for some shapes. Disabled for now, see + // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes. + // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, // clang-format on }; -#if CUDNN_VERSION >= 5100 if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED); } -#endif out_algorithms->clear(); for (auto i : algo_types) { @@ -2939,17 +2816,8 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(), saved_inv_var->opaque()); -#if CUDNN_VERSION < 5000 - CHECK(inv_var_to_var); - inv_var_to_var(); -#endif } else { -#if CUDNN_VERSION < 5000 - CHECK(var_to_inv_var); - const void* maybe_inv_var = var_to_inv_var().opaque(); -#else const void* maybe_inv_var = estimated_variance.opaque(); -#endif status = cudnnBatchNormalizationForwardInference( cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), @@ -3159,11 +3027,6 @@ bool CudnnSupport::DoFusedConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { -#if CUDNN_VERSION < 6000 - LOG(WARNING) << "cudnnConvolutionBiasActivationForward() is only " - "supported for cuDNN version >= 6"; - return false; -#else int cc_major, cc_minor; stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); @@ -3179,7 +3042,6 @@ bool CudnnSupport::DoFusedConvolve( side_input_scale, bias_descriptor, biases, activation_mode, output_descriptor, output_data, scratch_allocator, algorithm_config, output_profile_result); -#endif } namespace { @@ -3428,13 +3290,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( timer->Start(AsCUDAStream(stream)); } -#if CUDNN_VERSION >= 5000 auto status = cudnnConvolutionBackwardData(cudnn.handle(), -#else - auto status = - cudnnConvolutionBackwardData_v3(cudnn.handle(), -#endif /*alpha=*/alpha, /*wDesc=*/filter.handle(), /*w=*/filter_data.opaque(), @@ -3697,13 +3554,8 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( timer->Start(AsCUDAStream(stream)); } -#if CUDNN_VERSION >= 5000 auto status = cudnnConvolutionBackwardFilter( cudnn.handle(), -#else - auto status = cudnnConvolutionBackwardFilter_v3( - cudnn.handle(), -#endif /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(), /*srcData=*/input_data.opaque(), @@ -4016,11 +3868,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, auto cudnn = cudnn_->GetHandle(parent_, stream); -#if CUDNN_VERSION >= 5000 auto status = cudnnAddTensor( -#else - auto status = cudnnAddTensor_v3( -#endif cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), &beta, input_descriptor.handle(), output_data->opaque()); @@ -4038,37 +3886,8 @@ bool CudnnSupport::DoActivate(Stream* stream, const DeviceMemory& input_data, DeviceMemory* output_data, uint64 options) { -#if CUDNN_VERSION >= 5000 ScopedActivationDescriptor activation_desc( activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max()); -#else - cudnnActivationMode_t mode; - switch (activation_mode) { - case dnn::ActivationMode::kRelu6: - // TODO(leary) should probably do a post-pass to clip at 6? - LOG(WARNING) << "user requested Relu6, but providing Relu instead"; - mode = CUDNN_ACTIVATION_RELU; - break; - case dnn::ActivationMode::kReluX: - // TODO(broune) should probably do a post-pass to clip at X? - LOG(WARNING) << "user requested ReluX, but providing Relu instead"; - mode = CUDNN_ACTIVATION_RELU; - break; - case dnn::ActivationMode::kRelu: - mode = CUDNN_ACTIVATION_RELU; - break; - case dnn::ActivationMode::kSigmoid: - mode = CUDNN_ACTIVATION_SIGMOID; - break; - case dnn::ActivationMode::kTanh: - mode = CUDNN_ACTIVATION_TANH; - break; - default: - LOG(ERROR) << "unrecognized activation mode: " - << static_cast(activation_mode); - return false; - } -#endif ScopedTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT); // Alpha is the input scaling factor. @@ -4077,15 +3896,9 @@ bool CudnnSupport::DoActivate(Stream* stream, float beta = 0.0; auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = - cudnnActivationForward(cudnn.handle(), -#if CUDNN_VERSION >= 5000 - activation_desc.handle(), -#else - mode, -#endif - &alpha, input_nd.handle(), input_data.opaque(), - &beta, input_nd.handle(), output_data->opaque()); + auto status = cudnnActivationForward( + cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(), + input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "stream " << stream << " could not enqueue activation: " << ToString(status); -- cgit v1.2.3