diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_dnn.cc')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 87 |
1 files changed, 60 insertions, 27 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index ab5e6590e0..1aea0485fd 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -18,7 +18,9 @@ limitations under the License. #include <functional> #include <memory> +#include "absl/strings/str_cat.h" #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" @@ -27,6 +29,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/cuda/cuda_timer.h" +#include "tensorflow/stream_executor/cuda/cudnn_version.h" #include "tensorflow/stream_executor/dnn.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/error.h" @@ -55,15 +58,6 @@ NarrowT CheckedNarrowing(const WideT& wide) { return narrow; } -// Returns the "Compatibility" version number from the CuDNN version number. -// This is the number that tries to indicate ABI compatibility. -// -// For example, if cudnn_version is 5107, the compatibility version -// number will be 5100. -size_t cudnnCompatibilityVersion(size_t cudnn_version) { - return (cudnn_version / 100) * 100; -} - } // namespace namespace perftools { @@ -109,6 +103,22 @@ string ToString(cudnnStatus_t status) { } } +#if CUDNN_VERSION >= 6000 +string ToString(libraryPropertyType type) { + switch (type) { + case MAJOR_VERSION: + return "MAJOR_VERSION"; + case MINOR_VERSION: + return "MINOR_VERSION"; + case PATCH_LEVEL: + return "PATCH_LEVEL"; + default: + return absl::StrCat( + "<unknown libraryPropertyType: ", static_cast<int>(type), ">"); + } +} +#endif + template <typename T> cudnnDataType_t GetCudnnDataType(); @@ -360,6 +370,34 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo( } } +#if CUDNN_VERSION >= 6000 +port::Status GetCudnnProperty(libraryPropertyType type, int* value) { + cudnnStatus_t status = cudnnGetProperty(type, value); + if (status != CUDNN_STATUS_SUCCESS) { + const string error = + absl::StrCat("cudnnGetProperty failed for type: ", ToString(type), + " with status: ", ToString(status)); + LOG(ERROR) << error; + return port::Status{port::error::INTERNAL, error}; + } + return port::Status::OK(); +} +#endif + +port::Status GetLoadedCudnnVersion(CudnnVersion* version) { +#if CUDNN_VERSION >= 6000 + TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version)); + TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version)); + TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level)); +#else + size_t loaded_version = ::cudnnGetVersion(); + version->major_version = loaded_version / 1000; + version->minor_version = (loaded_version / 100) % 10; + version->patch_level = loaded_version % 100; +#endif + return port::Status::OK(); +} + } // namespace CudnnSupport::CudnnSupport(CUDAExecutor* parent) @@ -376,24 +414,19 @@ port::Status CudnnSupport::Init() { auto status = wrap::cudnnCreate( parent_, reinterpret_cast<cudnnHandle_t*>(&dnn_handle_)); if (status == CUDNN_STATUS_SUCCESS) { - // Check whether loaded version of CuDNN matches what the source - // was built with. - size_t loaded_version = ::cudnnGetVersion(); - size_t loaded_compat_version = cudnnCompatibilityVersion(loaded_version); - size_t compiled_compat_version = cudnnCompatibilityVersion(CUDNN_VERSION); - bool library_loaded_matches_source = - (loaded_compat_version == compiled_compat_version); - if (!library_loaded_matches_source) { - const string error = - port::StrCat("Loaded runtime CuDNN library: ", loaded_version, - " (compatibility version ", loaded_compat_version, - ") but source was compiled with ", CUDNN_VERSION, - " (compatibility version ", compiled_compat_version, - "). If using a binary install, upgrade your CuDNN " - "library to match. If building from sources, " - "make sure the library loaded at runtime matches a " - "compatible version specified during compile " - "configuration."); + CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); + + CudnnVersion loaded_version; + TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version)); + if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) { + const tensorflow::string error = absl::StrCat( + "Loaded runtime CuDNN library: ", loaded_version.ToString(), + " but source was compiled with: ", source_version.ToString(), + ". CuDNN library major and minor version needs to match or have " + "higher minor version in case of CuDNN 7.0 or later version. If " + "using a binary install, upgrade your CuDNN library. If building " + "from sources, make sure the library loaded at runtime is compatible " + "with the version specified during compile configuration."); LOG(ERROR) << error; return port::Status{port::error::INTERNAL, error}; } |