From 2ebbffe87059fdc8ed66aa59cfa810af87029abb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 15 Jun 2017 17:19:13 -0700 Subject: Enable setting remote configuration for cuda and python as an env variable PiperOrigin-RevId: 159176334 --- third_party/gpus/cuda_configure.bzl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) (limited to 'third_party/gpus') diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 77dc602fd9..83a377dde5 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -26,6 +26,7 @@ _TF_CUDA_VERSION = "TF_CUDA_VERSION" _TF_CUDNN_VERSION = "TF_CUDNN_VERSION" _CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH" _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" +_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _DEFAULT_CUDA_VERSION = "" _DEFAULT_CUDNN_VERSION = "" @@ -1001,8 +1002,7 @@ def _create_local_cuda_repository(repository_ctx): "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path, }) - -def _create_remote_cuda_repository(repository_ctx): +def _create_remote_cuda_repository(repository_ctx, remote_config_repo): """Creates pointers to a remotely configured repo set up to build with CUDA.""" _tpl(repository_ctx, "cuda:build_defs.bzl", { @@ -1013,10 +1013,10 @@ def _create_remote_cuda_repository(repository_ctx): }) _tpl(repository_ctx, "cuda:remote.BUILD", { - "%{remote_cuda_repo}": repository_ctx.attr.remote_config_repo, + "%{remote_cuda_repo}": remote_config_repo, }, "cuda/BUILD") _tpl(repository_ctx, "crosstool:remote.BUILD", { - "%{remote_cuda_repo}": repository_ctx.attr.remote_config_repo, + "%{remote_cuda_repo}": remote_config_repo, }, "crosstool/BUILD") def _cuda_autoconf_impl(repository_ctx): @@ -1024,8 +1024,12 @@ def _cuda_autoconf_impl(repository_ctx): if not _enable_cuda(repository_ctx): _create_dummy_repository(repository_ctx) else: - if repository_ctx.attr.remote_config_repo != "": - _create_remote_cuda_repository(repository_ctx) + if _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ: + _create_remote_cuda_repository(repository_ctx, + repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO]) + elif repository_ctx.attr.remote_config_repo != "": + _create_remote_cuda_repository(repository_ctx, + repository_ctx.attr.remote_config_repo) else: _create_local_cuda_repository(repository_ctx) @@ -1043,6 +1047,7 @@ cuda_configure = repository_rule( _TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_COMPUTE_CAPABILITIES, + _TF_CUDA_CONFIG_REPO, ], ) -- cgit v1.2.3