aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-18 10:42:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-18 10:45:22 -0700
commit4fce5d6c88982ca06d16b55fac98cb29d0a87081 (patch)
tree7e07b76461580d672e1052d6459b885a63cbe50a /tensorflow/stream_executor
parent487d2ab835286f4eea891d93bc32adfd5543aef8 (diff)
Automated g4 rollback of changelist 197118212
PiperOrigin-RevId: 197167501
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc245
1 files changed, 216 insertions, 29 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index d82d36c691..7ace7fd303 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -53,8 +53,6 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
namespace {
-static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher");
-
// Converts (via narrowing) a type T value to a type U, and checks that the
// value has no value change due to the conversion.
template <typename WideT, typename NarrowT>
@@ -95,6 +93,7 @@ string ToString(cudnnStatus_t status) {
}
}
+#if CUDNN_VERSION >= 6000
string ToString(libraryPropertyType type) {
switch (type) {
case MAJOR_VERSION:
@@ -108,6 +107,7 @@ string ToString(libraryPropertyType type) {
"<unknown libraryPropertyType: ", static_cast<int>(type), ">");
}
}
+#endif
template <typename T>
cudnnDataType_t GetCudnnDataType();
@@ -213,8 +213,12 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
+#if CUDNN_VERSION >= 5000
case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
+#endif
+#if CUDNN_VERSION >= 5100
case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
+#endif
return algo;
default:
LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
@@ -231,8 +235,12 @@ cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
+#if CUDNN_VERSION >= 5000
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
+#endif
+#if CUDNN_VERSION >= 5100
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
+#endif
return algo;
default:
LOG(FATAL)
@@ -250,12 +258,11 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
+#if CUDNN_VERSION >= 5100
// Based on cudnn.h, the following is not implemented.
// case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
- // Produces incorrect results for some shapes. Disabled for now, see
- // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
- // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
+#endif
return algo;
default:
LOG(FATAL)
@@ -264,6 +271,7 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
}
}
+#if CUDNN_VERSION >= 6000
port::Status GetCudnnProperty(libraryPropertyType type, int* value) {
cudnnStatus_t status = cudnnGetProperty(type, value);
if (status != CUDNN_STATUS_SUCCESS) {
@@ -292,11 +300,19 @@ cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) {
}
}
}
+#endif
port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
+#if CUDNN_VERSION >= 6000
TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version));
TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version));
TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level));
+#else
+ size_t loaded_version = ::cudnnGetVersion();
+ version->major_version = loaded_version / 1000;
+ version->minor_version = (loaded_version / 100) % 10;
+ version->patch_level = loaded_version % 100;
+#endif
return port::Status::OK();
}
@@ -402,6 +418,7 @@ class ScopedTensorDescriptor {
<< " to cudnn tensor descriptor: " << ToString(status);
}
} break;
+#if CUDNN_VERSION >= 6000
case dnn::DataLayout::kBatchDepthYX4: {
status = cudnnSetTensor4dDescriptor(
handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type,
@@ -413,6 +430,7 @@ class ScopedTensorDescriptor {
<< " to cudnn tensor descriptor: " << ToString(status);
}
} break;
+#endif
default:
LOG(FATAL) << "Unsupported tensor format "
<< DataLayoutString(batch_descriptor.layout());
@@ -448,6 +466,7 @@ class ScopedFilterDescriptor {
<< ToString(status);
}
+#if CUDNN_VERSION >= 5000
// TODO(b/23032134): Even if the filter layout is not supported,
// cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it
// does not take layout as an input. Maybe force cuDNN by giving wrong
@@ -457,14 +476,17 @@ class ScopedFilterDescriptor {
case dnn::FilterLayout::kOutputInputYX:
format = CUDNN_TENSOR_NCHW;
break;
+#if CUDNN_VERSION >= 6000
case dnn::FilterLayout::kOutputInputYX4:
format = CUDNN_TENSOR_NCHW_VECT_C;
break;
+#endif
default:
LOG(FATAL) << "Unsupported filter format "
<< FilterLayoutString(filter_descriptor.layout());
break;
}
+#endif
std::vector<int> dims(2 + filter_descriptor.ndims());
dims[0] = filter_descriptor.output_feature_map_count();
@@ -472,8 +494,11 @@ class ScopedFilterDescriptor {
const auto& spatial_dims = filter_descriptor.input_filter_dims();
std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
- status = cudnnSetFilterNdDescriptor(handle_, elem_type, format, dims.size(),
- dims.data());
+ status = cudnnSetFilterNdDescriptor(handle_, elem_type,
+#if CUDNN_VERSION >= 5000
+ format,
+#endif
+ dims.size(), dims.data());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn filter descriptor: "
<< ToString(status);
@@ -667,8 +692,10 @@ class ScopedPoolingDescriptor {
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
? CUDNN_POOLING_MAX
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
- propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
- shape.data(), padding.data(), strides.data());
+#if CUDNN_VERSION >= 5000
+ propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN,
+#endif
+ nd, shape.data(), padding.data(), strides.data());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn pooling descriptor: "
<< ToString(status);
@@ -744,6 +771,7 @@ class ScopedNormalizeDescriptor {
SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
};
+#if CUDNN_VERSION >= 5000
// Turns a ActivationDescriptor structure into a cudnn activation
// descriptor handle within a scope.
class ScopedActivationDescriptor {
@@ -806,6 +834,7 @@ class ScopedActivationDescriptor {
SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
};
+#endif
cudnnDataType_t ToCudnnDataType(
dnn::DataType data_type,
@@ -815,14 +844,18 @@ cudnnDataType_t ToCudnnDataType(
case dnn::DataType::kDouble:
case dnn::DataType::kHalf:
return static_cast<cudnnDataType_t>(data_type);
+#if CUDNN_VERSION >= 6000
case dnn::DataType::kInt8:
return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
: CUDNN_DATA_INT8;
+#endif
default:
LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
}
}
+#if CUDNN_VERSION >= 5000
+
cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
switch (input_mode) {
case dnn::RnnInputMode::kRnnLinearSkip:
@@ -870,11 +903,15 @@ int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
}
}
+#endif // CUDNN_VERSION
+
template <typename Base>
class MixinBase : public Base {};
template <>
class MixinBase<void> {};
+#if CUDNN_VERSION >= 5000
+
#define CUDNN_RETURN_IF_FAIL(STATUS, ...) \
if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) { \
string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
@@ -1005,7 +1042,9 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
hidden_size_(hidden_size),
input_size_(input_size),
batch_size_(batch_size),
+#if CUDNN_VERSION >= 6000
rnn_plan_(nullptr),
+#endif
input_mode_(input_mode),
direction_mode_(direction_mode),
rnn_mode_(rnn_mode),
@@ -1023,6 +1062,7 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
// Create the RNN handle
cudnnStatus_t status = cudnnCreateRNNDescriptor(&rnn_desc_);
CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
+#if CUDNN_VERSION >= 6000
// TODO: allow the user to choose an algorithm.
rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm());
status = cudnnSetRNNDescriptor_v6(
@@ -1044,6 +1084,16 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
status = cudnnSetPersistentRNNPlan(rnn_desc_, rnn_plan_);
CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan.");
}
+#else
+ CHECK(algorithm_config_.is_default())
+ << "Non-default algorithm not supported for CUDA version < 6.0";
+ status = cudnnSetRNNDescriptor(
+ /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size,
+ /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(),
+ /*inputMode=*/input_mode, /*direction=*/direction_mode,
+ /*mode=*/rnn_mode, /*dataType=*/compute_type);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor");
+#endif
// Create the params handle.
cudnn_params_desc_.reset(new CudnnRnnParamsDescriptor(cudnn, *this));
@@ -1056,10 +1106,12 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
~CudnnRnnDescriptor() override {
if (rnn_desc_) {
cudnnStatus_t status;
+#if CUDNN_VERSION >= 6000
if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) {
status = cudnnDestroyPersistentRNNPlan(rnn_plan_);
CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan.");
}
+#endif
status = cudnnDestroyRNNDescriptor(rnn_desc_);
CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
}
@@ -1120,8 +1172,10 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
// batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
// algorithm.
int batch_size_;
+#if CUDNN_VERSION >= 6000
cudnnRNNAlgo_t rnn_algo_;
cudnnPersistentRNNPlan_t rnn_plan_;
+#endif
cudnnRNNInputMode_t input_mode_;
cudnnDirectionMode_t direction_mode_;
cudnnRNNMode_t rnn_mode_;
@@ -1752,6 +1806,8 @@ bool CudnnSupport::DoRnnBackwardImpl(
return true;
}
+#endif // CUDNN_VERSION
+
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
CudnnSupport::createRnnDescriptor(
int num_layers, int hidden_size, int input_size, int batch_size,
@@ -1759,6 +1815,7 @@ CudnnSupport::createRnnDescriptor(
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
ScratchAllocator* state_allocator) {
+#if CUDNN_VERSION >= 5000
// Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
// not enqueueing anything into a stream, we pass in the null stream.
auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
@@ -1773,12 +1830,20 @@ CudnnSupport::createRnnDescriptor(
}
return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
std::move(rnn_desc));
+#else
+ string error_msg =
+ port::StrCat("createRnnDescriptor needs at least Cudnn 5.0 to work. ",
+ "Current Cudnn version: ", CUDNN_VERSION, ". ");
+ LOG(ERROR) << error_msg;
+ return port::Status(port::error::UNIMPLEMENTED, error_msg);
+#endif // CUDNN_VERSION
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
int data_size,
dnn::DataType data_type) {
+#if CUDNN_VERSION >= 5000
std::unique_ptr<CudnnRnnSequenceTensorDescriptor> seq_desc(
new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size,
data_size,
@@ -1788,12 +1853,20 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
}
return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
std::move(seq_desc));
+#else
+ string error_msg = port::StrCat(
+ "createRnnSequenceTensorDescriptor needs at least Cudnn 5.0 to work. ",
+ "Current Cudnn version: ", CUDNN_VERSION, ". ");
+ LOG(ERROR) << error_msg;
+ return port::Status(port::error::UNIMPLEMENTED, error_msg);
+#endif // CUDNN_VERSION
}
port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
int data_size,
dnn::DataType data_type) {
+#if CUDNN_VERSION >= 5000
std::unique_ptr<CudnnRnnStateTensorDescriptor> state_desc(
new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
data_size, ToCudnnDataType(data_type)));
@@ -1802,6 +1875,13 @@ CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
}
return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
std::move(state_desc));
+#else
+ string error_msg = port::StrCat(
+ "createRnnStateTensorDescriptor needs at least Cudnn 5.0 to work. ",
+ "Current Cudnn version: ", CUDNN_VERSION, ". ");
+ LOG(ERROR) << error_msg;
+ return port::Status(port::error::UNIMPLEMENTED, error_msg);
+#endif // CUDNN_VERSION
}
bool CudnnSupport::DoRnnForward(
@@ -1822,6 +1902,7 @@ bool CudnnSupport::DoRnnForward(
ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -1843,6 +1924,9 @@ bool CudnnSupport::DoRnnForward(
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
output_c_data, is_training, reserve_space_allocator, workspace_allocator,
output_profile_result);
+#else
+ return false;
+#endif // CUDNN_VERSION
}
bool CudnnSupport::DoRnnForward(
@@ -1862,6 +1946,7 @@ bool CudnnSupport::DoRnnForward(
ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -1883,6 +1968,9 @@ bool CudnnSupport::DoRnnForward(
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
output_c_data, is_training, reserve_space_allocator, workspace_allocator,
output_profile_result);
+#else
+ return false;
+#endif // CUDNN_VERSION
}
bool CudnnSupport::DoRnnForward(
@@ -1903,6 +1991,7 @@ bool CudnnSupport::DoRnnForward(
ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -1924,6 +2013,9 @@ bool CudnnSupport::DoRnnForward(
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
output_c_data, is_training, reserve_space_allocator, workspace_allocator,
output_profile_result);
+#else
+ return false;
+#endif // CUDNN_VERSION
}
bool CudnnSupport::DoRnnBackward(
@@ -1951,6 +2043,7 @@ bool CudnnSupport::DoRnnBackward(
DeviceMemory<uint8>* reserve_space_data,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -1974,6 +2067,9 @@ bool CudnnSupport::DoRnnBackward(
output_c_backprop_data, input_backprop_data, input_h_backprop_data,
input_c_backprop_data, params_backprop_data, reserve_space_data,
workspace_allocator, output_profile_result);
+#else
+ return false;
+#endif // CUDNN_VERSION
}
bool CudnnSupport::DoRnnBackward(
@@ -2000,6 +2096,7 @@ bool CudnnSupport::DoRnnBackward(
DeviceMemory<uint8>* reserve_space_data,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -2023,6 +2120,9 @@ bool CudnnSupport::DoRnnBackward(
output_c_backprop_data, input_backprop_data, input_h_backprop_data,
input_c_backprop_data, params_backprop_data, reserve_space_data,
workspace_allocator, output_profile_result);
+#else
+ return false;
+#endif // CUDNN_VERSION
}
bool CudnnSupport::DoRnnBackward(
@@ -2050,6 +2150,7 @@ bool CudnnSupport::DoRnnBackward(
DeviceMemory<uint8>* reserve_space_data,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
@@ -2073,6 +2174,9 @@ bool CudnnSupport::DoRnnBackward(
output_c_backprop_data, input_backprop_data, input_h_backprop_data,
input_c_backprop_data, params_backprop_data, reserve_space_data,
workspace_allocator, output_profile_result);
+#else
+ return false;
+#endif // CUDNN_VERSION
}
namespace {
@@ -2207,12 +2311,16 @@ class CudnnEnvVar {
};
// A helper struct to decide whether to enable the FFT_TILING algorithms for
-// forward convolution. It is disabled for cuDNN < 7 due to memory corruption
-// caused by some shapes with this algorithm. Users can explicitly enable the
-// algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
+// forward convolution. Before cudnn v5.1 it works fine but since cudnn v5.1
+// it is turned off due to memory corruption caused by some shapes with this
+// algorithm.
+// Before NVIDIA fixes the memory corruption bug, users can explicitly
+// enable the algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
struct FftTilingForward {
static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
- static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000;
+ // TODO(yangzihao): turn the default to True when the memory corruption bug
+ // is fixed.
+ static constexpr bool kDefaultFlag = CUDNN_VERSION < 5100;
};
// A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
@@ -2221,9 +2329,10 @@ struct FftTilingForward {
// https://github.com/tensorflow/tensorflow/pull/4901
struct WinogradNonfused {
static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
- // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
- // we have a workaround.
- static constexpr bool kDefaultFlag = true;
+ // NVIDIA has fixed winograd nonfused bug for cudnn v>=7.
+ // For cudnn v>=5.1, we have a workaround and for any lower version, we
+ // disable it by default.
+ static constexpr bool kDefaultFlag = CUDNN_VERSION >= 5100;
};
// A helper struct to decide whether to use FP32 as the internal compute type
@@ -2512,6 +2621,11 @@ bool CudnnSupport::DoFusedConvolveImpl(
DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION < 6000
+ LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only "
+ "supported for cuDNN version >= 6";
+ return false;
+#else
ScopedTensorDescriptor conv_input_nd(
conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
ScopedTensorDescriptor output_nd(
@@ -2618,27 +2732,32 @@ bool CudnnSupport::DoFusedConvolveImpl(
}
return true;
+#endif // CUDNN_VERSION < 6000
}
bool CudnnSupport::GetConvolveAlgorithms(
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
- // clang-format off
+ // clang-format off
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
+#if CUDNN_VERSION >= 5000
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
- // clang-format on
+#endif
+ // clang-format on
};
if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
}
+#if CUDNN_VERSION >= 5100
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
}
+#endif
out_algorithms->clear();
for (auto i : algo_types) {
@@ -2653,11 +2772,13 @@ bool CudnnSupport::GetConvolveAlgorithms(
bool CudnnSupport::GetRnnAlgorithms(
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
- // clang-format off
+ // clang-format off
+#if CUDNN_VERSION >= 6000
CUDNN_RNN_ALGO_STANDARD,
CUDNN_RNN_ALGO_PERSIST_STATIC,
CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
- // clang-format on
+#endif
+ // clang-format on
};
out_algorithms->clear();
@@ -2676,17 +2797,21 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
- // clang-format off
+ // clang-format off
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
+#if CUDNN_VERSION >= 5000
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
- // clang-format on
+#endif
+ // clang-format on
};
+#if CUDNN_VERSION >= 5100
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
}
+#endif
out_algorithms->clear();
for (auto i : algo_types) {
@@ -2709,15 +2834,13 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
// Based on cudnn.h, the following is not implemented.
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
-
- // Produces incorrect results for some shapes. Disabled for now, see
- // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
- // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
// clang-format on
};
+#if CUDNN_VERSION >= 5100
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
}
+#endif
out_algorithms->clear();
for (auto i : algo_types) {
@@ -2816,8 +2939,17 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque,
batch_var_opaque, epsilon, saved_mean->opaque(),
saved_inv_var->opaque());
+#if CUDNN_VERSION < 5000
+ CHECK(inv_var_to_var);
+ inv_var_to_var();
+#endif
} else {
+#if CUDNN_VERSION < 5000
+ CHECK(var_to_inv_var);
+ const void* maybe_inv_var = var_to_inv_var().opaque();
+#else
const void* maybe_inv_var = estimated_variance.opaque();
+#endif
status = cudnnBatchNormalizationForwardInference(
cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
@@ -3027,6 +3159,11 @@ bool CudnnSupport::DoFusedConvolve(
DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION < 6000
+ LOG(WARNING) << "cudnnConvolutionBiasActivationForward() is only "
+ "supported for cuDNN version >= 6";
+ return false;
+#else
int cc_major, cc_minor;
stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
&cc_minor);
@@ -3042,6 +3179,7 @@ bool CudnnSupport::DoFusedConvolve(
side_input_scale, bias_descriptor, biases, activation_mode,
output_descriptor, output_data, scratch_allocator, algorithm_config,
output_profile_result);
+#endif
}
namespace {
@@ -3290,8 +3428,13 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
timer->Start(AsCUDAStream(stream));
}
+#if CUDNN_VERSION >= 5000
auto status =
cudnnConvolutionBackwardData(cudnn.handle(),
+#else
+ auto status =
+ cudnnConvolutionBackwardData_v3(cudnn.handle(),
+#endif
/*alpha=*/alpha,
/*wDesc=*/filter.handle(),
/*w=*/filter_data.opaque(),
@@ -3554,8 +3697,13 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
timer->Start(AsCUDAStream(stream));
}
+#if CUDNN_VERSION >= 5000
auto status = cudnnConvolutionBackwardFilter(
cudnn.handle(),
+#else
+ auto status = cudnnConvolutionBackwardFilter_v3(
+ cudnn.handle(),
+#endif
/*alpha=*/alpha,
/*srcDesc=*/input_nd.handle(),
/*srcData=*/input_data.opaque(),
@@ -3868,7 +4016,11 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
auto cudnn = cudnn_->GetHandle(parent_, stream);
+#if CUDNN_VERSION >= 5000
auto status = cudnnAddTensor(
+#else
+ auto status = cudnnAddTensor_v3(
+#endif
cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), &beta,
input_descriptor.handle(), output_data->opaque());
@@ -3886,8 +4038,37 @@ bool CudnnSupport::DoActivate(Stream* stream,
const DeviceMemory<float>& input_data,
DeviceMemory<float>* output_data,
uint64 options) {
+#if CUDNN_VERSION >= 5000
ScopedActivationDescriptor activation_desc(
activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
+#else
+ cudnnActivationMode_t mode;
+ switch (activation_mode) {
+ case dnn::ActivationMode::kRelu6:
+ // TODO(leary) should probably do a post-pass to clip at 6?
+ LOG(WARNING) << "user requested Relu6, but providing Relu instead";
+ mode = CUDNN_ACTIVATION_RELU;
+ break;
+ case dnn::ActivationMode::kReluX:
+ // TODO(broune) should probably do a post-pass to clip at X?
+ LOG(WARNING) << "user requested ReluX, but providing Relu instead";
+ mode = CUDNN_ACTIVATION_RELU;
+ break;
+ case dnn::ActivationMode::kRelu:
+ mode = CUDNN_ACTIVATION_RELU;
+ break;
+ case dnn::ActivationMode::kSigmoid:
+ mode = CUDNN_ACTIVATION_SIGMOID;
+ break;
+ case dnn::ActivationMode::kTanh:
+ mode = CUDNN_ACTIVATION_TANH;
+ break;
+ default:
+ LOG(ERROR) << "unrecognized activation mode: "
+ << static_cast<int>(activation_mode);
+ return false;
+ }
+#endif
ScopedTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
// Alpha is the input scaling factor.
@@ -3896,9 +4077,15 @@ bool CudnnSupport::DoActivate(Stream* stream,
float beta = 0.0;
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnActivationForward(
- cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
- input_data.opaque(), &beta, input_nd.handle(), output_data->opaque());
+ auto status =
+ cudnnActivationForward(cudnn.handle(),
+#if CUDNN_VERSION >= 5000
+ activation_desc.handle(),
+#else
+ mode,
+#endif
+ &alpha, input_nd.handle(), input_data.opaque(),
+ &beta, input_nd.handle(), output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "stream " << stream
<< " could not enqueue activation: " << ToString(status);