aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/gpus
diff options
context:
space:
mode:
authorGravatar Guangda Lai <laigd@google.com>2018-01-25 23:59:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 00:02:44 -0800
commit76f6938bafeb81a4ca41b8dac2b9c83e1286fa95 (patch)
tree75dd6285db559b19d7a0449065973780a748ef7c /third_party/gpus
parentffda6079ed619df8fd3edb4db71ffc7d005c2430 (diff)
Set up TensorRT configurations for external use, and add a test.
PiperOrigin-RevId: 183347199
Diffstat (limited to 'third_party/gpus')
-rw-r--r--third_party/gpus/cuda_configure.bzl104
1 files changed, 57 insertions, 47 deletions
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 2727fa5efe..8e1dd8a54f 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -236,7 +236,7 @@ def _cudnn_install_basedir(repository_ctx):
return cudnn_install_path
-def _matches_version(environ_version, detected_version):
+def matches_version(environ_version, detected_version):
"""Checks whether the user-specified version matches the detected version.
This function performs a weak matching so that if the user specifies only the
@@ -317,7 +317,7 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
environ_version = ""
if _TF_CUDA_VERSION in repository_ctx.os.environ:
environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
- if environ_version and not _matches_version(environ_version, full_version):
+ if environ_version and not matches_version(environ_version, full_version):
auto_configure_fail(
("CUDA version detected from nvcc (%s) does not match " +
"TF_CUDA_VERSION (%s)") % (full_version, environ_version))
@@ -338,35 +338,49 @@ _DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"
-def _find_cuda_define(repository_ctx, cudnn_header_dir, define):
- """Returns the value of a #define in cudnn.h
+def find_cuda_define(repository_ctx, header_dir, header_file, define):
+ """Returns the value of a #define in a header file.
- Greps through cudnn.h and returns the value of the specified #define. If the
- #define is not found, then raise an error.
+ Greps through a header file and returns the value of the specified #define.
+ If the #define is not found, then raise an error.
Args:
repository_ctx: The repository context.
- cudnn_header_dir: The directory containing the cuDNN header.
+ header_dir: The directory containing the header file.
+ header_file: The header file name.
define: The #define to search for.
Returns:
- The value of the #define found in cudnn.h.
+ The value of the #define found in the header.
"""
- # Confirm location of cudnn.h and grep for the line defining CUDNN_MAJOR.
- cudnn_h_path = repository_ctx.path("%s/cudnn.h" % cudnn_header_dir)
- if not cudnn_h_path.exists:
- auto_configure_fail("Cannot find cudnn.h at %s" % str(cudnn_h_path))
- result = repository_ctx.execute(["grep", "--color=never", "-E", define, str(cudnn_h_path)])
+ # Confirm location of the header and grep for the line defining the macro.
+ h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
+ if not h_path.exists:
+ auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
+ result = repository_ctx.execute(
+ # Grep one more lines as some #defines are splitted into two lines.
+ ["grep", "--color=never", "-A1", "-E", define, str(h_path)])
if result.stderr:
- auto_configure_fail("Error reading %s: %s" %
- (result.stderr, str(cudnn_h_path)))
+ auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
- # Parse the cuDNN major version from the line defining CUDNN_MAJOR
- lines = result.stdout.splitlines()
- if len(lines) == 0 or lines[0].find(define) == -1:
+ # Parse the version from the line defining the macro.
+ if result.stdout.find(define) == -1:
auto_configure_fail("Cannot find line containing '%s' in %s" %
- (define, str(cudnn_h_path)))
- return lines[0].replace(define, "").strip()
+ (define, h_path))
+ version = result.stdout
+ # Remove the new line and '\' character if any.
+ version = version.replace("\\", " ")
+ version = version.replace("\n", " ")
+ version = version.replace(define, "").lstrip()
+ # Remove the code after the version number.
+ version_end = version.find(" ")
+ if version_end != -1:
+ if version_end == 0:
+ auto_configure_fail(
+ "Cannot extract the version from line containing '%s' in %s" %
+ (define, str(h_path)))
+ version = version[:version_end].strip()
+ return version
def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
@@ -382,12 +396,12 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
"""
cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
cudnn_install_basedir)
- major_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
- _DEFINE_CUDNN_MAJOR)
- minor_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
- _DEFINE_CUDNN_MINOR)
- patch_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
- _DEFINE_CUDNN_PATCHLEVEL)
+ major_version = find_cuda_define(
+ repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MAJOR)
+ minor_version = find_cuda_define(
+ repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MINOR)
+ patch_version = find_cuda_define(
+ repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_PATCHLEVEL)
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
# Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
@@ -395,7 +409,7 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
environ_version = ""
if _TF_CUDNN_VERSION in repository_ctx.os.environ:
environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
- if environ_version and not _matches_version(environ_version, full_version):
+ if environ_version and not matches_version(environ_version, full_version):
cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
cudnn_install_basedir)
auto_configure_fail(
@@ -427,7 +441,7 @@ def _compute_capabilities(repository_ctx):
return capabilities
-def _cpu_value(repository_ctx):
+def get_cpu_value(repository_ctx):
"""Returns the name of the host operating system.
Args:
@@ -447,7 +461,7 @@ def _cpu_value(repository_ctx):
def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows."""
- return _cpu_value(repository_ctx) == "Windows"
+ return get_cpu_value(repository_ctx) == "Windows"
def _lib_name(lib, cpu_value, version="", static=False):
"""Constructs the platform-specific name of a library.
@@ -582,11 +596,8 @@ def _find_libs(repository_ctx, cuda_config):
cuda_config: The CUDA config as returned by _get_cuda_config
Returns:
- Map of library names to structs of filename and path as returned by
- _find_cuda_lib and _find_cupti_lib.
+ Map of library names to structs of filename and path.
"""
- cudnn_version = cuda_config.cudnn_version
- cudnn_ext = ".%s" % cudnn_version if cudnn_version else ""
cpu_value = cuda_config.cpu_value
return {
"cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
@@ -611,7 +622,7 @@ def _find_libs(repository_ctx, cuda_config):
"cudnn": _find_cuda_lib(
"cudnn", repository_ctx, cpu_value, cuda_config.cudnn_install_basedir,
cuda_config.cudnn_version),
- "cupti": _find_cupti_lib(repository_ctx, cuda_config),
+ "cupti": _find_cupti_lib(repository_ctx, cuda_config)
}
@@ -654,7 +665,7 @@ def _get_cuda_config(repository_ctx):
compute_capabilities: A list of the system's CUDA compute capabilities.
cpu_value: The name of the host operating system.
"""
- cpu_value = _cpu_value(repository_ctx)
+ cpu_value = get_cpu_value(repository_ctx)
cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
@@ -712,13 +723,13 @@ error_gpu_disabled()
def _create_dummy_repository(repository_ctx):
- cpu_value = _cpu_value(repository_ctx)
+ cpu_value = get_cpu_value(repository_ctx)
# Set up BUILD file for cuda/.
_tpl(repository_ctx, "cuda:build_defs.bzl",
{
"%{cuda_is_configured}": "False",
- "%{cuda_extra_copts}": "[]"
+ "%{cuda_extra_copts}": "[]",
})
_tpl(repository_ctx, "cuda:BUILD",
{
@@ -805,8 +816,8 @@ def _norm_path(path):
return path
-def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
- src_files = [], dest_files = []):
+def symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
+ src_files = [], dest_files = []):
"""Returns a genrule to symlink(or copy if on Windows) a set of files.
If src_dir is passed, files will be read from the given directory; otherwise
@@ -913,11 +924,11 @@ def _create_local_cuda_repository(repository_ctx):
# cuda_toolkit_path
cuda_toolkit_path = cuda_config.cuda_toolkit_path
cuda_include_path = cuda_toolkit_path + "/include"
- genrules = [_symlink_genrule_for_dir(repository_ctx,
+ genrules = [symlink_genrule_for_dir(repository_ctx,
cuda_include_path, "cuda/include", "cuda-include")]
- genrules.append(_symlink_genrule_for_dir(repository_ctx,
+ genrules.append(symlink_genrule_for_dir(repository_ctx,
cuda_toolkit_path + "/nvvm", "cuda/nvvm", "cuda-nvvm"))
- genrules.append(_symlink_genrule_for_dir(repository_ctx,
+ genrules.append(symlink_genrule_for_dir(repository_ctx,
cuda_toolkit_path + "/extras/CUPTI/include",
"cuda/extras/CUPTI/include", "cuda-extras"))
@@ -927,15 +938,15 @@ def _create_local_cuda_repository(repository_ctx):
for lib in cuda_libs.values():
cuda_lib_src.append(lib.path)
cuda_lib_dest.append("cuda/lib/" + lib.file_name)
- genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
- cuda_lib_src, cuda_lib_dest))
+ genrules.append(symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
+ cuda_lib_src, cuda_lib_dest))
- # Set up the symbolic links for cudnn if cudnn was was not installed to
+ # Set up the symbolic links for cudnn if cndnn was not installed to
# CUDA_TOOLKIT_PATH.
included_files = _read_dir(repository_ctx, cuda_include_path).replace(
cuda_include_path, '').splitlines()
if '/cudnn.h' not in included_files:
- genrules.append(_symlink_genrule_for_dir(repository_ctx, None,
+ genrules.append(symlink_genrule_for_dir(repository_ctx, None,
"cuda/include/", "cudnn-include", [cudnn_header_dir + "/cudnn.h"],
["cudnn.h"]))
else:
@@ -952,7 +963,6 @@ def _create_local_cuda_repository(repository_ctx):
"%{cuda_is_configured}": "True",
"%{cuda_extra_copts}": _compute_cuda_extra_copts(
repository_ctx, cuda_config.compute_capabilities),
-
})
_tpl(repository_ctx, "cuda:BUILD",
{