aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/gpus/cuda_configure.bzl
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/gpus/cuda_configure.bzl')
-rw-r--r--third_party/gpus/cuda_configure.bzl47
1 files changed, 37 insertions, 10 deletions
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 1e47bfac78..ddd376cddb 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -86,9 +86,33 @@ def get_cxx_inc_directories(repository_ctx, cc):
return [repository_ctx.path(_cxx_inc_convert(p))
for p in inc_dirs.split("\n")]
+def auto_configure_fail(msg):
+ """Output failure message when auto configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg))
# END cc_configure common functions (see TODO above).
+def _gcc_host_compiler_includes(repository_ctx, cc):
+ """Generates the cxx_builtin_include_directory entries for gcc inc dirs.
+
+ Args:
+ repository_ctx: The repository context.
+ cc: The path to the gcc host compiler.
+
+ Returns:
+ A string containing the cxx_builtin_include_directory for each of the gcc
+ host compiler include directories, which can be added to the CROSSTOOL
+ file.
+ """
+ inc_dirs = get_cxx_inc_directories(repository_ctx, cc)
+ inc_entries = []
+ for inc_dir in inc_dirs:
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
+ return "\n".join(inc_entries)
+
+
def _enable_cuda(repository_ctx):
if "TF_NEED_CUDA" in repository_ctx.os.environ:
enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
@@ -102,7 +126,7 @@ def _cuda_toolkit_path(repository_ctx):
if _CUDA_TOOLKIT_PATH in repository_ctx.os.environ:
cuda_toolkit_path = repository_ctx.os.environ[_CUDA_TOOLKIT_PATH].strip()
if not repository_ctx.path(cuda_toolkit_path).exists:
- fail("Cannot find cuda toolkit path.")
+ auto_configure_fail("Cannot find cuda toolkit path.")
return cuda_toolkit_path
@@ -112,7 +136,7 @@ def _cudnn_install_basedir(repository_ctx):
if _CUDNN_INSTALL_PATH in repository_ctx.os.environ:
cudnn_install_path = repository_ctx.os.environ[_CUDNN_INSTALL_PATH].strip()
if not repository_ctx.path(cudnn_install_path).exists:
- fail("Cannot find cudnn install path.")
+ auto_configure_fail("Cannot find cudnn install path.")
return cudnn_install_path
@@ -144,7 +168,7 @@ def _compute_capabilities(repository_ctx):
# if re.match("[0-9]+.[0-9]+", capability) == None:
parts = capability.split(".")
if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
- fail("Invalid compute capability: %s" % capability)
+ auto_configure_fail("Invalid compute capability: %s" % capability)
return capabilities
@@ -186,7 +210,7 @@ def _cuda_symlink_files(cpu_value, cuda_version, cudnn_version):
cuda_fft_lib = "lib/libcufft%s.dylib" % cuda_ext,
cuda_cupti_lib = "extras/CUPTI/lib/libcupti%s.dylib" % cuda_ext)
else:
- fail("Not supported CPU value %s" % cpu_value)
+ auto_configure_fail("Not supported CPU value %s" % cpu_value)
def _check_lib(repository_ctx, cuda_toolkit_path, cuda_lib):
@@ -199,7 +223,7 @@ def _check_lib(repository_ctx, cuda_toolkit_path, cuda_lib):
"""
lib_path = cuda_toolkit_path + "/" + cuda_lib
if not repository_ctx.path(lib_path).exists:
- fail("Cannot find %s" % lib_path)
+ auto_configure_fail("Cannot find %s" % lib_path)
def _check_dir(repository_ctx, directory):
@@ -210,7 +234,7 @@ def _check_dir(repository_ctx, directory):
directory: The directory to check the existence of.
"""
if not repository_ctx.path(directory).exists:
- fail("Cannot find dir: %s" % directory)
+ auto_configure_fail("Cannot find dir: %s" % directory)
def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir):
@@ -230,7 +254,7 @@ def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir):
return cudnn_install_basedir + "/include"
if repository_ctx.path("/usr/include/cudnn.h").exists:
return "/usr/include"
- fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
+ auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
def _find_cudnn_lib_path(repository_ctx, cudnn_install_basedir, symlink_files):
@@ -252,7 +276,7 @@ def _find_cudnn_lib_path(repository_ctx, cudnn_install_basedir, symlink_files):
if repository_ctx.path(alt_lib_dir).exists:
return alt_lib_dir
- fail("Cannot find %s or %s under %s" %
+ auto_configure_fail("Cannot find %s or %s under %s" %
(symlink_files.cuda_dnn_lib, symlink_files.cuda_dnn_lib_alt,
cudnn_install_basedir))
@@ -380,15 +404,18 @@ def _create_cuda_repository(repository_ctx):
# Set up crosstool/
_file(repository_ctx, "crosstool:BUILD")
+ cc = find_cc(repository_ctx)
+ gcc_host_compiler_includes = _gcc_host_compiler_includes(repository_ctx, cc)
_tpl(repository_ctx, "crosstool:CROSSTOOL",
{
"%{cuda_version}": ("-%s" % cuda_version) if cuda_version else "",
+ "%{gcc_host_compiler_includes}": gcc_host_compiler_includes,
})
_tpl(repository_ctx,
"crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
{
- "%{cpu_compiler}": str(find_cc(repository_ctx)),
- "%{gcc_host_compiler_path}": str(find_cc(repository_ctx)),
+ "%{cpu_compiler}": str(cc),
+ "%{gcc_host_compiler_path}": str(cc),
"%{cuda_compute_capabilities}": ", ".join(
["\"%s\"" % c for c in compute_capabilities]),
})