aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_dnn.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_dnn.cc')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc87
1 files changed, 60 insertions, 27 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index ab5e6590e0..1aea0485fd 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -18,7 +18,9 @@ limitations under the License.
#include <functional>
#include <memory>
+#include "absl/strings/str_cat.h"
#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
@@ -27,6 +29,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/cuda/cuda_timer.h"
+#include "tensorflow/stream_executor/cuda/cudnn_version.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
@@ -55,15 +58,6 @@ NarrowT CheckedNarrowing(const WideT& wide) {
return narrow;
}
-// Returns the "Compatibility" version number from the CuDNN version number.
-// This is the number that tries to indicate ABI compatibility.
-//
-// For example, if cudnn_version is 5107, the compatibility version
-// number will be 5100.
-size_t cudnnCompatibilityVersion(size_t cudnn_version) {
- return (cudnn_version / 100) * 100;
-}
-
} // namespace
namespace perftools {
@@ -109,6 +103,22 @@ string ToString(cudnnStatus_t status) {
}
}
+#if CUDNN_VERSION >= 6000
+string ToString(libraryPropertyType type) {
+ switch (type) {
+ case MAJOR_VERSION:
+ return "MAJOR_VERSION";
+ case MINOR_VERSION:
+ return "MINOR_VERSION";
+ case PATCH_LEVEL:
+ return "PATCH_LEVEL";
+ default:
+ return absl::StrCat(
+ "<unknown libraryPropertyType: ", static_cast<int>(type), ">");
+ }
+}
+#endif
+
template <typename T>
cudnnDataType_t GetCudnnDataType();
@@ -360,6 +370,34 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
}
}
+#if CUDNN_VERSION >= 6000
+port::Status GetCudnnProperty(libraryPropertyType type, int* value) {
+ cudnnStatus_t status = cudnnGetProperty(type, value);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ const string error =
+ absl::StrCat("cudnnGetProperty failed for type: ", ToString(type),
+ " with status: ", ToString(status));
+ LOG(ERROR) << error;
+ return port::Status{port::error::INTERNAL, error};
+ }
+ return port::Status::OK();
+}
+#endif
+
+port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
+#if CUDNN_VERSION >= 6000
+ TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version));
+ TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version));
+ TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level));
+#else
+ size_t loaded_version = ::cudnnGetVersion();
+ version->major_version = loaded_version / 1000;
+ version->minor_version = (loaded_version / 100) % 10;
+ version->patch_level = loaded_version % 100;
+#endif
+ return port::Status::OK();
+}
+
} // namespace
CudnnSupport::CudnnSupport(CUDAExecutor* parent)
@@ -376,24 +414,19 @@ port::Status CudnnSupport::Init() {
auto status = wrap::cudnnCreate(
parent_, reinterpret_cast<cudnnHandle_t*>(&dnn_handle_));
if (status == CUDNN_STATUS_SUCCESS) {
- // Check whether loaded version of CuDNN matches what the source
- // was built with.
- size_t loaded_version = ::cudnnGetVersion();
- size_t loaded_compat_version = cudnnCompatibilityVersion(loaded_version);
- size_t compiled_compat_version = cudnnCompatibilityVersion(CUDNN_VERSION);
- bool library_loaded_matches_source =
- (loaded_compat_version == compiled_compat_version);
- if (!library_loaded_matches_source) {
- const string error =
- port::StrCat("Loaded runtime CuDNN library: ", loaded_version,
- " (compatibility version ", loaded_compat_version,
- ") but source was compiled with ", CUDNN_VERSION,
- " (compatibility version ", compiled_compat_version,
- "). If using a binary install, upgrade your CuDNN "
- "library to match. If building from sources, "
- "make sure the library loaded at runtime matches a "
- "compatible version specified during compile "
- "configuration.");
+ CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
+
+ CudnnVersion loaded_version;
+ TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
+ if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
+ const tensorflow::string error = absl::StrCat(
+ "Loaded runtime CuDNN library: ", loaded_version.ToString(),
+ " but source was compiled with: ", source_version.ToString(),
+ ". CuDNN library major and minor version needs to match or have "
+ "higher minor version in case of CuDNN 7.0 or later version. If "
+ "using a binary install, upgrade your CuDNN library. If building "
+ "from sources, make sure the library loaded at runtime is compatible "
+ "with the version specified during compile configuration.");
LOG(ERROR) << error;
return port::Status{port::error::INTERNAL, error};
}