From 76f6938bafeb81a4ca41b8dac2b9c83e1286fa95 Mon Sep 17 00:00:00 2001 From: Guangda Lai Date: Thu, 25 Jan 2018 23:59:19 -0800 Subject: Set up TensorRT configurations for external use, and add a test. PiperOrigin-RevId: 183347199 --- third_party/gpus/cuda_configure.bzl | 104 ++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 47 deletions(-) (limited to 'third_party/gpus') diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 2727fa5efe..8e1dd8a54f 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -236,7 +236,7 @@ def _cudnn_install_basedir(repository_ctx): return cudnn_install_path -def _matches_version(environ_version, detected_version): +def matches_version(environ_version, detected_version): """Checks whether the user-specified version matches the detected version. This function performs a weak matching so that if the user specifies only the @@ -317,7 +317,7 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value): environ_version = "" if _TF_CUDA_VERSION in repository_ctx.os.environ: environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip() - if environ_version and not _matches_version(environ_version, full_version): + if environ_version and not matches_version(environ_version, full_version): auto_configure_fail( ("CUDA version detected from nvcc (%s) does not match " + "TF_CUDA_VERSION (%s)") % (full_version, environ_version)) @@ -338,35 +338,49 @@ _DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR" _DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL" -def _find_cuda_define(repository_ctx, cudnn_header_dir, define): - """Returns the value of a #define in cudnn.h +def find_cuda_define(repository_ctx, header_dir, header_file, define): + """Returns the value of a #define in a header file. - Greps through cudnn.h and returns the value of the specified #define. If the - #define is not found, then raise an error. + Greps through a header file and returns the value of the specified #define. + If the #define is not found, then raise an error. Args: repository_ctx: The repository context. - cudnn_header_dir: The directory containing the cuDNN header. + header_dir: The directory containing the header file. + header_file: The header file name. define: The #define to search for. Returns: - The value of the #define found in cudnn.h. + The value of the #define found in the header. """ - # Confirm location of cudnn.h and grep for the line defining CUDNN_MAJOR. - cudnn_h_path = repository_ctx.path("%s/cudnn.h" % cudnn_header_dir) - if not cudnn_h_path.exists: - auto_configure_fail("Cannot find cudnn.h at %s" % str(cudnn_h_path)) - result = repository_ctx.execute(["grep", "--color=never", "-E", define, str(cudnn_h_path)]) + # Confirm location of the header and grep for the line defining the macro. + h_path = repository_ctx.path("%s/%s" % (header_dir, header_file)) + if not h_path.exists: + auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path))) + result = repository_ctx.execute( + # Grep one more lines as some #defines are splitted into two lines. + ["grep", "--color=never", "-A1", "-E", define, str(h_path)]) if result.stderr: - auto_configure_fail("Error reading %s: %s" % - (result.stderr, str(cudnn_h_path))) + auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr)) - # Parse the cuDNN major version from the line defining CUDNN_MAJOR - lines = result.stdout.splitlines() - if len(lines) == 0 or lines[0].find(define) == -1: + # Parse the version from the line defining the macro. + if result.stdout.find(define) == -1: auto_configure_fail("Cannot find line containing '%s' in %s" % - (define, str(cudnn_h_path))) - return lines[0].replace(define, "").strip() + (define, h_path)) + version = result.stdout + # Remove the new line and '\' character if any. + version = version.replace("\\", " ") + version = version.replace("\n", " ") + version = version.replace(define, "").lstrip() + # Remove the code after the version number. + version_end = version.find(" ") + if version_end != -1: + if version_end == 0: + auto_configure_fail( + "Cannot extract the version from line containing '%s' in %s" % + (define, str(h_path))) + version = version[:version_end].strip() + return version def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value): @@ -382,12 +396,12 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value): """ cudnn_header_dir = _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir) - major_version = _find_cuda_define(repository_ctx, cudnn_header_dir, - _DEFINE_CUDNN_MAJOR) - minor_version = _find_cuda_define(repository_ctx, cudnn_header_dir, - _DEFINE_CUDNN_MINOR) - patch_version = _find_cuda_define(repository_ctx, cudnn_header_dir, - _DEFINE_CUDNN_PATCHLEVEL) + major_version = find_cuda_define( + repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MAJOR) + minor_version = find_cuda_define( + repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MINOR) + patch_version = find_cuda_define( + repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_PATCHLEVEL) full_version = "%s.%s.%s" % (major_version, minor_version, patch_version) # Check whether TF_CUDNN_VERSION was set by the user and fail if it does not @@ -395,7 +409,7 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value): environ_version = "" if _TF_CUDNN_VERSION in repository_ctx.os.environ: environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip() - if environ_version and not _matches_version(environ_version, full_version): + if environ_version and not matches_version(environ_version, full_version): cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" % cudnn_install_basedir) auto_configure_fail( @@ -427,7 +441,7 @@ def _compute_capabilities(repository_ctx): return capabilities -def _cpu_value(repository_ctx): +def get_cpu_value(repository_ctx): """Returns the name of the host operating system. Args: @@ -447,7 +461,7 @@ def _cpu_value(repository_ctx): def _is_windows(repository_ctx): """Returns true if the host operating system is windows.""" - return _cpu_value(repository_ctx) == "Windows" + return get_cpu_value(repository_ctx) == "Windows" def _lib_name(lib, cpu_value, version="", static=False): """Constructs the platform-specific name of a library. @@ -582,11 +596,8 @@ def _find_libs(repository_ctx, cuda_config): cuda_config: The CUDA config as returned by _get_cuda_config Returns: - Map of library names to structs of filename and path as returned by - _find_cuda_lib and _find_cupti_lib. + Map of library names to structs of filename and path. """ - cudnn_version = cuda_config.cudnn_version - cudnn_ext = ".%s" % cudnn_version if cudnn_version else "" cpu_value = cuda_config.cpu_value return { "cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path), @@ -611,7 +622,7 @@ def _find_libs(repository_ctx, cuda_config): "cudnn": _find_cuda_lib( "cudnn", repository_ctx, cpu_value, cuda_config.cudnn_install_basedir, cuda_config.cudnn_version), - "cupti": _find_cupti_lib(repository_ctx, cuda_config), + "cupti": _find_cupti_lib(repository_ctx, cuda_config) } @@ -654,7 +665,7 @@ def _get_cuda_config(repository_ctx): compute_capabilities: A list of the system's CUDA compute capabilities. cpu_value: The name of the host operating system. """ - cpu_value = _cpu_value(repository_ctx) + cpu_value = get_cpu_value(repository_ctx) cuda_toolkit_path = _cuda_toolkit_path(repository_ctx) cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value) cudnn_install_basedir = _cudnn_install_basedir(repository_ctx) @@ -712,13 +723,13 @@ error_gpu_disabled() def _create_dummy_repository(repository_ctx): - cpu_value = _cpu_value(repository_ctx) + cpu_value = get_cpu_value(repository_ctx) # Set up BUILD file for cuda/. _tpl(repository_ctx, "cuda:build_defs.bzl", { "%{cuda_is_configured}": "False", - "%{cuda_extra_copts}": "[]" + "%{cuda_extra_copts}": "[]", }) _tpl(repository_ctx, "cuda:BUILD", { @@ -805,8 +816,8 @@ def _norm_path(path): return path -def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, - src_files = [], dest_files = []): +def symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, + src_files = [], dest_files = []): """Returns a genrule to symlink(or copy if on Windows) a set of files. If src_dir is passed, files will be read from the given directory; otherwise @@ -913,11 +924,11 @@ def _create_local_cuda_repository(repository_ctx): # cuda_toolkit_path cuda_toolkit_path = cuda_config.cuda_toolkit_path cuda_include_path = cuda_toolkit_path + "/include" - genrules = [_symlink_genrule_for_dir(repository_ctx, + genrules = [symlink_genrule_for_dir(repository_ctx, cuda_include_path, "cuda/include", "cuda-include")] - genrules.append(_symlink_genrule_for_dir(repository_ctx, + genrules.append(symlink_genrule_for_dir(repository_ctx, cuda_toolkit_path + "/nvvm", "cuda/nvvm", "cuda-nvvm")) - genrules.append(_symlink_genrule_for_dir(repository_ctx, + genrules.append(symlink_genrule_for_dir(repository_ctx, cuda_toolkit_path + "/extras/CUPTI/include", "cuda/extras/CUPTI/include", "cuda-extras")) @@ -927,15 +938,15 @@ def _create_local_cuda_repository(repository_ctx): for lib in cuda_libs.values(): cuda_lib_src.append(lib.path) cuda_lib_dest.append("cuda/lib/" + lib.file_name) - genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib", - cuda_lib_src, cuda_lib_dest)) + genrules.append(symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib", + cuda_lib_src, cuda_lib_dest)) - # Set up the symbolic links for cudnn if cudnn was was not installed to + # Set up the symbolic links for cudnn if cndnn was not installed to # CUDA_TOOLKIT_PATH. included_files = _read_dir(repository_ctx, cuda_include_path).replace( cuda_include_path, '').splitlines() if '/cudnn.h' not in included_files: - genrules.append(_symlink_genrule_for_dir(repository_ctx, None, + genrules.append(symlink_genrule_for_dir(repository_ctx, None, "cuda/include/", "cudnn-include", [cudnn_header_dir + "/cudnn.h"], ["cudnn.h"])) else: @@ -952,7 +963,6 @@ def _create_local_cuda_repository(repository_ctx): "%{cuda_is_configured}": "True", "%{cuda_extra_copts}": _compute_cuda_extra_copts( repository_ctx, cuda_config.compute_capabilities), - }) _tpl(repository_ctx, "cuda:BUILD", { -- cgit v1.2.3