aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/gpus
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/gpus')
-rw-r--r--third_party/gpus/cuda_configure.bzl76
1 files changed, 63 insertions, 13 deletions
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index c6deae05b8..61932a8e6d 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -5,7 +5,7 @@
* `TF_NEED_CUDA`: Whether to enable building with CUDA.
* `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
- * `TF_CUDA_CLANG`: Wheter to use clang as a cuda compiler.
+ * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler.
* `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for
both host and device code compilation if TF_CUDA_CLANG is 1.
* `CUDA_TOOLKIT_PATH`: The path to the CUDA toolkit. Default is
@@ -41,8 +41,8 @@ def find_cc(repository_ctx):
"""Find the C++ compiler."""
# On Windows, we use Bazel's MSVC CROSSTOOL for GPU build
# Return a dummy value for GCC detection here to avoid error
- if _cpu_value(repository_ctx) == "Windows":
- return "/use/--config x64_windows_msvc/instead"
+ if _is_windows(repository_ctx):
+ return "/use/--config=win-cuda --cpu=x64_windows_msvc/instead"
if _use_cuda_clang(repository_ctx):
target_cc_name = "clang"
@@ -57,7 +57,7 @@ def find_cc(repository_ctx):
if cc_name_from_env:
cc_name = cc_name_from_env
if cc_name.startswith("/"):
- # Absolute path, maybe we should make this suported by our which function.
+ # Absolute path, maybe we should make this supported by our which function.
return cc_name
cc = repository_ctx.which(cc_name)
if cc == None:
@@ -122,10 +122,10 @@ def get_cxx_inc_directories(repository_ctx, cc):
def auto_configure_fail(msg):
- """Output failure message when auto configuration fails."""
+ """Output failure message when cuda configuration fails."""
red = "\033[0;31m"
no_color = "\033[0m"
- fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg))
+ fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
# END cc_configure common functions (see TODO above).
@@ -421,6 +421,10 @@ def _cpu_value(repository_ctx):
return result.stdout.strip()
+def _is_windows(repository_ctx):
+ """Returns true if the host operating system is windows."""
+ return _cpu_value(repository_ctx) == "Windows"
+
def _lib_name(lib, cpu_value, version="", static=False):
"""Constructs the platform-specific name of a library.
@@ -769,14 +773,48 @@ def _create_dummy_repository(repository_ctx):
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
+def _execute(repository_ctx, cmdline, error_msg=None, error_details=None,
+ empty_stdout_fine=False):
+ """Executes an arbitrary shell command.
+
+ Args:
+ repository_ctx: the repository_ctx object
+ cmdline: list of strings, the command to execute
+ error_msg: string, a summary of the error if the command fails
+ error_details: string, details about the error or steps to fix it
+ empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
+ it's an error
+ Return:
+ the result of repository_ctx.execute(cmdline)
+ """
+ result = repository_ctx.execute(cmdline)
+ if result.stderr or not (empty_stdout_fine or result.stdout):
+ auto_configure_fail(
+ "\n".join([
+ error_msg.strip() if error_msg else "Repository command failed",
+ result.stderr.strip(),
+ error_details if error_details else ""]))
+ return result
+
+
+def _norm_path(path):
+ """Returns a path with '/' and remove the trailing slash."""
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ return path
+
+
def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
src_files = [], dest_files = []):
- """Returns a genrule to symlink a set of files.
+ """Returns a genrule to symlink(or copy if on Windows) a set of files.
If src_dir is passed, files will be read from the given directory; otherwise
we assume files are in src_files and dest_files
"""
if src_dir != None:
+ src_dir = _norm_path(src_dir)
+ dest_dir = _norm_path(dest_dir)
files = _read_dir(repository_ctx, src_dir)
# Create a list with the src_dir stripped to use for outputs.
dest_files = files.replace(src_dir, '').splitlines()
@@ -787,8 +825,10 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
if dest_files[i] != "":
# If we have only one file to link we do not want to use the dest_dir, as
# $(@D) will include the full path to the file.
- dest = ' $(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else ' $(@D)/' + dest_files[i]
- command.append('ln -s ' + src_files[i] + dest)
+ dest = '$(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else '$(@D)/' + dest_files[i]
+ # On Windows, symlink is not supported, so we just copy all the files.
+ cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s'
+ command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')
genrule = _genrule(src_dir, genrule_name, " && ".join(command),
"\n".join(outs))
@@ -821,10 +861,20 @@ def _read_dir(repository_ctx, src_dir):
symlinks. The returned string contains the full path of all files
separated by line breaks.
"""
- find_result = repository_ctx.execute([
- "find", src_dir, "-follow", "-type", "f"
- ])
- return find_result.stdout
+ if _is_windows(repository_ctx):
+ src_dir = src_dir.replace("/", "\\")
+ find_result = _execute(
+ repository_ctx, ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
+ empty_stdout_fine=True)
+ # src_files will be used in genrule.outs where the paths must
+ # use forward slashes.
+ result = find_result.stdout.replace("\\", "/")
+ else:
+ find_result = _execute(
+ repository_ctx, ["find", src_dir, "-follow", "-type", "f"],
+ empty_stdout_fine=True)
+ result = find_result.stdout
+ return result
def _use_cuda_clang(repository_ctx):