aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party
diff options
context:
space:
mode:
authorGravatar Jason Furmanek <furmanek@us.ibm.com>2018-09-26 04:44:12 +0000
committerGravatar Jason Furmanek <furmanek@us.ibm.com>2018-09-26 04:44:12 +0000
commit7c2341501a583ca625c976f118090e495cdcbe07 (patch)
tree3c6cba5366c4f0f119df312d7e4d26f3f4119b4e /third_party
parent6666516f390f125ed70ddbd4e6f89b83d953c408 (diff)
Find NCCL2 debians in Tensorflow configure
Diffstat (limited to 'third_party')
-rw-r--r--third_party/nccl/nccl_configure.bzl14
-rw-r--r--third_party/nccl/system.BUILD.tpl4
2 files changed, 13 insertions, 5 deletions
diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl
index ce9447096e..0713b36724 100644
--- a/third_party/nccl/nccl_configure.bzl
+++ b/third_party/nccl/nccl_configure.bzl
@@ -5,6 +5,7 @@
* `TF_NCCL_VERSION`: The NCCL version.
* `NCCL_INSTALL_PATH`: The installation path of the NCCL library.
+ * `NCCL_HDR_PATH`: The installation path of the NCCL header files.
"""
load(
@@ -15,6 +16,7 @@ load(
)
_NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH"
+_NCCL_HDR_PATH = "NCCL_HDR_PATH"
_TF_NCCL_VERSION = "TF_NCCL_VERSION"
_TF_NCCL_CONFIG_REPO = "TF_NCCL_CONFIG_REPO"
@@ -68,7 +70,7 @@ def _find_nccl_header(repository_ctx, nccl_install_path):
return header_path
-def _check_nccl_version(repository_ctx, nccl_install_path, nccl_version):
+def _check_nccl_version(repository_ctx, nccl_install_path, nccl_hdr_path, nccl_version):
"""Checks whether the header file matches the specified version of NCCL.
Args:
@@ -79,7 +81,9 @@ def _check_nccl_version(repository_ctx, nccl_install_path, nccl_version):
Returns:
A string containing the library version of NCCL.
"""
- header_path = _find_nccl_header(repository_ctx, nccl_install_path)
+ header_path = repository_ctx.path("%s/nccl.h" % nccl_hdr_path)
+ if not header_path.exists:
+ header_path = _find_nccl_header(repository_ctx, nccl_install_path)
header_dir = str(header_path.realpath.dirname)
major_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
_DEFINE_NCCL_MAJOR)
@@ -109,6 +113,7 @@ def _find_nccl_lib(repository_ctx, nccl_install_path, nccl_version):
"""
lib_path = repository_ctx.path("%s/lib/libnccl.so.%s" % (nccl_install_path,
nccl_version))
+
if not lib_path.exists:
auto_configure_fail("Cannot find NCCL library %s" % str(lib_path))
return lib_path
@@ -138,10 +143,12 @@ def _nccl_configure_impl(repository_ctx):
else:
# Create target for locally installed NCCL.
nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip()
- _check_nccl_version(repository_ctx, nccl_install_path, nccl_version)
+ nccl_hdr_path = repository_ctx.os.environ[_NCCL_HDR_PATH].strip()
+ _check_nccl_version(repository_ctx, nccl_install_path, nccl_hdr_path, nccl_version)
repository_ctx.template("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE, {
"%{version}": nccl_version,
"%{install_path}": nccl_install_path,
+ "%{hdr_path}": nccl_hdr_path,
})
@@ -149,6 +156,7 @@ nccl_configure = repository_rule(
implementation=_nccl_configure_impl,
environ=[
_NCCL_INSTALL_PATH,
+ _NCCL_HDR_PATH,
_TF_NCCL_VERSION,
],
)
diff --git a/third_party/nccl/system.BUILD.tpl b/third_party/nccl/system.BUILD.tpl
index 7ca835dedf..a07f54955f 100644
--- a/third_party/nccl/system.BUILD.tpl
+++ b/third_party/nccl/system.BUILD.tpl
@@ -20,7 +20,7 @@ genrule(
"libnccl.so.%{version}",
"nccl.h",
],
- cmd = """cp "%{install_path}/include/nccl.h" "$(@D)/nccl.h" &&
- cp "%{install_path}/lib/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """,
+ cmd = """cp "%{hdr_path}/nccl.h" "$(@D)/nccl.h" &&
+ cp "%{install_path}/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """,
)