diff options
Diffstat (limited to 'third_party/gpus/cuda_configure.bzl')
-rw-r--r-- | third_party/gpus/cuda_configure.bzl | 42 |
1 files changed, 35 insertions, 7 deletions
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 61932a8e6d..77dc602fd9 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -883,15 +883,16 @@ def _use_cuda_clang(repository_ctx): return enable_cuda == "1" return False -def _compute_cuda_extra_copts(repository_ctx, cuda_config): +def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): if _use_cuda_clang(repository_ctx): - capability_flags = ["--cuda-gpu-arch=sm_" + cap.replace(".", "") for cap in cuda_config.compute_capabilities] + capability_flags = ["--cuda-gpu-arch=sm_" + + cap.replace(".", "") for cap in compute_capabilities] else: # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc capability_flags = [] return str(capability_flags) -def _create_cuda_repository(repository_ctx): +def _create_local_cuda_repository(repository_ctx): """Creates the repository containing files set up to build with CUDA.""" cuda_config = _get_cuda_config(repository_ctx) @@ -939,7 +940,8 @@ def _create_cuda_repository(repository_ctx): _tpl(repository_ctx, "cuda:build_defs.bzl", { "%{cuda_is_configured}": "True", - "%{cuda_extra_copts}": _compute_cuda_extra_copts(repository_ctx, cuda_config), + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + repository_ctx, cuda_config.compute_capabilities), }) _tpl(repository_ctx, "cuda:BUILD", @@ -1000,17 +1002,39 @@ def _create_cuda_repository(repository_ctx): }) +def _create_remote_cuda_repository(repository_ctx): + """Creates pointers to a remotely configured repo set up to build with CUDA.""" + _tpl(repository_ctx, "cuda:build_defs.bzl", + { + "%{cuda_is_configured}": "True", + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + repository_ctx, _compute_capabilities(repository_ctx)), + + }) + _tpl(repository_ctx, "cuda:remote.BUILD", + { + "%{remote_cuda_repo}": repository_ctx.attr.remote_config_repo, + }, "cuda/BUILD") + _tpl(repository_ctx, "crosstool:remote.BUILD", { + "%{remote_cuda_repo}": repository_ctx.attr.remote_config_repo, + }, "crosstool/BUILD") + def _cuda_autoconf_impl(repository_ctx): """Implementation of the cuda_autoconf repository rule.""" if not _enable_cuda(repository_ctx): _create_dummy_repository(repository_ctx) else: - _create_cuda_repository(repository_ctx) - + if repository_ctx.attr.remote_config_repo != "": + _create_remote_cuda_repository(repository_ctx) + else: + _create_local_cuda_repository(repository_ctx) cuda_configure = repository_rule( implementation = _cuda_autoconf_impl, + attrs = { + "remote_config_repo": attr.string(mandatory = False, default =""), + }, environ = [ _GCC_HOST_COMPILER_PATH, "TF_NEED_CUDA", @@ -1027,9 +1051,13 @@ cuda_configure = repository_rule( Add the following to your WORKSPACE FILE: ```python -cuda_configure(name = "local_config_cuda") +cuda_configure( + name = "local_config_cuda" + remote_config_repo = "@remote_cuda_config_tf//" +) ``` Args: name: A unique name for this workspace rule. + remote_config_repo: Location of a pre-generated config (optional). """ |