diff options
Diffstat (limited to 'third_party/gpus')
-rw-r--r-- | third_party/gpus/cuda_configure.bzl | 76 |
1 files changed, 63 insertions, 13 deletions
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index c6deae05b8..61932a8e6d 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -5,7 +5,7 @@ * `TF_NEED_CUDA`: Whether to enable building with CUDA. * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path - * `TF_CUDA_CLANG`: Wheter to use clang as a cuda compiler. + * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler. * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for both host and device code compilation if TF_CUDA_CLANG is 1. * `CUDA_TOOLKIT_PATH`: The path to the CUDA toolkit. Default is @@ -41,8 +41,8 @@ def find_cc(repository_ctx): """Find the C++ compiler.""" # On Windows, we use Bazel's MSVC CROSSTOOL for GPU build # Return a dummy value for GCC detection here to avoid error - if _cpu_value(repository_ctx) == "Windows": - return "/use/--config x64_windows_msvc/instead" + if _is_windows(repository_ctx): + return "/use/--config=win-cuda --cpu=x64_windows_msvc/instead" if _use_cuda_clang(repository_ctx): target_cc_name = "clang" @@ -57,7 +57,7 @@ def find_cc(repository_ctx): if cc_name_from_env: cc_name = cc_name_from_env if cc_name.startswith("/"): - # Absolute path, maybe we should make this suported by our which function. + # Absolute path, maybe we should make this supported by our which function. return cc_name cc = repository_ctx.which(cc_name) if cc == None: @@ -122,10 +122,10 @@ def get_cxx_inc_directories(repository_ctx, cc): def auto_configure_fail(msg): - """Output failure message when auto configuration fails.""" + """Output failure message when cuda configuration fails.""" red = "\033[0;31m" no_color = "\033[0m" - fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg)) + fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg)) # END cc_configure common functions (see TODO above). @@ -421,6 +421,10 @@ def _cpu_value(repository_ctx): return result.stdout.strip() +def _is_windows(repository_ctx): + """Returns true if the host operating system is windows.""" + return _cpu_value(repository_ctx) == "Windows" + def _lib_name(lib, cpu_value, version="", static=False): """Constructs the platform-specific name of a library. @@ -769,14 +773,48 @@ def _create_dummy_repository(repository_ctx): repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) +def _execute(repository_ctx, cmdline, error_msg=None, error_details=None, + empty_stdout_fine=False): + """Executes an arbitrary shell command. + + Args: + repository_ctx: the repository_ctx object + cmdline: list of strings, the command to execute + error_msg: string, a summary of the error if the command fails + error_details: string, details about the error or steps to fix it + empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise + it's an error + Return: + the result of repository_ctx.execute(cmdline) + """ + result = repository_ctx.execute(cmdline) + if result.stderr or not (empty_stdout_fine or result.stdout): + auto_configure_fail( + "\n".join([ + error_msg.strip() if error_msg else "Repository command failed", + result.stderr.strip(), + error_details if error_details else ""])) + return result + + +def _norm_path(path): + """Returns a path with '/' and remove the trailing slash.""" + path = path.replace("\\", "/") + if path[-1] == "/": + path = path[:-1] + return path + + def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, src_files = [], dest_files = []): - """Returns a genrule to symlink a set of files. + """Returns a genrule to symlink(or copy if on Windows) a set of files. If src_dir is passed, files will be read from the given directory; otherwise we assume files are in src_files and dest_files """ if src_dir != None: + src_dir = _norm_path(src_dir) + dest_dir = _norm_path(dest_dir) files = _read_dir(repository_ctx, src_dir) # Create a list with the src_dir stripped to use for outputs. dest_files = files.replace(src_dir, '').splitlines() @@ -787,8 +825,10 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, if dest_files[i] != "": # If we have only one file to link we do not want to use the dest_dir, as # $(@D) will include the full path to the file. - dest = ' $(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else ' $(@D)/' + dest_files[i] - command.append('ln -s ' + src_files[i] + dest) + dest = '$(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else '$(@D)/' + dest_files[i] + # On Windows, symlink is not supported, so we just copy all the files. + cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s' + command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest)) outs.append(' "' + dest_dir + dest_files[i] + '",') genrule = _genrule(src_dir, genrule_name, " && ".join(command), "\n".join(outs)) @@ -821,10 +861,20 @@ def _read_dir(repository_ctx, src_dir): symlinks. The returned string contains the full path of all files separated by line breaks. """ - find_result = repository_ctx.execute([ - "find", src_dir, "-follow", "-type", "f" - ]) - return find_result.stdout + if _is_windows(repository_ctx): + src_dir = src_dir.replace("/", "\\") + find_result = _execute( + repository_ctx, ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], + empty_stdout_fine=True) + # src_files will be used in genrule.outs where the paths must + # use forward slashes. + result = find_result.stdout.replace("\\", "/") + else: + find_result = _execute( + repository_ctx, ["find", src_dir, "-follow", "-type", "f"], + empty_stdout_fine=True) + result = find_result.stdout + return result def _use_cuda_clang(repository_ctx): |