diff options
Diffstat (limited to 'third_party/gpus/rocm_configure.bzl')
-rw-r--r-- | third_party/gpus/rocm_configure.bzl | 1147 |
1 files changed, 634 insertions, 513 deletions
diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 9371e33f97..9108639b0b 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -27,334 +27,377 @@ _DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" _DEFAULT_ROCM_AMDGPU_TARGETS = ["gfx803", "gfx900"] def find_cc(repository_ctx): - """Find the C++ compiler.""" - # Return a dummy value for GCC detection here to avoid error - target_cc_name = "gcc" - cc_path_envvar = _GCC_HOST_COMPILER_PATH - cc_name = target_cc_name - - if cc_path_envvar in repository_ctx.os.environ: - cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip() - if cc_name_from_env: - cc_name = cc_name_from_env - if cc_name.startswith("/"): - # Absolute path, maybe we should make this supported by our which function. - return cc_name - cc = repository_ctx.which(cc_name) - if cc == None: - fail(("Cannot find {}, either correct your path or set the {}" + - " environment variable").format(target_cc_name, cc_path_envvar)) - return cc + """Find the C++ compiler.""" + + # Return a dummy value for GCC detection here to avoid error + target_cc_name = "gcc" + cc_path_envvar = _GCC_HOST_COMPILER_PATH + cc_name = target_cc_name + + if cc_path_envvar in repository_ctx.os.environ: + cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip() + if cc_name_from_env: + cc_name = cc_name_from_env + if cc_name.startswith("/"): + # Absolute path, maybe we should make this supported by our which function. + return cc_name + cc = repository_ctx.which(cc_name) + if cc == None: + fail(("Cannot find {}, either correct your path or set the {}" + + " environment variable").format(target_cc_name, cc_path_envvar)) + return cc _INC_DIR_MARKER_BEGIN = "#include <...>" def _cxx_inc_convert(path): - """Convert path returned by cc -E xc++ in a complete path.""" - path = path.strip() - return path + """Convert path returned by cc -E xc++ in a complete path.""" + path = path.strip() + return path def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): - """Compute the list of default C or C++ include directories.""" - if lang_is_cpp: - lang = "c++" - else: - lang = "c" - # TODO: We pass -no-canonical-prefixes here to match the compiler flags, - # but in rocm_clang CROSSTOOL file that is a `feature` and we should - # handle the case when it's disabled and no flag is passed - result = repository_ctx.execute([cc, "-no-canonical-prefixes", - "-E", "-x" + lang, "-", "-v"]) - index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN) - if index1 == -1: - return [] - index1 = result.stderr.find("\n", index1) - if index1 == -1: - return [] - index2 = result.stderr.rfind("\n ") - if index2 == -1 or index2 < index1: - return [] - index2 = result.stderr.find("\n", index2 + 1) - if index2 == -1: - inc_dirs = result.stderr[index1 + 1:] - else: - inc_dirs = result.stderr[index1 + 1:index2].strip() - - return [str(repository_ctx.path(_cxx_inc_convert(p))) - for p in inc_dirs.split("\n")] + """Compute the list of default C or C++ include directories.""" + if lang_is_cpp: + lang = "c++" + else: + lang = "c" + + # TODO: We pass -no-canonical-prefixes here to match the compiler flags, + # but in rocm_clang CROSSTOOL file that is a `feature` and we should + # handle the case when it's disabled and no flag is passed + result = repository_ctx.execute([ + cc, + "-no-canonical-prefixes", + "-E", + "-x" + lang, + "-", + "-v", + ]) + index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN) + if index1 == -1: + return [] + index1 = result.stderr.find("\n", index1) + if index1 == -1: + return [] + index2 = result.stderr.rfind("\n ") + if index2 == -1 or index2 < index1: + return [] + index2 = result.stderr.find("\n", index2 + 1) + if index2 == -1: + inc_dirs = result.stderr[index1 + 1:] + else: + inc_dirs = result.stderr[index1 + 1:index2].strip() + + return [ + str(repository_ctx.path(_cxx_inc_convert(p))) + for p in inc_dirs.split("\n") + ] def get_cxx_inc_directories(repository_ctx, cc): - """Compute the list of default C and C++ include directories.""" - # For some reason `clang -xc` sometimes returns include paths that are - # different from the ones from `clang -xc++`. (Symlink and a dir) - # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists - includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True) - includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False) + """Compute the list of default C and C++ include directories.""" - includes_cpp_set = depset(includes_cpp) - return includes_cpp + [inc for inc in includes_c - if inc not in includes_cpp_set] + # For some reason `clang -xc` sometimes returns include paths that are + # different from the ones from `clang -xc++`. (Symlink and a dir) + # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists + includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True) + includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False) + + includes_cpp_set = depset(includes_cpp) + return includes_cpp + [ + inc + for inc in includes_c + if inc not in includes_cpp_set + ] def auto_configure_fail(msg): - """Output failure message when rocm configuration fails.""" - red = "\033[0;31m" - no_color = "\033[0m" - fail("\n%sROCm Configuration Error:%s %s\n" % (red, no_color, msg)) + """Output failure message when rocm configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("\n%sROCm Configuration Error:%s %s\n" % (red, no_color, msg)) + # END cc_configure common functions (see TODO above). def _host_compiler_includes(repository_ctx, cc): - """Generates the cxx_builtin_include_directory entries for gcc inc dirs. + """Generates the cxx_builtin_include_directory entries for gcc inc dirs. - Args: - repository_ctx: The repository context. - cc: The path to the gcc host compiler. + Args: + repository_ctx: The repository context. + cc: The path to the gcc host compiler. - Returns: - A string containing the cxx_builtin_include_directory for each of the gcc - host compiler include directories, which can be added to the CROSSTOOL - file. - """ - inc_dirs = get_cxx_inc_directories(repository_ctx, cc) + Returns: + A string containing the cxx_builtin_include_directory for each of the gcc + host compiler include directories, which can be added to the CROSSTOOL + file. + """ + inc_dirs = get_cxx_inc_directories(repository_ctx, cc) - # Add numpy headers - inc_dirs.append("/usr/lib/python2.7/dist-packages/numpy/core/include") + # Add numpy headers + inc_dirs.append("/usr/lib/python2.7/dist-packages/numpy/core/include") - entries = [] - for inc_dir in inc_dirs: - entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir) + entries = [] + for inc_dir in inc_dirs: + entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir) - # define TENSORFLOW_USE_ROCM - entries.append(" unfiltered_cxx_flag: \"-DTENSORFLOW_USE_ROCM\"") + # define TENSORFLOW_USE_ROCM + entries.append(" unfiltered_cxx_flag: \"-DTENSORFLOW_USE_ROCM\"") - return "\n".join(entries) + return "\n".join(entries) def _rocm_include_path(repository_ctx, rocm_config): - """Generates the cxx_builtin_include_directory entries for rocm inc dirs. + """Generates the cxx_builtin_include_directory entries for rocm inc dirs. + + Args: + repository_ctx: The repository context. + cc: The path to the gcc host compiler. - Args: - repository_ctx: The repository context. - cc: The path to the gcc host compiler. + Returns: + A string containing the cxx_builtin_include_directory for each of the gcc + host compiler include directories, which can be added to the CROSSTOOL + file. + """ + inc_dirs = [] - Returns: - A string containing the cxx_builtin_include_directory for each of the gcc - host compiler include directories, which can be added to the CROSSTOOL - file. - """ - inc_dirs = [] + # general ROCm include path + inc_dirs.append(rocm_config.rocm_toolkit_path + "/include") - # general ROCm include path - inc_dirs.append(rocm_config.rocm_toolkit_path + '/include') + # Add HSA headers + inc_dirs.append("/opt/rocm/hsa/include") - # Add HSA headers - inc_dirs.append("/opt/rocm/hsa/include") + # Add HIP headers + inc_dirs.append("/opt/rocm/include/hip") + inc_dirs.append("/opt/rocm/include/hip/hcc_detail") - # Add HIP headers - inc_dirs.append("/opt/rocm/include/hip") - inc_dirs.append("/opt/rocm/include/hip/hcc_detail") + # Add rocrand and hiprand headers + inc_dirs.append("/opt/rocm/rocrand/include") + inc_dirs.append("/opt/rocm/hiprand/include") - # Add rocrand and hiprand headers - inc_dirs.append("/opt/rocm/rocrand/include") - inc_dirs.append("/opt/rocm/hiprand/include") + # Add rocfft headers + inc_dirs.append("/opt/rocm/rocfft/include") - # Add rocfft headers - inc_dirs.append("/opt/rocm/rocfft/include") + # Add rocBLAS headers + inc_dirs.append("/opt/rocm/rocblas/include") - # Add rocBLAS headers - inc_dirs.append("/opt/rocm/rocblas/include") + # Add MIOpen headers + inc_dirs.append("/opt/rocm/miopen/include") - # Add MIOpen headers - inc_dirs.append("/opt/rocm/miopen/include") + # Add hcc headers + inc_dirs.append("/opt/rocm/hcc/include") + inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/7.0.0/include/") + inc_dirs.append("/opt/rocm/hcc/lib/clang/7.0.0/include") - # Add hcc headers - inc_dirs.append("/opt/rocm/hcc/include") - inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/7.0.0/include/") - inc_dirs.append("/opt/rocm/hcc/lib/clang/7.0.0/include") - # Newer hcc builds use/are based off of clang 8.0.0. - inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/8.0.0/include/") - inc_dirs.append("/opt/rocm/hcc/lib/clang/8.0.0/include") + # Newer hcc builds use/are based off of clang 8.0.0. + inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/8.0.0/include/") + inc_dirs.append("/opt/rocm/hcc/lib/clang/8.0.0/include") - inc_entries = [] - for inc_dir in inc_dirs: - inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir) - return "\n".join(inc_entries) + inc_entries = [] + for inc_dir in inc_dirs: + inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir) + return "\n".join(inc_entries) def _enable_rocm(repository_ctx): - if "TF_NEED_ROCM" in repository_ctx.os.environ: - enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip() - return enable_rocm == "1" - return False + if "TF_NEED_ROCM" in repository_ctx.os.environ: + enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip() + return enable_rocm == "1" + return False def _rocm_toolkit_path(repository_ctx): - """Finds the rocm toolkit directory. + """Finds the rocm toolkit directory. - Args: - repository_ctx: The repository context. + Args: + repository_ctx: The repository context. - Returns: - A speculative real path of the rocm toolkit install directory. - """ - rocm_toolkit_path = _DEFAULT_ROCM_TOOLKIT_PATH - if _ROCM_TOOLKIT_PATH in repository_ctx.os.environ: - rocm_toolkit_path = repository_ctx.os.environ[_ROCM_TOOLKIT_PATH].strip() - if not repository_ctx.path(rocm_toolkit_path).exists: - auto_configure_fail("Cannot find rocm toolkit path.") - return str(repository_ctx.path(rocm_toolkit_path).realpath) + Returns: + A speculative real path of the rocm toolkit install directory. + """ + rocm_toolkit_path = _DEFAULT_ROCM_TOOLKIT_PATH + if _ROCM_TOOLKIT_PATH in repository_ctx.os.environ: + rocm_toolkit_path = repository_ctx.os.environ[_ROCM_TOOLKIT_PATH].strip() + if not repository_ctx.path(rocm_toolkit_path).exists: + auto_configure_fail("Cannot find rocm toolkit path.") + return str(repository_ctx.path(rocm_toolkit_path).realpath) def _amdgpu_targets(repository_ctx): - """Returns a list of strings representing AMDGPU targets.""" - if _TF_ROCM_AMDGPU_TARGETS not in repository_ctx.os.environ: - return _DEFAULT_ROCM_AMDGPU_TARGETS - amdgpu_targets_str = repository_ctx.os.environ[_TF_ROCM_AMDGPU_TARGETS] - amdgpu_targets = amdgpu_targets_str.split(",") - for amdgpu_target in amdgpu_targets: - if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit(): - auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target) - return amdgpu_targets + """Returns a list of strings representing AMDGPU targets.""" + if _TF_ROCM_AMDGPU_TARGETS not in repository_ctx.os.environ: + return _DEFAULT_ROCM_AMDGPU_TARGETS + amdgpu_targets_str = repository_ctx.os.environ[_TF_ROCM_AMDGPU_TARGETS] + amdgpu_targets = amdgpu_targets_str.split(",") + for amdgpu_target in amdgpu_targets: + if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit(): + auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target) + return amdgpu_targets def _cpu_value(repository_ctx): - """Returns the name of the host operating system. - - Args: - repository_ctx: The repository context. - - Returns: - A string containing the name of the host operating system. - """ - os_name = repository_ctx.os.name.lower() - if os_name.startswith("mac os"): - return "Darwin" - if os_name.find("windows") != -1: - return "Windows" - result = repository_ctx.execute(["uname", "-s"]) - return result.stdout.strip() - -def _lib_name(lib, cpu_value, version="", static=False): - """Constructs the platform-specific name of a library. - - Args: - lib: The name of the library, such as "hip" - cpu_value: The name of the host operating system. - version: The version of the library. - static: True the library is static or False if it is a shared object. - - Returns: - The platform-specific name of the library. - """ - if cpu_value in ("Linux"): - if static: - return "lib%s.a" % lib + """Returns the name of the host operating system. + + Args: + repository_ctx: The repository context. + + Returns: + A string containing the name of the host operating system. + """ + os_name = repository_ctx.os.name.lower() + if os_name.startswith("mac os"): + return "Darwin" + if os_name.find("windows") != -1: + return "Windows" + result = repository_ctx.execute(["uname", "-s"]) + return result.stdout.strip() + +def _lib_name(lib, cpu_value, version = "", static = False): + """Constructs the platform-specific name of a library. + + Args: + lib: The name of the library, such as "hip" + cpu_value: The name of the host operating system. + version: The version of the library. + static: True the library is static or False if it is a shared object. + + Returns: + The platform-specific name of the library. + """ + if cpu_value in ("Linux"): + if static: + return "lib%s.a" % lib + else: + if version: + version = ".%s" % version + return "lib%s.so%s" % (lib, version) + elif cpu_value == "Windows": + return "%s.lib" % lib + elif cpu_value == "Darwin": + if static: + return "lib%s.a" % lib + elif version: + version = ".%s" % version + return "lib%s%s.dylib" % (lib, version) else: - if version: - version = ".%s" % version - return "lib%s.so%s" % (lib, version) - elif cpu_value == "Windows": - return "%s.lib" % lib - elif cpu_value == "Darwin": - if static: - return "lib%s.a" % lib - elif version: - version = ".%s" % version - return "lib%s%s.dylib" % (lib, version) - else: - auto_configure_fail("Invalid cpu_value: %s" % cpu_value) - -def _find_rocm_lib(lib, repository_ctx, cpu_value, basedir, version="", - static=False): - """Finds the given ROCm libraries on the system. - - Args: - lib: The name of the library, such as "hip" - repository_ctx: The repository context. - cpu_value: The name of the host operating system. - basedir: The install directory of ROCm. - version: The version of the library. - static: True if static library, False if shared object. - - Returns: - Returns a struct with the following fields: - file_name: The basename of the library found on the system. - path: The full path to the library. - """ - file_name = _lib_name(lib, cpu_value, version, static) - if cpu_value == "Linux": - path = repository_ctx.path("%s/lib64/%s" % (basedir, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - path = repository_ctx.path("%s/lib64/stubs/%s" % (basedir, file_name)) + auto_configure_fail("Invalid cpu_value: %s" % cpu_value) + +def _find_rocm_lib( + lib, + repository_ctx, + cpu_value, + basedir, + version = "", + static = False): + """Finds the given ROCm libraries on the system. + + Args: + lib: The name of the library, such as "hip" + repository_ctx: The repository context. + cpu_value: The name of the host operating system. + basedir: The install directory of ROCm. + version: The version of the library. + static: True if static library, False if shared object. + + Returns: + Returns a struct with the following fields: + file_name: The basename of the library found on the system. + path: The full path to the library. + """ + file_name = _lib_name(lib, cpu_value, version, static) + if cpu_value == "Linux": + path = repository_ctx.path("%s/lib64/%s" % (basedir, file_name)) + if path.exists: + return struct(file_name = file_name, path = str(path.realpath)) + path = repository_ctx.path("%s/lib64/stubs/%s" % (basedir, file_name)) + if path.exists: + return struct(file_name = file_name, path = str(path.realpath)) + path = repository_ctx.path( + "%s/lib/x86_64-linux-gnu/%s" % (basedir, file_name), + ) + if path.exists: + return struct(file_name = file_name, path = str(path.realpath)) + + path = repository_ctx.path("%s/lib/%s" % (basedir, file_name)) if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - path = repository_ctx.path( - "%s/lib/x86_64-linux-gnu/%s" % (basedir, file_name)) + return struct(file_name = file_name, path = str(path.realpath)) + path = repository_ctx.path("%s/%s" % (basedir, file_name)) if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) + return struct(file_name = file_name, path = str(path.realpath)) - path = repository_ctx.path("%s/lib/%s" % (basedir, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - path = repository_ctx.path("%s/%s" % (basedir, file_name)) - if path.exists: - return struct(file_name=file_name, path=str(path.realpath)) - - auto_configure_fail("Cannot find rocm library %s" % file_name) + auto_configure_fail("Cannot find rocm library %s" % file_name) def _find_libs(repository_ctx, rocm_config): - """Returns the ROCm libraries on the system. - - Args: - repository_ctx: The repository context. - rocm_config: The ROCm config as returned by _get_rocm_config - - Returns: - Map of library names to structs of filename and path as returned by - _find_rocm_lib. - """ - cpu_value = rocm_config.cpu_value - return { - "hip": _find_rocm_lib( - "hip_hcc", repository_ctx, cpu_value, rocm_config.rocm_toolkit_path), - "rocblas": _find_rocm_lib( - "rocblas", repository_ctx, cpu_value, rocm_config.rocm_toolkit_path + "/rocblas"), - "rocfft": _find_rocm_lib( - "rocfft", repository_ctx, cpu_value, rocm_config.rocm_toolkit_path + "/rocfft"), - "hiprand": _find_rocm_lib( - "hiprand", repository_ctx, cpu_value, rocm_config.rocm_toolkit_path + "/hiprand"), - "miopen": _find_rocm_lib( - "MIOpen", repository_ctx, cpu_value, rocm_config.rocm_toolkit_path + "/miopen"), - } + """Returns the ROCm libraries on the system. + + Args: + repository_ctx: The repository context. + rocm_config: The ROCm config as returned by _get_rocm_config + + Returns: + Map of library names to structs of filename and path as returned by + _find_rocm_lib. + """ + cpu_value = rocm_config.cpu_value + return { + "hip": _find_rocm_lib( + "hip_hcc", + repository_ctx, + cpu_value, + rocm_config.rocm_toolkit_path, + ), + "rocblas": _find_rocm_lib( + "rocblas", + repository_ctx, + cpu_value, + rocm_config.rocm_toolkit_path + "/rocblas", + ), + "rocfft": _find_rocm_lib( + "rocfft", + repository_ctx, + cpu_value, + rocm_config.rocm_toolkit_path + "/rocfft", + ), + "hiprand": _find_rocm_lib( + "hiprand", + repository_ctx, + cpu_value, + rocm_config.rocm_toolkit_path + "/hiprand", + ), + "miopen": _find_rocm_lib( + "MIOpen", + repository_ctx, + cpu_value, + rocm_config.rocm_toolkit_path + "/miopen", + ), + } def _get_rocm_config(repository_ctx): - """Detects and returns information about the ROCm installation on the system. - - Args: - repository_ctx: The repository context. - - Returns: - A struct containing the following fields: - rocm_toolkit_path: The ROCm toolkit installation directory. - amdgpu_targets: A list of the system's AMDGPU targets. - cpu_value: The name of the host operating system. - """ - cpu_value = _cpu_value(repository_ctx) - rocm_toolkit_path = _rocm_toolkit_path(repository_ctx) - return struct( - rocm_toolkit_path = rocm_toolkit_path, - amdgpu_targets = _amdgpu_targets(repository_ctx), - cpu_value = cpu_value) - -def _tpl(repository_ctx, tpl, substitutions={}, out=None): - if not out: - out = tpl.replace(":", "/") - repository_ctx.template( - out, - Label("//third_party/gpus/%s.tpl" % tpl), - substitutions) - + """Detects and returns information about the ROCm installation on the system. + + Args: + repository_ctx: The repository context. + + Returns: + A struct containing the following fields: + rocm_toolkit_path: The ROCm toolkit installation directory. + amdgpu_targets: A list of the system's AMDGPU targets. + cpu_value: The name of the host operating system. + """ + cpu_value = _cpu_value(repository_ctx) + rocm_toolkit_path = _rocm_toolkit_path(repository_ctx) + return struct( + rocm_toolkit_path = rocm_toolkit_path, + amdgpu_targets = _amdgpu_targets(repository_ctx), + cpu_value = cpu_value, + ) + +def _tpl(repository_ctx, tpl, substitutions = {}, out = None): + if not out: + out = tpl.replace(":", "/") + repository_ctx.template( + out, + Label("//third_party/gpus/%s.tpl" % tpl), + substitutions, + ) def _file(repository_ctx, label): - repository_ctx.template( - label.replace(":", "/"), - Label("//third_party/gpus/%s.tpl" % label), - {}) - + repository_ctx.template( + label.replace(":", "/"), + Label("//third_party/gpus/%s.tpl" % label), + {}, + ) _DUMMY_CROSSTOOL_BZL_FILE = """ def error_gpu_disabled(): @@ -375,7 +418,6 @@ def error_gpu_disabled(): ) """ - _DUMMY_CROSSTOOL_BUILD_FILE = """ load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled") @@ -383,259 +425,338 @@ error_gpu_disabled() """ def _create_dummy_repository(repository_ctx): - cpu_value = _cpu_value(repository_ctx) - - # Set up BUILD file for rocm/. - _tpl(repository_ctx, "rocm:build_defs.bzl", - { - "%{rocm_is_configured}": "False", - "%{rocm_extra_copts}": "[]" - }) - _tpl(repository_ctx, "rocm:BUILD", - { - "%{hip_lib}": _lib_name("hip", cpu_value), - "%{rocblas_lib}": _lib_name("rocblas", cpu_value), - "%{miopen_lib}": _lib_name("miopen", cpu_value), - "%{rocfft_lib}": _lib_name("rocfft", cpu_value), - "%{hiprand_lib}": _lib_name("hiprand", cpu_value), - "%{rocm_include_genrules}": '', - "%{rocm_headers}": '', - }) - - # Create dummy files for the ROCm toolkit since they are still required by - # tensorflow/core/platform/default/build_config:rocm. - repository_ctx.file("rocm/hip/include/hip/hip_runtime.h", "") - - # Set up rocm_config.h, which is used by - # tensorflow/stream_executor/dso_loader.cc. - _tpl(repository_ctx, "rocm:rocm_config.h", - { - "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH, - }, "rocm/rocm/rocm_config.h") - - # If rocm_configure is not configured to build with GPU support, and the user - # attempts to build with --config=rocm, add a dummy build rule to intercept - # this and fail with an actionable error message. - repository_ctx.file("crosstool/error_gpu_disabled.bzl", - _DUMMY_CROSSTOOL_BZL_FILE) - repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) - -def _execute(repository_ctx, cmdline, error_msg=None, error_details=None, - empty_stdout_fine=False): - """Executes an arbitrary shell command. - - Args: - repository_ctx: the repository_ctx object - cmdline: list of strings, the command to execute - error_msg: string, a summary of the error if the command fails - error_details: string, details about the error or steps to fix it - empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise - it's an error - Return: - the result of repository_ctx.execute(cmdline) - """ - result = repository_ctx.execute(cmdline) - if result.stderr or not (empty_stdout_fine or result.stdout): - auto_configure_fail( - "\n".join([ - error_msg.strip() if error_msg else "Repository command failed", - result.stderr.strip(), - error_details if error_details else ""])) - return result + cpu_value = _cpu_value(repository_ctx) + + # Set up BUILD file for rocm/. + _tpl( + repository_ctx, + "rocm:build_defs.bzl", + { + "%{rocm_is_configured}": "False", + "%{rocm_extra_copts}": "[]", + }, + ) + _tpl( + repository_ctx, + "rocm:BUILD", + { + "%{hip_lib}": _lib_name("hip", cpu_value), + "%{rocblas_lib}": _lib_name("rocblas", cpu_value), + "%{miopen_lib}": _lib_name("miopen", cpu_value), + "%{rocfft_lib}": _lib_name("rocfft", cpu_value), + "%{hiprand_lib}": _lib_name("hiprand", cpu_value), + "%{rocm_include_genrules}": "", + "%{rocm_headers}": "", + }, + ) + + # Create dummy files for the ROCm toolkit since they are still required by + # tensorflow/core/platform/default/build_config:rocm. + repository_ctx.file("rocm/hip/include/hip/hip_runtime.h", "") + + # Set up rocm_config.h, which is used by + # tensorflow/stream_executor/dso_loader.cc. + _tpl( + repository_ctx, + "rocm:rocm_config.h", + { + "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH, + }, + "rocm/rocm/rocm_config.h", + ) + + # If rocm_configure is not configured to build with GPU support, and the user + # attempts to build with --config=rocm, add a dummy build rule to intercept + # this and fail with an actionable error message. + repository_ctx.file( + "crosstool/error_gpu_disabled.bzl", + _DUMMY_CROSSTOOL_BZL_FILE, + ) + repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) + +def _execute( + repository_ctx, + cmdline, + error_msg = None, + error_details = None, + empty_stdout_fine = False): + """Executes an arbitrary shell command. + + Args: + repository_ctx: the repository_ctx object + cmdline: list of strings, the command to execute + error_msg: string, a summary of the error if the command fails + error_details: string, details about the error or steps to fix it + empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise + it's an error + Return: + the result of repository_ctx.execute(cmdline) + """ + result = repository_ctx.execute(cmdline) + if result.stderr or not (empty_stdout_fine or result.stdout): + auto_configure_fail( + "\n".join([ + error_msg.strip() if error_msg else "Repository command failed", + result.stderr.strip(), + error_details if error_details else "", + ]), + ) + return result def _norm_path(path): - """Returns a path with '/' and remove the trailing slash.""" - path = path.replace("\\", "/") - if path[-1] == "/": - path = path[:-1] - return path - -def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, - src_files = [], dest_files = []): - """Returns a genrule to symlink(or copy if on Windows) a set of files. - - If src_dir is passed, files will be read from the given directory; otherwise - we assume files are in src_files and dest_files - """ - if src_dir != None: - src_dir = _norm_path(src_dir) - dest_dir = _norm_path(dest_dir) - files = _read_dir(repository_ctx, src_dir) - # Create a list with the src_dir stripped to use for outputs. - dest_files = files.replace(src_dir, '').splitlines() - src_files = files.splitlines() - command = [] - # We clear folders that might have been generated previously to avoid - # undesired inclusions - command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi') - command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi') - outs = [] - for i in range(len(dest_files)): - if dest_files[i] != "": - # If we have only one file to link we do not want to use the dest_dir, as - # $(@D) will include the full path to the file. - dest = '$(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else '$(@D)/' + dest_files[i] - # On Windows, symlink is not supported, so we just copy all the files. - cmd = 'ln -s' - command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest)) - outs.append(' "' + dest_dir + dest_files[i] + '",') - genrule = _genrule(src_dir, genrule_name, " && ".join(command), - "\n".join(outs)) - return genrule + """Returns a path with '/' and remove the trailing slash.""" + path = path.replace("\\", "/") + if path[-1] == "/": + path = path[:-1] + return path + +def _symlink_genrule_for_dir( + repository_ctx, + src_dir, + dest_dir, + genrule_name, + src_files = [], + dest_files = []): + """Returns a genrule to symlink(or copy if on Windows) a set of files. + + If src_dir is passed, files will be read from the given directory; otherwise + we assume files are in src_files and dest_files + """ + if src_dir != None: + src_dir = _norm_path(src_dir) + dest_dir = _norm_path(dest_dir) + files = _read_dir(repository_ctx, src_dir) + + # Create a list with the src_dir stripped to use for outputs. + dest_files = files.replace(src_dir, "").splitlines() + src_files = files.splitlines() + command = [] + + # We clear folders that might have been generated previously to avoid + # undesired inclusions + command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi') + command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi') + outs = [] + for i in range(len(dest_files)): + if dest_files[i] != "": + # If we have only one file to link we do not want to use the dest_dir, as + # $(@D) will include the full path to the file. + dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] + + # On Windows, symlink is not supported, so we just copy all the files. + cmd = "ln -s" + command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) + outs.append(' "' + dest_dir + dest_files[i] + '",') + genrule = _genrule( + src_dir, + genrule_name, + " && ".join(command), + "\n".join(outs), + ) + return genrule def _genrule(src_dir, genrule_name, command, outs): - """Returns a string with a genrule. - - Genrule executes the given command and produces the given outputs. - """ - return ( - 'genrule(\n' + - ' name = "' + - genrule_name + '",\n' + - ' outs = [\n' + - outs + - '\n ],\n' + - ' cmd = """\n' + - command + - '\n """,\n' + - ')\n' - ) + """Returns a string with a genrule. + + Genrule executes the given command and produces the given outputs. + """ + return ( + "genrule(\n" + + ' name = "' + + genrule_name + '",\n' + + " outs = [\n" + + outs + + "\n ],\n" + + ' cmd = """\n' + + command + + '\n """,\n' + + ")\n" + ) def _read_dir(repository_ctx, src_dir): - """Returns a string with all files in a directory. - - Finds all files inside a directory, traversing subfolders and following - symlinks. The returned string contains the full path of all files - separated by line breaks. - """ - find_result = _execute( - repository_ctx, ["find", src_dir, "-follow", "-type", "f"], - empty_stdout_fine=True) - result = find_result.stdout - return result + """Returns a string with all files in a directory. + + Finds all files inside a directory, traversing subfolders and following + symlinks. The returned string contains the full path of all files + separated by line breaks. + """ + find_result = _execute( + repository_ctx, + ["find", src_dir, "-follow", "-type", "f"], + empty_stdout_fine = True, + ) + result = find_result.stdout + return result def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): - if False: - amdgpu_target_flags = ["--amdgpu-target=" + - amdgpu_target for amdgpu_target in amdgpu_targets] - else: - # AMDGPU targets are handled in the "crosstool_wrapper_driver_is_not_gcc" - amdgpu_target_flags = [] - return str(amdgpu_target_flags) + if False: + amdgpu_target_flags = ["--amdgpu-target=" + + amdgpu_target for amdgpu_target in amdgpu_targets] + else: + # AMDGPU targets are handled in the "crosstool_wrapper_driver_is_not_gcc" + amdgpu_target_flags = [] + return str(amdgpu_target_flags) def _create_local_rocm_repository(repository_ctx): - """Creates the repository containing files set up to build with ROCm.""" - rocm_config = _get_rocm_config(repository_ctx) - - # Set up symbolic links for the rocm toolkit by creating genrules to do - # symlinking. We create one genrule for each directory we want to track under - # rocm_toolkit_path - rocm_toolkit_path = rocm_config.rocm_toolkit_path - rocm_include_path = rocm_toolkit_path + "/include" - genrules = [_symlink_genrule_for_dir(repository_ctx, - rocm_include_path, "rocm/include", "rocm-include")] - genrules.append(_symlink_genrule_for_dir(repository_ctx, - rocm_toolkit_path + "/rocfft/include", "rocm/include/rocfft", "rocfft-include")) - genrules.append(_symlink_genrule_for_dir(repository_ctx, - rocm_toolkit_path + "/rocblas/include", "rocm/include/rocblas", "rocblas-include")) - genrules.append(_symlink_genrule_for_dir(repository_ctx, - rocm_toolkit_path + "/miopen/include", "rocm/include/miopen", "miopen-include")) - - rocm_libs = _find_libs(repository_ctx, rocm_config) - rocm_lib_src = [] - rocm_lib_dest = [] - for lib in rocm_libs.values(): - rocm_lib_src.append(lib.path) - rocm_lib_dest.append("rocm/lib/" + lib.file_name) - genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "", "rocm-lib", - rocm_lib_src, rocm_lib_dest)) - - included_files = _read_dir(repository_ctx, rocm_include_path).replace( - rocm_include_path, '').splitlines() - - # Set up BUILD file for rocm/ - _tpl(repository_ctx, "rocm:build_defs.bzl", - { - "%{rocm_is_configured}": "True", - "%{rocm_extra_copts}": _compute_rocm_extra_copts( - repository_ctx, rocm_config.amdgpu_targets), - - }) - _tpl(repository_ctx, "rocm:BUILD", - { - "%{hip_lib}": rocm_libs["hip"].file_name, - "%{rocblas_lib}": rocm_libs["rocblas"].file_name, - "%{rocfft_lib}": rocm_libs["rocfft"].file_name, - "%{hiprand_lib}": rocm_libs["hiprand"].file_name, - "%{miopen_lib}": rocm_libs["miopen"].file_name, - "%{rocm_include_genrules}": "\n".join(genrules), - "%{rocm_headers}": ('":rocm-include",\n' + - '":rocfft-include",\n' + - '":rocblas-include",\n' + - '":miopen-include",'), - }) - # Set up crosstool/ - _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty", "%{win_linker_files}": ":empty"}) - cc = find_cc(repository_ctx) - host_compiler_includes = _host_compiler_includes(repository_ctx, cc) - rocm_defines = { - "%{rocm_include_path}": _rocm_include_path(repository_ctx, - rocm_config), - "%{host_compiler_includes}": host_compiler_includes, - "%{clang_path}": str(cc), - } - - _tpl(repository_ctx, "crosstool:CROSSTOOL_hipcc", rocm_defines, out="crosstool/CROSSTOOL") - - _tpl(repository_ctx, - "crosstool:clang/bin/crosstool_wrapper_driver_rocm", - { - "%{cpu_compiler}": str(cc), - "%{hipcc_path}": "/opt/rocm/bin/hipcc", - "%{gcc_host_compiler_path}": str(cc), - "%{rocm_amdgpu_targets}": ",".join( - ["\"%s\"" % c for c in rocm_config.amdgpu_targets]), - }) - - # Set up rocm_config.h, which is used by - # tensorflow/stream_executor/dso_loader.cc. - _tpl(repository_ctx, "rocm:rocm_config.h", - { - "%{rocm_amdgpu_targets}": ",".join( - ["\"%s\"" % c for c in rocm_config.amdgpu_targets]), - "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path, - }, "rocm/rocm/rocm_config.h") - + """Creates the repository containing files set up to build with ROCm.""" + rocm_config = _get_rocm_config(repository_ctx) + + # Set up symbolic links for the rocm toolkit by creating genrules to do + # symlinking. We create one genrule for each directory we want to track under + # rocm_toolkit_path + rocm_toolkit_path = rocm_config.rocm_toolkit_path + rocm_include_path = rocm_toolkit_path + "/include" + genrules = [_symlink_genrule_for_dir( + repository_ctx, + rocm_include_path, + "rocm/include", + "rocm-include", + )] + genrules.append(_symlink_genrule_for_dir( + repository_ctx, + rocm_toolkit_path + "/rocfft/include", + "rocm/include/rocfft", + "rocfft-include", + )) + genrules.append(_symlink_genrule_for_dir( + repository_ctx, + rocm_toolkit_path + "/rocblas/include", + "rocm/include/rocblas", + "rocblas-include", + )) + genrules.append(_symlink_genrule_for_dir( + repository_ctx, + rocm_toolkit_path + "/miopen/include", + "rocm/include/miopen", + "miopen-include", + )) + + rocm_libs = _find_libs(repository_ctx, rocm_config) + rocm_lib_src = [] + rocm_lib_dest = [] + for lib in rocm_libs.values(): + rocm_lib_src.append(lib.path) + rocm_lib_dest.append("rocm/lib/" + lib.file_name) + genrules.append(_symlink_genrule_for_dir( + repository_ctx, + None, + "", + "rocm-lib", + rocm_lib_src, + rocm_lib_dest, + )) + + included_files = _read_dir(repository_ctx, rocm_include_path).replace( + rocm_include_path, + "", + ).splitlines() + + # Set up BUILD file for rocm/ + _tpl( + repository_ctx, + "rocm:build_defs.bzl", + { + "%{rocm_is_configured}": "True", + "%{rocm_extra_copts}": _compute_rocm_extra_copts( + repository_ctx, + rocm_config.amdgpu_targets, + ), + }, + ) + _tpl( + repository_ctx, + "rocm:BUILD", + { + "%{hip_lib}": rocm_libs["hip"].file_name, + "%{rocblas_lib}": rocm_libs["rocblas"].file_name, + "%{rocfft_lib}": rocm_libs["rocfft"].file_name, + "%{hiprand_lib}": rocm_libs["hiprand"].file_name, + "%{miopen_lib}": rocm_libs["miopen"].file_name, + "%{rocm_include_genrules}": "\n".join(genrules), + "%{rocm_headers}": ('":rocm-include",\n' + + '":rocfft-include",\n' + + '":rocblas-include",\n' + + '":miopen-include",'), + }, + ) + + # Set up crosstool/ + _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty", "%{win_linker_files}": ":empty"}) + cc = find_cc(repository_ctx) + host_compiler_includes = _host_compiler_includes(repository_ctx, cc) + rocm_defines = { + "%{rocm_include_path}": _rocm_include_path( + repository_ctx, + rocm_config, + ), + "%{host_compiler_includes}": host_compiler_includes, + "%{clang_path}": str(cc), + } + + _tpl(repository_ctx, "crosstool:CROSSTOOL_hipcc", rocm_defines, out = "crosstool/CROSSTOOL") + + _tpl( + repository_ctx, + "crosstool:clang/bin/crosstool_wrapper_driver_rocm", + { + "%{cpu_compiler}": str(cc), + "%{hipcc_path}": "/opt/rocm/bin/hipcc", + "%{gcc_host_compiler_path}": str(cc), + "%{rocm_amdgpu_targets}": ",".join( + ["\"%s\"" % c for c in rocm_config.amdgpu_targets], + ), + }, + ) + + # Set up rocm_config.h, which is used by + # tensorflow/stream_executor/dso_loader.cc. + _tpl( + repository_ctx, + "rocm:rocm_config.h", + { + "%{rocm_amdgpu_targets}": ",".join( + ["\"%s\"" % c for c in rocm_config.amdgpu_targets], + ), + "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path, + }, + "rocm/rocm/rocm_config.h", + ) def _create_remote_rocm_repository(repository_ctx, remote_config_repo): - """Creates pointers to a remotely configured repo set up to build with ROCm.""" - _tpl(repository_ctx, "rocm:build_defs.bzl", - { - "%{rocm_is_configured}": "True", - "%{rocm_extra_copts}": _compute_rocm_extra_copts( - repository_ctx, #_compute_capabilities(repository_ctx) + """Creates pointers to a remotely configured repo set up to build with ROCm.""" + _tpl( + repository_ctx, + "rocm:build_defs.bzl", + { + "%{rocm_is_configured}": "True", + "%{rocm_extra_copts}": _compute_rocm_extra_copts( + repository_ctx, #_compute_capabilities(repository_ctx) ), - - }) - _tpl(repository_ctx, "rocm:remote.BUILD", - { - "%{remote_rocm_repo}": remote_config_repo, - }, "rocm/BUILD") - _tpl(repository_ctx, "crosstool:remote.BUILD", { - "%{remote_rocm_repo}": remote_config_repo, - }, "crosstool/BUILD") + }, + ) + _tpl( + repository_ctx, + "rocm:remote.BUILD", + { + "%{remote_rocm_repo}": remote_config_repo, + }, + "rocm/BUILD", + ) + _tpl(repository_ctx, "crosstool:remote.BUILD", { + "%{remote_rocm_repo}": remote_config_repo, + }, "crosstool/BUILD") def _rocm_autoconf_impl(repository_ctx): - """Implementation of the rocm_autoconf repository rule.""" - if not _enable_rocm(repository_ctx): - _create_dummy_repository(repository_ctx) - else: - if _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ: - _create_remote_rocm_repository(repository_ctx, - repository_ctx.os.environ[_TF_ROCM_CONFIG_REPO]) + """Implementation of the rocm_autoconf repository rule.""" + if not _enable_rocm(repository_ctx): + _create_dummy_repository(repository_ctx) + elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ: + _create_remote_rocm_repository( + repository_ctx, + repository_ctx.os.environ[_TF_ROCM_CONFIG_REPO], + ) else: - _create_local_rocm_repository(repository_ctx) - + _create_local_rocm_repository(repository_ctx) rocm_configure = repository_rule( implementation = _rocm_autoconf_impl, |